├── src ├── props.cpp ├── operators │ ├── input.cpp │ ├── random.cpp │ ├── special.cpp │ ├── optimized.cpp │ ├── linalg.cpp │ ├── constant.cpp │ ├── reduction.cpp │ ├── arithmetic.cpp │ ├── logical.cpp │ ├── shape.cpp │ ├── elementwise.cpp │ └── abstract.cpp ├── api │ ├── random.cpp │ ├── grad.cpp │ ├── optimized.cpp │ ├── debug.cpp │ ├── special.cpp │ └── elementwise.cpp ├── definitions.cpp ├── node.cpp ├── exceptions.cpp ├── print.cpp ├── os.cpp ├── helpers.cpp └── abstract_operator.cpp ├── include ├── optimization │ └── simple.h ├── optimization.h ├── helpers.h ├── export │ ├── protobuf.h │ ├── graphviz.h │ ├── cytoscape.h │ └── json.h ├── export.h ├── api │ ├── grad.h │ ├── index.h │ ├── debug.h │ ├── shape.h │ ├── special.h │ ├── elementwise.h │ ├── optimized.h │ ├── linalg.h │ └── input.h ├── operators.h ├── api.h ├── graph_ir.h ├── operators │ ├── random.h │ ├── input.h │ ├── debug.h │ ├── constant.h │ ├── optimized.h │ ├── logical.h │ └── arithmetic.h ├── print.h ├── os.h ├── shared.h ├── props.h ├── node.h ├── backend.h ├── definitions.h ├── utils.h └── enums.h ├── mock_backend ├── src │ ├── evaluator.cpp │ └── backend.cpp └── CMakeLists.txt ├── .gitignore ├── .gitmodules ├── artefacts └── default_promotion_table.csv ├── README.md ├── examples └── src │ └── test.cpp └── CMakeLists.txt /src/props.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 24/10/16. 3 | // 4 | 5 | #include "graph_ir.h" 6 | 7 | namespace md{ 8 | namespace gir{ 9 | } 10 | } 11 | 12 | -------------------------------------------------------------------------------- /src/operators/input.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 06/10/16. 3 | // 4 | 5 | #include "graph_ir.h" 6 | 7 | namespace md{ 8 | namespace gir{ 9 | 10 | 11 | 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /include/optimization/simple.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 19/10/16. 3 | // 4 | 5 | #ifndef METADIFF_GRAPH_IR_OPTIMIZATION_SIMPLE_H 6 | #define METADIFF_GRAPH_IR_OPTIMIZATION_SIMPLE_H 7 | 8 | #endif //METADIFF_GRAPH_IR_OPTIMIZATION_ SIMPLE_H 9 | -------------------------------------------------------------------------------- /include/optimization.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 19/10/16. 3 | // 4 | 5 | #ifndef METADIFF_GRAPH_IR_OPTIMIZATION_H 6 | #define METADIFF_GRAPH_IR_OPTIMIZATION_H 7 | 8 | #include "optimization/simple.h" 9 | 10 | #endif //METADIFF_GRAPH_IR_OPTIMIZATION_H 11 | -------------------------------------------------------------------------------- /include/helpers.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 06/10/16. 3 | // 4 | 5 | #ifndef METADIFF_GRAPH_IR_HELPERS_H 6 | #define METADIFF_GRAPH_IR_HELPERS_H 7 | 8 | namespace md{ 9 | namespace gir{ 10 | 11 | } 12 | } 13 | #endif //METADIFF_GRAPH_IR_HELPERS_H 14 | -------------------------------------------------------------------------------- /include/export/protobuf.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 25/10/16. 3 | // 4 | 5 | #ifndef METADIFF_GRAPH_IR_PROTOBUF_H 6 | #define METADIFF_GRAPH_IR_PROTOBUF_H 7 | 8 | namespace md{ 9 | namespace protobuf{ 10 | // TODO 11 | } 12 | } 13 | #endif //METADIFF_GRAPH_IR_PROTOBUF_H 14 | -------------------------------------------------------------------------------- /include/export/graphviz.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 19/10/16. 3 | // 4 | 5 | #ifndef METADIFF_GRAPH_IR_VISUAL_GRAPHVIZ_H 6 | #define METADIFF_GRAPH_IR_VISUAL_GRAPHVIZ_H 7 | 8 | namespace md{ 9 | namespace graphviz{ 10 | // TODO 11 | } 12 | } 13 | #endif //METADIFF_GRAPH_IR_VISUAL_GRAPHVIZ_H 14 | -------------------------------------------------------------------------------- /include/export.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 19/10/16. 3 | // 4 | 5 | #ifndef METADIFF_GRAPH_IR_EXPORT_H 6 | #define METADIFF_GRAPH_IR_EXPORT_H 7 | 8 | #include "export/graphviz.h" 9 | #include "export/cytoscape.h" 10 | #include "export/json.h" 11 | #include "export/protobuf.h" 12 | 13 | #endif //METADIFF_GRAPH_IR_EXPORT_H 14 | -------------------------------------------------------------------------------- /mock_backend/src/evaluator.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 07/12/16. 3 | // 4 | 5 | #include "mock.h" 6 | 7 | namespace md { 8 | namespace backend { 9 | namespace mock { 10 | std::shared_ptr MockBackend::make_in_memory_function(GraphFunction const &gf) { 11 | std::cout << "DA" << std::endl; 12 | return make_source_gen_function(gf); 13 | } 14 | } 15 | } 16 | } -------------------------------------------------------------------------------- /include/api/grad.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 27/10/16. 3 | // 4 | 5 | #ifndef GRAPH_IR_API_GENERAL_H 6 | #define GRAPH_IR_API_GENERAL_H 7 | #include "type_traits" 8 | namespace md{ 9 | namespace api{ 10 | /** @brief Calculates the gradient of f with respect to the variables w 11 | * 12 | * @param f 13 | * @param w 14 | * @return 15 | */ 16 | NodeVec gradient(Node const f, NodeVec const & w); 17 | } 18 | } 19 | 20 | #endif //GRAPH_IR_API_GENERAL_H 21 | -------------------------------------------------------------------------------- /src/api/random.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 25/10/16. 3 | // 4 | 5 | #include "graph_ir.h" 6 | 7 | namespace md{ 8 | namespace gir{ 9 | 10 | Node GraphInternal::random_uniform(Shape shape) { 11 | Operator op = std::make_shared(this, shape); 12 | return derived_node(op); 13 | } 14 | 15 | Node GraphInternal::random_normal(Shape shape) { 16 | Operator op = std::make_shared(this, shape); 17 | return derived_node(op); 18 | } 19 | } 20 | } 21 | 22 | -------------------------------------------------------------------------------- /src/operators/random.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 25/10/16. 3 | // 4 | 5 | #include "graph_ir.h" 6 | 7 | namespace md{ 8 | namespace gir{ 9 | Node GraphInternal::random_uniform(Shape shape) { 10 | Operator op = std::make_shared(this, shape); 11 | return derived_node(op); 12 | } 13 | 14 | Node GraphInternal::random_normal(Shape shape) { 15 | Operator op = std::make_shared(this, shape); 16 | return derived_node(op); 17 | } 18 | } 19 | } 20 | 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files 2 | *.slo 3 | *.lo 4 | *.o 5 | *.obj 6 | 7 | # Precompiled Headers 8 | *.gch 9 | *.pch 10 | 11 | # Compiled Dynamic libraries 12 | *.so 13 | *.dylib 14 | *.dll 15 | 16 | # Fortran module files 17 | *.mod 18 | *.smod 19 | 20 | # Compiled Static libraries 21 | *.lai 22 | *.la 23 | *.a 24 | *.lib 25 | 26 | # Executables 27 | *.exe 28 | *.out 29 | *.app 30 | 31 | # IDE 32 | .idea/ 33 | # Runtime directory 34 | build/ 35 | # Docs 36 | doxygen/ 37 | # Cmake build 38 | cmake-build-default/ 39 | cmake-build-debug/ 40 | cmake-build-release/ 41 | 42 | -------------------------------------------------------------------------------- /include/api/index.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 24/11/16. 3 | // 4 | 5 | #ifndef GRAPH_IR_API_INDEX_H 6 | #define GRAPH_IR_API_INDEX_H 7 | 8 | namespace md{ 9 | namespace api{ 10 | 11 | Node slice(Node node, Axes axes, std::vector> slices); 12 | 13 | Node slice(Node node, int axis, SymInt start, SymInt end); 14 | 15 | Node index(Node node, Axes axes, std::vector indexes); 16 | 17 | Node index(Node node, int axis, Node index); 18 | 19 | Node cross_index(Node node, int axis, Node index); 20 | } 21 | } 22 | #endif //GRAPH_IR_API_INDEX_H 23 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "external/spdlog"] 2 | path = external/spdlog 3 | url = git@github.com:gabime/spdlog.git 4 | [submodule "external/rapidjson"] 5 | path = external/rapidjson 6 | url = git@github.com:miloyip/rapidjson.git 7 | [submodule "external/symbolic-integers"] 8 | path = external/symbolic-integers 9 | url = git@github.com:Metadiff/symbolic-integers.git 10 | [submodule "external/filesystem"] 11 | path = external/filesystem 12 | url = https://github.com/boostorg/filesystem/ 13 | [submodule "external/system"] 14 | path = external/system 15 | url = git@github.com:boostorg/system.git 16 | [submodule "external/fmt"] 17 | path = external/fmt 18 | url = git@github.com:fmtlib/fmt.git 19 | -------------------------------------------------------------------------------- /include/operators.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 30/09/16. 3 | // 4 | 5 | #ifndef METADIFF_GRAPH_IR_OPERATORS_H 6 | #define METADIFF_GRAPH_IR_OPERATORS_H 7 | 8 | #include "operators/abstract.h" 9 | #include "operators/input.h" 10 | #include "operators/constant.h" 11 | #include "operators/special.h" 12 | #include "operators/shape.h" 13 | #include "operators/logical.h" 14 | #include "operators/arithmetic.h" 15 | #include "operators/reduction.h" 16 | #include "operators/elementwise.h" 17 | #include "operators/linalg.h" 18 | #include "operators/random.h" 19 | #include "operators/index.h" 20 | #include "operators/debug.h" 21 | #include "operators/optimized.h" 22 | 23 | 24 | #endif //METADIFF_GRAPH_IR_OPERATORS_H 25 | -------------------------------------------------------------------------------- /artefacts/default_promotion_table.csv: -------------------------------------------------------------------------------- 1 | X,b8,u8,u16,u32,u64,i8,i16,i32,i64,f8,f16,f32,f64 2 | b8,b8,u8,u16,u32,u64,i8,i16,i32,i64,f8,f16,f32,f64 3 | u8,u8,u8,u16,u32,u64,i8,i16,i32,i64,f8,f16,f32,f64 4 | u16,u16,u16,u16,u32,u64,i16,i16,i32,i64,f16,f16,f32,f64 5 | u32,u32,u32,u32,u32,u64,i32,i32,i32,i64,f32,f32,f32,f64 6 | u64,u64,u64,u64,u64,u64,i64,i64,i64,i64,f64,f64,f64,f64 7 | i8,i8,i8,i16,i32,i64,i8,i16,i32,i64,f8,f16,f32,f64 8 | i16,i16,i16,i16,i32,i64,i16,i16,i32,i64,f16,f16,f32,f64 9 | i32,i32,i32,i32,i32,i64,i32,i32,i32,i64,f32,f32,f32,f64 10 | i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,f64,f64,f64,f64 11 | f8,f8,f8,f16,f32,f64,f8,f16,f32,f64,f8,f16,f32,f64 12 | f16,f16,f16,f16,f32,f64,f16,f16,f32,f64,f16,f16,f32,f64 13 | f32,f32,f32,f32,f32,f64,f32,f32,f32,f64,f32,f32,f32,f64 14 | f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64 15 | -------------------------------------------------------------------------------- /include/export/cytoscape.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 19/10/16. 3 | // 4 | 5 | #ifndef METADIFF_GRAPH_IR_VISUAL_CYTOSCAPE_H 6 | #define METADIFF_GRAPH_IR_VISUAL_CYTOSCAPE_H 7 | 8 | #include "set" 9 | 10 | namespace md{ 11 | using namespace gir; 12 | namespace cytoscape{ 13 | typedef std::vector> Edges; 14 | typedef std::set GroupSet; 15 | 16 | void export_graph(Graph g, std::ostream& s); 17 | 18 | std::pair export_nodes(Graph g, std::ostream& s); 19 | 20 | void export_edges(Edges& edges, std::ostream& s); 21 | 22 | void export_groups(Graph g, GroupSet& groups, std::ostream& s); 23 | 24 | void export_header(std::string name, std::ostream& s); 25 | 26 | void export_footer(std::string name, std::ostream& s); 27 | 28 | } 29 | } 30 | #endif //METADIFF_GRAPH_IR_EXPORT_CYTOSCAPEa_H 31 | -------------------------------------------------------------------------------- /src/api/grad.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 27/10/16. 3 | // 4 | 5 | #include "graph_ir.h" 6 | 7 | namespace md{ 8 | namespace api{ 9 | NodeVec gradient(Node const f, NodeVec const & w) { 10 | // If no parameters return empty vector as well 11 | if(w.size() == 0){ 12 | return NodeVec(); 13 | } 14 | Graph g = f.g(); 15 | // Verify that the objective is a scalar 16 | if (f.order() != 0) { 17 | op_logger("Grad")->error("Requested gradient with respect to a non-scalar function."); 18 | throw InvalidOperatorArgument(NodeVec{f}, "Grad", "Requested gradient with respect to a non-scalar function."); 19 | } 20 | NodeVec u = {g->constant(1)}; 21 | u[0]->grad_level = f->grad_level + ((unsigned int)(1)); 22 | return g->backward_diff(NodeVec{f}, u, w); 23 | }; 24 | } 25 | } -------------------------------------------------------------------------------- /mock_backend/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8) # Could be even lower. Works with 3.0 2 | 3 | # Project name and versions 4 | project(MockBackend) 5 | set(MockBackend_MAJOR_VERSION 0) 6 | set(MockBackend_INTEGERS_VERSION 0) 7 | set(MockBackend_INTEGERS_VERSION 1) 8 | set(MockBackend__VERSION ${MockBackend_MAJOR_VERSION}.${MockBackend_MINOR_VERSION}.${MockBackend_PATCH_VERSION}) 9 | 10 | # Include directory 11 | include_directories(include) 12 | 13 | # Source files 14 | set(MOCK_BACKEND_SOURCES 15 | ${PROJECT_SOURCE_DIR}/src/backend.cpp 16 | ${PROJECT_SOURCE_DIR}/src/source_gen.cpp 17 | ${PROJECT_SOURCE_DIR}/src/evaluator.cpp 18 | ) 19 | 20 | # Build library 21 | if(DEFINED GRAPH_IR_SHARED) 22 | set(MOCK_BACKEND_SHARED GRAPH_IR_SHARED) 23 | else() 24 | set(MOCK_BACKEND_SHARED 1) 25 | endif() 26 | if(MOCK_BACKEND_SHARED) 27 | add_library(mock_backend SHARED ${MOCK_BACKEND_SOURCES}) 28 | else() 29 | add_library(mock_backend STATIC ${MOCK_BACKEND_SOURCES}) 30 | endif() 31 | -------------------------------------------------------------------------------- /include/api.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 26/10/16. 3 | // 4 | 5 | #ifndef METADIFF_GRAPH_IR_API_H 6 | #define METADIFF_GRAPH_IR_API_H 7 | 8 | namespace md { 9 | namespace api { 10 | using namespace gir; 11 | 12 | inline Graph create_graph() { 13 | return std::make_shared(); 14 | } 15 | 16 | // inline Graph default_graph() { 17 | // static std::shared_ptr graph; 18 | // if (not graph) { 19 | // graph = create_graph(); 20 | // } 21 | // return graph; 22 | // } 23 | } 24 | } 25 | 26 | #include "api/input.h" 27 | #include "api/constant.h" 28 | #include "api/special.h" 29 | #include "api/shape.h" 30 | #include "api/logical.h" 31 | #include "api/arithmetic.h" 32 | #include "api/reduction.h" 33 | #include "api/elementwise.h" 34 | #include "api/linalg.h" 35 | #include "api/index.h" 36 | #include "api/debug.h" 37 | #include "api/optimized.h" 38 | #include "api/grad.h" 39 | 40 | #endif //METADIFF_GRAPH_IR_API_H 41 | -------------------------------------------------------------------------------- /src/definitions.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 24/10/16. 3 | // 4 | 5 | #include "graph_ir.h" 6 | 7 | namespace md{ 8 | namespace gir { 9 | 10 | 11 | // NodeGroup::NodeGroup(const std::string name, 12 | // const std::weak_ptr parent) : 13 | // name(name), 14 | // graph(parent.lock()->graph), 15 | // parent(parent) { 16 | // if(parent.lock()->is_base()){ 17 | // full_name = name; 18 | // } else { 19 | // full_name = parent.lock()->full_name; 20 | // full_name += graph->props.group_delimiter; 21 | // full_name += name; 22 | // } 23 | // std::replace(full_name.begin(), full_name.end(), ' ', '_'); 24 | // }; 25 | // 26 | // NodeGroup::NodeGroup(const std::string name, 27 | // const GraphInPtr graph): 28 | // name(name), graph(graph), full_name(name) {}; 29 | // 30 | // bool NodeGroup::is_base() const{ 31 | // return parent.expired(); 32 | // }; 33 | // } 34 | // gir::Graph const default_graph = std::make_shared(); 35 | } 36 | } -------------------------------------------------------------------------------- /include/graph_ir.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 29/09/16. 3 | // 4 | 5 | #ifndef METADIFF_GRAPH_IR_H 6 | #define METADIFF_GRAPH_IR_H 7 | 8 | // External includes 9 | #include "type_traits" 10 | #include "string" 11 | #include "symbolic_integers.h" 12 | #include "spdlog/spdlog.h" 13 | #include "spdlog/sinks/dist_sink.h" 14 | #include "boost/config.hpp" 15 | #include "boost/filesystem.hpp" 16 | #include "spdlog/fmt/fmt.h" 17 | #include "dlfcn.h" 18 | 19 | namespace md{ 20 | template struct disjunction : std::false_type { }; 21 | template struct disjunction : B1 { }; 22 | 23 | template 24 | struct disjunction 25 | : std::conditional> { }; 26 | 27 | template struct conjunction : std::true_type { }; 28 | template struct conjunction : B1 { }; 29 | template 30 | struct conjunction 31 | : std::conditional, B1> {}; 32 | } 33 | 34 | #include "enums.h" 35 | #include "definitions.h" 36 | #include "props.h" 37 | //#include "shared.h" 38 | #include "exceptions.h" 39 | #include "export.h" 40 | #include "node.h" 41 | #include "utils.h" 42 | #include "print.h" 43 | #include "graph.h" 44 | #include "api.h" 45 | #include "operators.h" 46 | #include "backend.h" 47 | //#include "mock.h" 48 | 49 | #endif //METADIFF_GRAPH_IR_H 50 | -------------------------------------------------------------------------------- /src/operators/special.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 18/10/16. 3 | // 4 | 5 | #include "graph_ir.h" 6 | 7 | namespace md{ 8 | namespace gir{ 9 | Node GraphInternal::cast(Node node, DataType data_type) { 10 | // If same data_type do nothing 11 | if(node->data_type == data_type){ 12 | return alias(node); 13 | } 14 | // Standard 15 | return derived_node(std::make_shared(this, node, data_type)); 16 | } 17 | 18 | Node GraphInternal::alias(Node node) { 19 | // Standard 20 | return apply(get_base_node(node)); 21 | } 22 | 23 | Node GraphInternal::broadcast(Node node, Shape shape) { 24 | // If same shape do nothing 25 | if(node->shape == shape){ 26 | return alias(node); 27 | } 28 | // Standard 29 | return derived_node(std::make_shared(this, node, shape)); 30 | } 31 | 32 | Node GraphInternal::make_constant(Node node){ 33 | // If already a constant do nothing 34 | if(not node->is_differentiable){ 35 | return alias(node); 36 | } 37 | // Standard 38 | return apply(node); 39 | } 40 | 41 | Node GraphInternal::select(Node condition, Node if_true, Node if_false){ 42 | // TODO check if node1 == node2 than just return that 43 | return apply(condition, if_true, if_false); 44 | } 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/operators/optimized.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 21/10/16. 3 | // 4 | 5 | #include "graph_ir.h" 6 | 7 | // TODO implement checks for all operators here 8 | 9 | namespace md{ 10 | namespace gir{ 11 | Node GraphInternal::softplus(Node node, double threshold){ 12 | Operator op = std::make_shared(this, node, threshold); 13 | return derived_node(op); 14 | } 15 | 16 | Node GraphInternal::log_sum_exp(Node node, Axes axes, double threshold){ 17 | Operator op = std::make_shared(this, node, axes, threshold); 18 | return derived_node(op); 19 | } 20 | 21 | Node GraphInternal::log_sum_exp(Node node, int axis, double threshold){ 22 | if(axis == 100){ 23 | return log_sum_exp(node, auto_infer_axes(node), threshold); 24 | } else { 25 | return log_sum_exp(node, {axis}, threshold); 26 | } 27 | } 28 | 29 | Node GraphInternal::sigmoid(Node node){ 30 | return apply(node); 31 | } 32 | 33 | Node GraphInternal::softmax(Node node){ 34 | Operator op = std::make_shared(this, node); 35 | return derived_node(op); 36 | } 37 | 38 | Node GraphInternal::binary_cross_entropy_logits(Node p, Node q_logits){ 39 | Operator op = std::make_shared(this, p, q_logits); 40 | return derived_node(op); 41 | } 42 | 43 | Node GraphInternal::categorical_cross_entropy_logits(Node p, Node q_logits){ 44 | Operator op = std::make_shared(this, p, q_logits); 45 | return derived_node(op); 46 | } 47 | } 48 | } -------------------------------------------------------------------------------- /include/operators/random.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 25/10/16. 3 | // 4 | 5 | #ifndef METADIFF_GRAPH_IR_OPERATORS_RANDOM_H 6 | #define METADIFF_GRAPH_IR_OPERATORS_RANDOM_H 7 | 8 | namespace md{ 9 | namespace gir{ 10 | namespace op { 11 | /** Node filled with uniformly distributed random numbers */ 12 | class RandomUniform : public ConstantOperator { 13 | public: 14 | Shape shape; 15 | 16 | RandomUniform(GraphInPtr graph, Shape shape) : 17 | AbstractOperator(graph, "RandomUniform"), 18 | ConstantOperator(DataType(FLOAT, graph->props.max_float)), 19 | shape(shape) {}; 20 | 21 | Operator copy_to(GraphInPtr graph, NodeVec ancestors) const { 22 | return std::make_shared(graph, shape); 23 | } 24 | 25 | Shape get_shape() const { 26 | return shape; 27 | } 28 | 29 | }; 30 | 31 | /** Node filled with normally distributed random numbers */ 32 | class RandomNormal : public ConstantOperator { 33 | public: 34 | Shape shape; 35 | 36 | RandomNormal(GraphInPtr graph, Shape shape) : 37 | AbstractOperator(graph, "RandomNormal"), ConstantOperator(DataType(FLOAT, graph->props.max_float)), 38 | shape(shape) {}; 39 | 40 | Operator copy_to(GraphInPtr graph, NodeVec ancestors) const { 41 | return std::make_shared(graph, shape); 42 | } 43 | 44 | Shape get_shape() const { 45 | return shape; 46 | } 47 | 48 | }; 49 | } 50 | } 51 | } 52 | #endif //METADIFF_GRAPH_IR_OPERATORS_RANDOM_H 53 | -------------------------------------------------------------------------------- /include/print.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 30/09/16. 3 | // 4 | 5 | #ifndef METADIFF_GRAPH_IR_PRINT_H 6 | #define METADIFF_GRAPH_IR_PRINT_H 7 | 8 | namespace md{ 9 | namespace gir{ 10 | std::string to_string(Precision const precision); 11 | 12 | inline std::ostream &operator<<(std::ostream &f, Precision const precision) { 13 | return f << to_string(precision); 14 | } 15 | 16 | std::string to_string(DataType const data_type); 17 | 18 | inline std::ostream &operator<<(std::ostream &f, DataType const data_type) { 19 | return f << to_string(data_type); 20 | } 21 | 22 | std::string to_string(DeviceType const device_type); 23 | 24 | inline std::ostream &operator<<(std::ostream &f, DeviceType const device_type) { 25 | return f << to_string(device_type); 26 | } 27 | 28 | std::string to_string(Policy const policy); 29 | 30 | inline std::ostream &operator<<(std::ostream &f, Policy const policy) { 31 | return f << to_string(policy); 32 | } 33 | 34 | std::string to_string(Shape const shape); 35 | 36 | inline std::ostream &operator<<(std::ostream &f, Shape const shape) { 37 | return f << to_string(shape); 38 | } 39 | 40 | std::string to_string(Device const device); 41 | 42 | inline std::ostream &operator<<(std::ostream &f, Device const device) { 43 | return f << to_string(device); 44 | } 45 | 46 | std::string to_string(Node const node); 47 | 48 | inline std::ostream &operator<<(std::ostream &f, Node const node) { 49 | return f << to_string(node); 50 | } 51 | 52 | std::string to_string(NodeVec const & nodes); 53 | 54 | inline std::ostream &operator<<(std::ostream &f, NodeVec const & nodes) { 55 | return f << to_string(nodes); 56 | } 57 | } 58 | } 59 | #endif //METADIFF_GRAPH_IR_PRINT_H 60 | -------------------------------------------------------------------------------- /include/export/json.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 25/10/16. 3 | // 4 | 5 | #ifndef METADIFF_GRAPH_IR_JSON_H 6 | #define METADIFF_GRAPH_IR_JSON_H 7 | 8 | #define RAPIDJSON_HAS_STDSTRING 1 9 | #include "rapidjson/rapidjson.h" 10 | #include "rapidjson/document.h" 11 | #include "rapidjson/writer.h" 12 | #include "rapidjson/prettywriter.h" 13 | 14 | namespace md{ 15 | namespace json{ 16 | using namespace rapidjson; 17 | 18 | /** Exprots the graph to a stream in a json format */ 19 | void export_graph(Graph const g, std::ostream& s); 20 | 21 | /** Exports a graph to a wrtier from RapidJson */ 22 | void export_graph(Graph const g, PrettyWriter& writer); 23 | 24 | /** Exports a Properties object to a wrtier from RapidJson */ 25 | void export_props(Properties const properties, PrettyWriter& writer); 26 | 27 | /** Exports a Policies object to a wrtier from RapidJson */ 28 | void export_policies(GraphPolicies const policies, PrettyWriter& writer); 29 | 30 | /** Exports a Shape object to a writer from RapidJson */ 31 | void export_shape(Shape const shape, PrettyWriter& writer); 32 | 33 | /** Exports a SymInt object to a writer from RapidJson */ 34 | void export_sym_int(SymInt const sym_int, PrettyWriter& writer); 35 | 36 | /** Exports an ExecutionData object to a writer from RapidJson */ 37 | void export_execution_data(ExecutionData execution, PrettyWriter& writer); 38 | 39 | /** Exports a vector of NodeData objects to a wrtier from RapidJson */ 40 | void export_nodes(std::vector> const & nodes, 41 | PrettyWriter& writer); 42 | 43 | /** Exports an Operator object to a wrtier from RapidJson */ 44 | void export_op(Operator const op, PrettyWriter& writer); 45 | } 46 | } 47 | 48 | #endif //METADIFF_GRAPH_IR_JSON_H 49 | -------------------------------------------------------------------------------- /include/os.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 02/11/16. 3 | // 4 | 5 | #ifndef METADIFF_GRAPH_IR_OS_H 6 | #define METADIFF_GRAPH_IR_OS_H 7 | 8 | namespace md{ 9 | namespace os{ 10 | /** Helper for writing to files */ 11 | struct FtpFile { 12 | const char *filename; 13 | FILE *stream; 14 | }; 15 | 16 | /** Path separator used for the file system */ 17 | const char kPathSeparator = 18 | #ifdef _WIN32 19 | '\\'; 20 | #else 21 | '/'; 22 | #endif 23 | 24 | /** Checks if a the file specified by the path exists on the file system */ 25 | bool exists(std::string path); 26 | 27 | /** Returns false if the path does not exists or if it is not a directory 28 | * TODO - make this cross-platform */ 29 | bool is_dir(std::string path); 30 | 31 | /** Helper function to create directories 32 | * If check is true the directory is created only if it does not exists */ 33 | void create_dir(std::string path, bool check = false); 34 | 35 | /** Function to create a temporary directory and return its path */ 36 | std::string make_temp_dir(); 37 | 38 | /** Joins os paths **/ 39 | std::string join_paths(std::vector paths); 40 | 41 | /** Joins os paths **/ 42 | std::string join_paths(std::string path1, std::string path2); 43 | 44 | /** Joins os paths **/ 45 | std::string join_paths(std::string path1, std::string path2, std::string path3); 46 | 47 | /** Helper writes stream to files */ 48 | size_t write_steram_to_file(void *buffer, size_t size, size_t nmemb, void *stream); 49 | 50 | /** Returns the size of a file in bytes */ 51 | long long file_size(std::string file_name); 52 | 53 | /** Unzips a file */ 54 | int unpack_gz(std::string gz_path); 55 | 56 | // /** A helper function to download a file from a url */ 57 | // void download_file(std::string url, std::string local_path); 58 | } 59 | } 60 | #endif //METADIFF_GRAPH_IR_OS_H 61 | -------------------------------------------------------------------------------- /include/api/debug.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 31/10/16. 3 | // 4 | 5 | #ifndef GRAPH_IR_API_DEBUG_H 6 | #define GRAPH_IR_API_DEBUG_H 7 | 8 | namespace md{ 9 | namespace api{ 10 | /** @brief Prints out the values of the monitored node exactly after the calculation of the anchor. 11 | * If the anchor node is empty than use the monitored node as an anchor. 12 | * Note that this prevents optimization of the monitored node. 13 | * 14 | * @param anchor 15 | * @param monitored 16 | * @return 17 | */ 18 | Node print(Node monitored, std::string msg, Node anchor = Node()); 19 | 20 | /** @brief Retrieves the monitored value together with the outputs at the end of the function execution 21 | * If the anchor node is empty than use the monitored node as an anchor. 22 | * Note that this prevents optimization of the monitored node. 23 | * 24 | * @param anchor 25 | * @param monitored 26 | * @return 27 | */ 28 | Node retrieve(Node monitored, std::string msg, Node anchor = Node()); 29 | 30 | /** @brief Logs to a file the values of the monitored node exactly after the calculation of the anchor. 31 | * If the anchor node is empty than use the monitored node as an anchor. 32 | * Note that this prevents optimization of the monitored node. 33 | * 34 | * @param anchor 35 | * @param monitored 36 | * @return 37 | */ 38 | Node log_to_file(Node monitored, std::string msg, Node anchor = Node()); 39 | 40 | /** @brief Guards that the values of the monitored node do not go outside the interval low, high or go NaN or Inf. 41 | * The check is executed exacly after the calculation of the anchor. 42 | * If the anchor node is empty than use the monitored node as an anchor. 43 | * Note that this prevents optimization of the monitored node. 44 | * 45 | * @param anchor 46 | * @param monitored 47 | * @return 48 | */ 49 | Node guard(Node monitored, std::string msg, double low, double high, Node anchor = Node()); 50 | } 51 | } 52 | #endif //GRAPH_IR_API_DEBUG_H 53 | -------------------------------------------------------------------------------- /src/operators/linalg.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 19/10/16. 3 | // 4 | 5 | #include "graph_ir.h" 6 | 7 | namespace md{ 8 | namespace gir{ 9 | 10 | Node GraphInternal::gemm(NodeVec nodes, std::vector transpositions){ 11 | // TODO check if some of them is not inverse to each other 12 | for(auto i=0; i < nodes.size(); ++i){ 13 | nodes[i] = get_base_node(nodes[i]); 14 | } 15 | // Standard 16 | auto op = std::make_shared(this, nodes, transpositions); 17 | return derived_node(op); 18 | } 19 | 20 | Node GraphInternal::dot(Node left, Node right, bool transpose_left, bool transpose_right){ 21 | return gemm({left, right}, {transpose_left, transpose_right}); 22 | } 23 | 24 | Node GraphInternal::matrix_inverse(Node node){ 25 | // inv(inv(x)) = x 26 | auto base = get_base_node(node); 27 | if(base->op->name == "MatrixInv"){ 28 | return api::alias(base->op->get_parents()[0]); 29 | } 30 | // Standard 31 | return apply(node); 32 | } 33 | 34 | Node GraphInternal::matrix_inverse_mul(Node node1, Node node2, bool transpose){ 35 | // inv(inv(x)) = x 36 | Operator op = std::make_shared(this, node1, node2, transpose); 37 | return derived_node(op); 38 | } 39 | 40 | Node GraphInternal::determinant(Node node){ 41 | // If scalar do nothing 42 | if(node.order() == 0){ 43 | return api::alias(node); 44 | } 45 | // Standard 46 | return apply(node); 47 | } 48 | 49 | Node GraphInternal::log_det(Node node){ 50 | // If scalar return just the log 51 | if(node.order() == 0){ 52 | return log(node); 53 | } 54 | // Standard 55 | return apply(node); 56 | } 57 | 58 | Node GraphInternal::trace(Node node){ 59 | // If scalar do nothing 60 | if(node.order() == 0){ 61 | return api::alias(node); 62 | } 63 | // Standard 64 | return apply(node); 65 | } 66 | } 67 | } -------------------------------------------------------------------------------- /src/operators/constant.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 06/10/16. 3 | // 4 | 5 | #include "graph_ir.h" 6 | 7 | namespace md{ 8 | namespace gir { 9 | Node GraphInternal::constant(double value, DataType data_type, Shape shape) { 10 | // Limit the data type based on the max allowed 11 | DataType limit = limit_type(data_type); 12 | if(limit != data_type){ 13 | auto err = std::make_shared(NodeVec{}, "ConstantValue", data_type, limit); 14 | operate_policy(props.policies.data_type_promotion, logger(), err); 15 | } 16 | // Standard 17 | return derived_node(std::make_shared(this, value, limit, shape)); 18 | } 19 | 20 | Node GraphInternal::PI() { 21 | return constant(M_PI, props.max_float); 22 | } 23 | Node GraphInternal::E(){ 24 | return constant(M_E, props.max_float); 25 | } 26 | 27 | Node GraphInternal::LN_2(){ 28 | return constant(M_LN2, props.max_float); 29 | } 30 | 31 | Node GraphInternal::LN_10(){ 32 | return constant(M_LN10, props.max_float); 33 | } 34 | 35 | Node GraphInternal::zeros(Shape shape, DataType data_type){ 36 | return constant(0.0, data_type, shape); 37 | } 38 | 39 | Node GraphInternal::zeros(Shape shape){ 40 | return constant(0.0, props.max_float, shape); 41 | } 42 | 43 | Node GraphInternal::ones(Shape shape, DataType data_type){ 44 | return constant(1.0, data_type, shape); 45 | } 46 | 47 | Node GraphInternal::ones(Shape shape){ 48 | return constant(1.0, props.max_float, shape); 49 | } 50 | 51 | Node GraphInternal::range(SymInt start, SymInt end, DataType data_type){ 52 | auto op = std::make_shared(this, start, end, data_type); 53 | return derived_node(op); 54 | } 55 | 56 | Node GraphInternal::range(SymInt start, SymInt end){ 57 | return range(start, end, props.max_int); 58 | } 59 | 60 | Node GraphInternal::eye(SymInt size, DataType data_type){ 61 | auto op = std::make_shared(this, size, data_type); 62 | return derived_node(op); 63 | } 64 | 65 | Node GraphInternal::eye(SymInt size){ 66 | return eye(size, props.max_float); 67 | } 68 | } 69 | } -------------------------------------------------------------------------------- /include/shared.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 30/09/16. 3 | // 4 | 5 | #ifndef METADIFF_GRAPH_IR_SHARED_H 6 | #define METADIFF_GRAPH_IR_SHARED_H 7 | 8 | namespace md{ 9 | namespace gir{ 10 | /** A class for initializing all shared variables */ 11 | class Initializer{ 12 | public: 13 | virtual Node initialize(std::array const real_shape) = 0; 14 | }; 15 | 16 | typedef std::shared_ptr Init; 17 | 18 | /** A shared variable is a like a static variable, which is synchronized between devices */ 19 | class SharedVariable { 20 | private: 21 | /** A pointer to the buffer on the host side */ 22 | void * memory_ptr; 23 | /** The real shape after initialization */ 24 | std::array real_shape; 25 | /** Indicator if the variable has been initialized */ 26 | bool initialized = false; 27 | public: 28 | size_t const id; 29 | DataType const data_type; 30 | Shape const shape; 31 | std::string const name; 32 | // Init init; 33 | public: 34 | SharedVariable(size_t const id, DataType const data_type, 35 | Shape const shape, std::string const name): 36 | id(id), 37 | data_type(data_type), 38 | shape(shape), 39 | name(name) {}; 40 | 41 | // void initialize(Init init = Init()); 42 | }; 43 | 44 | typedef std::shared_ptr SharedVar; 45 | 46 | inline std::shared_ptr> get_all_shared(){ 47 | static std::shared_ptr> shared_vars; 48 | if(not shared_vars){ 49 | shared_vars = std::make_shared>(); 50 | } 51 | return shared_vars; 52 | } 53 | 54 | inline SharedVar get_shared(size_t const id){ 55 | return get_all_shared()->at(id); 56 | } 57 | 58 | inline SharedVar make_shared(DataType data_type, Shape shape, std::string name){ 59 | SharedVar var = std::make_shared(get_all_shared()->size(), data_type, shape, name); 60 | get_all_shared()->push_back(var); 61 | return get_all_shared()->back(); 62 | } 63 | } 64 | } 65 | 66 | #endif //METADIFF_GRAPH_IR_SHARED_H 67 | -------------------------------------------------------------------------------- /src/operators/reduction.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 19/10/16. 3 | // 4 | 5 | #include "graph_ir.h" 6 | 7 | namespace md{ 8 | namespace gir{ 9 | Node GraphInternal::sum(Node node, Axes axes){ 10 | return op::apply_reduction(this, node, axes); 11 | } 12 | 13 | Node GraphInternal::sum(Node node, int axis){ 14 | auto a = auto_infer_axes(node); 15 | if(axis == auto_infer){ 16 | return op::apply_reduction(this, node, auto_infer_axes(node)); 17 | } 18 | return op::apply_reduction(this, node, axis); 19 | } 20 | 21 | Node GraphInternal::mean(Node node, Axes axes){ 22 | return op::apply_reduction(this, node, axes); 23 | } 24 | 25 | Node GraphInternal::mean(Node node, int axis){ 26 | if(axis == auto_infer){ 27 | return op::apply_reduction(this, node, auto_infer_axes(node)); 28 | } 29 | return op::apply_reduction(this, node, axis); 30 | } 31 | 32 | Node GraphInternal::prod(Node node, Axes axes){ 33 | return op::apply_reduction(this, node, axes); 34 | } 35 | 36 | Node GraphInternal::prod(Node node, int axis){ 37 | if(axis == auto_infer){ 38 | return op::apply_reduction(this, node, auto_infer_axes(node)); 39 | } 40 | return op::apply_reduction(this, node, axis); 41 | } 42 | 43 | Node GraphInternal::all_true(Node node, Axes axes){ 44 | return op::apply_reduction(this, node, axes); 45 | } 46 | 47 | Node GraphInternal::all_true(Node node, int axis){ 48 | if(axis == auto_infer){ 49 | return op::apply_reduction(this, node, auto_infer_axes(node)); 50 | } 51 | return op::apply_reduction(this, node, axis); 52 | } 53 | 54 | Node GraphInternal::any_true(Node node, Axes axes){ 55 | return op::apply_reduction(this, node, axes); 56 | } 57 | 58 | Node GraphInternal::any_true(Node node, int axis){ 59 | if(axis == auto_infer){ 60 | return op::apply_reduction(this, node, auto_infer_axes(node)); 61 | } 62 | return op::apply_reduction(this, node, axis); 63 | } 64 | } 65 | } 66 | 67 | -------------------------------------------------------------------------------- /include/operators/input.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 03/05/16. 3 | // 4 | 5 | #ifndef METADIFF_GRAPH_IR_OPERATORS_INPUT_H 6 | #define METADIFF_GRAPH_IR_OPERATORS_INPUT_H 7 | 8 | namespace md { 9 | namespace gir { 10 | namespace op { 11 | /** Input variables */ 12 | class Input : public InputOperator { 13 | public: 14 | DataType data_type; 15 | Shape shape; 16 | 17 | Input(GraphInPtr graph, DataType data_type, Shape shape) : 18 | AbstractOperator(graph, "Input"), 19 | data_type(data_type), shape(shape) {} 20 | 21 | Operator copy_to(GraphInPtr graph, std::vector ancestors) const { 22 | return std::make_shared(graph, data_type, shape); 23 | } 24 | 25 | DataType get_data_type() const { 26 | return data_type; 27 | } 28 | 29 | Shape get_shape() const { 30 | return shape; 31 | } 32 | }; 33 | 34 | /** Parameter variables */ 35 | class Parameter : public InputOperator { 36 | public: 37 | std::string full_name; 38 | DataType data_type; 39 | Shape shape; 40 | 41 | 42 | Parameter(GraphInPtr graph, std::string full_name, DataType data_type, Shape shape) : 43 | AbstractOperator(graph, "Parameter"), 44 | full_name(full_name), data_type(data_type), shape(shape) {} 45 | 46 | Operator copy_to(GraphInPtr graph, NodeVec ancestors) const { 47 | return std::make_shared(graph, full_name, data_type, shape); 48 | } 49 | 50 | DataType get_data_type() const { 51 | return data_type; 52 | } 53 | 54 | Shape get_shape() const { 55 | return shape; 56 | } 57 | 58 | // bool equals(Operator const op) const { 59 | // if (name == op->name) { 60 | // auto cast_op = std::static_pointer_cast(op); 61 | // return var->id == cast_op->var->id; 62 | // } 63 | // return false; 64 | // } 65 | }; 66 | } 67 | } 68 | } 69 | #endif //METADIFF_GRAPH_IR_OPERATORS_H 70 | -------------------------------------------------------------------------------- /src/operators/arithmetic.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 04/10/16. 3 | // 4 | 5 | #include "graph_ir.h" 6 | 7 | namespace md{ 8 | namespace gir{ 9 | 10 | Node GraphInternal::add(NodeVec nodes){ 11 | // TODO check for redundancies like x + (-x) 12 | return apply(nodes); 13 | } 14 | 15 | Node GraphInternal::add(Node node1, Node node2){ 16 | return add({node1, node2}); 17 | } 18 | 19 | Node GraphInternal::add(Node node1, Node node2, Node node3){ 20 | return add({node1, node2, node3}); 21 | } 22 | 23 | Node GraphInternal::add(Node node1, Node node2, Node node3, Node node4){ 24 | return add({node1, node2, node3, node4}); 25 | } 26 | 27 | Node GraphInternal::neg(Node node){ 28 | auto base = get_base_node(node); 29 | // The -(-x) = x 30 | if(base->op->name == "Neg"){ 31 | return api::alias(base->op->get_parents()[0]); 32 | } 33 | // Standard 34 | return apply(node); 35 | } 36 | 37 | Node GraphInternal::neg(Node node1, Node node2){ 38 | return add({node1, neg(node2)}); 39 | } 40 | 41 | Node GraphInternal::mul(NodeVec nodes){ 42 | // TODO check for redundancies like x * (1/x) 43 | return apply(nodes); 44 | } 45 | 46 | Node GraphInternal::mul(Node node1, Node node2){ 47 | return mul({node1, node2}); 48 | } 49 | 50 | Node GraphInternal::mul(Node node1, Node node2, Node node3){ 51 | return mul({node1, node2, node3}); 52 | } 53 | 54 | /** Multiplies nodes */ 55 | Node GraphInternal::mul(Node node1, Node node2, Node node3, Node node4){ 56 | return mul({node1, node2, node3, node4}); 57 | } 58 | 59 | Node GraphInternal::div(Node node){ 60 | auto base = get_base_node(node); 61 | // The -(-x) = x 62 | if(base->op->name == "Div"){ 63 | return api::alias(base->op->get_parents()[0]); 64 | } 65 | // Standard 66 | return apply(node); 67 | } 68 | 69 | Node GraphInternal::div(Node node1, Node node2){ 70 | return mul({node1, div(node2)}); 71 | } 72 | 73 | Node GraphInternal::int_div(Node node1, Node node2){ 74 | return apply(node1, node2); 75 | } 76 | 77 | Node GraphInternal::int_mod(Node node1, Node node2){ 78 | return apply(node1, node2); 79 | } 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /include/api/shape.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 27/10/16. 3 | // 4 | 5 | #ifndef GRAPH_IR_API_SHAPE_H 6 | #define GRAPH_IR_API_SHAPE_H 7 | 8 | namespace md{ 9 | namespace api{ 10 | /** @brief Either extracts the diagonal of the input matrix, 11 | * or generates a matrix whose diagonal is equal to the input vector. 12 | * 13 | * @param node A matrix or a vector 14 | * @param g 15 | * @return 16 | */ 17 | Node diag(Node node); 18 | 19 | /** @brief Takes the lower triangular part of a matrix 20 | * if k is an integer takes the lower traingular bit excluding the k major diagonals 21 | * if k < 1 multiplies the diagonal with k 22 | * k can not be less then 0 23 | * 24 | * @param node 25 | * @param k 26 | * @return 27 | */ 28 | Node lower_tri(Node node, double k = 0); 29 | 30 | /** @brief Takes the upper triangular part of a matrix 31 | * if k > 1 it must be an integer takes the upper traingular bit excluding the k major diagonals 32 | * if k < 1 multiplies the diagonal with k 33 | * k can not be less then 0 34 | * @param node 35 | * @param k 36 | * @return 37 | */ 38 | Node upper_tri(Node node, double k = 0); 39 | 40 | /** @brief Reshapes the tensor to a specified shape 41 | * 42 | * @param node 43 | * @param shape 44 | * @param g 45 | * @return 46 | */ 47 | Node reshape(Node node, Shape shape); 48 | 49 | /** @brief Reshapes the tensor to a vector 50 | * 51 | * @param node 52 | * @param shape 53 | * @param g 54 | * @return 55 | */ 56 | Node flatten(Node node); 57 | 58 | /** @brief Reorders the dimensions of the tensor as specified 59 | * 60 | * @param node 61 | * @param axes 62 | * @param g 63 | * @return 64 | */ 65 | Node reorder(Node node, Axes order); 66 | 67 | /** @brief Takes the transpose of a tesor. For 3D and 4D tensor switches the last two dimensions. 68 | * 69 | * @param node 70 | * @param g 71 | * @return 72 | */ 73 | Node transpose(Node node); 74 | 75 | /** @brief Flips the elements of the tensor along every of the specified axes 76 | * 77 | * @param node 78 | * @param axes 79 | * @return 80 | */ 81 | Node flip(Node node, Axes axes); 82 | 83 | /** @brief Flips the elements of the tensor along the specified axis 84 | * 85 | * @param node 86 | * @param axis 87 | * @return 88 | */ 89 | Node flip(Node node, int axis); 90 | } 91 | } 92 | #endif //GRAPH_IR_API_SHAPE_H 93 | -------------------------------------------------------------------------------- /include/operators/debug.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 21/10/16. 3 | // 4 | 5 | #ifndef METADIFF_GRAPH_IR_OPERATORS_DEBUG_H 6 | #define METADIFF_GRAPH_IR_OPERATORS_DEBUG_H 7 | 8 | namespace md{ 9 | namespace gir{ 10 | namespace op{ 11 | /** The operator will print the monitored node at runtime, exactly after the execution */ 12 | class Print: public MonitorOperator{ 13 | public: 14 | Print(GraphInPtr const graph, Node anchor, Node monitored, std::string msg): 15 | AbstractOperator(graph, "Print"), UnaryOperator(anchor), MonitorOperator(monitored, msg) {}; 16 | 17 | Operator copy_to(GraphInPtr graph, NodeVec ancestors) const { 18 | return std::make_shared(graph, ancestors[0], ancestors[1], msg); 19 | } 20 | }; 21 | 22 | /** The operator will return the monitored node at runtime */ 23 | class Retrieve: public MonitorOperator{ 24 | public: 25 | Retrieve(GraphInPtr const graph, Node anchor, Node monitored, std::string msg): 26 | AbstractOperator(graph, "Retrieve"), UnaryOperator(anchor), MonitorOperator(monitored, msg) {}; 27 | 28 | Operator copy_to(GraphInPtr graph, NodeVec ancestors) const { 29 | return std::make_shared(graph, ancestors[0], ancestors[1], msg); 30 | } 31 | }; 32 | 33 | /** The operator will print the monitored node at runtime, exactly after the execution */ 34 | class LogToFile: public MonitorOperator{ 35 | public: 36 | LogToFile(GraphInPtr const graph, Node anchor, Node monitored, std::string msg): 37 | AbstractOperator(graph, "LogToFile"), UnaryOperator(anchor), MonitorOperator(monitored, msg) {}; 38 | 39 | Operator copy_to(GraphInPtr graph, NodeVec ancestors) const { 40 | return std::make_shared(graph, ancestors[0], ancestors[1], msg); 41 | } 42 | }; 43 | 44 | /** The operator will guard that the monitored value that it does not go out of bounds including nans and infs */ 45 | class Guard: public MonitorOperator{ 46 | public: 47 | double low; 48 | double high; 49 | Guard(GraphInPtr const graph, Node anchor, Node monitored, std::string msg, double low, double high): 50 | AbstractOperator(graph, "Guard"), UnaryOperator(anchor), MonitorOperator(monitored, msg), 51 | low(low), high(high){}; 52 | 53 | Operator copy_to(GraphInPtr graph, NodeVec ancestors) const { 54 | return std::make_shared(graph, ancestors[0], ancestors[1], msg, low, high); 55 | } 56 | }; 57 | } 58 | } 59 | } 60 | #endif //METADIFF_GRAPH_IR_OPERATORS_DEBUG_H 61 | -------------------------------------------------------------------------------- /include/props.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 29/09/16. 3 | // 4 | 5 | #ifndef METADIFF_GRAPH_IR_PROPS_H 6 | #define METADIFF_GRAPH_IR_PROPS_H 7 | namespace md{ 8 | namespace gir{ 9 | /** A container class for all user specified properties, which affect the response of the API */ 10 | class Properties{ 11 | public: 12 | /** A Http Proxy, if the API would need internet access this can be handy */ 13 | std::string http_proxy; 14 | /** The delimiter used for combining group names */ 15 | std::string scope_delimiter; 16 | /** The default device */ 17 | Device default_device; 18 | /** The maixmum allowed floating numbers precision */ 19 | Precision max_float; 20 | /** The maixmum allowed integer numbers precision */ 21 | Precision max_int; 22 | /** The collection of policies for reaction to errors */ 23 | GraphPolicies policies; 24 | /** Currently not used */ 25 | // DataType promotion_table[13][13]; 26 | /** Default working directory */ 27 | std::string default_work_dir; 28 | 29 | Properties(): 30 | http_proxy("HTTP_PROXY"), 31 | scope_delimiter("::"), 32 | default_device(HOST), 33 | max_float(p32), 34 | max_int(p32), 35 | policies(WARN, WARN, WARN), 36 | default_work_dir("."){ 37 | // for(auto i=0; i<13; ++i){ 38 | // for(auto j=0; j<13; ++j){ 39 | // promotion_table[i][j] = default_promotion_table[i][j]; 40 | // } 41 | // } 42 | }; 43 | 44 | Properties(std::shared_ptr ptr): 45 | http_proxy(ptr->http_proxy), 46 | scope_delimiter(ptr->scope_delimiter), 47 | default_device(ptr->default_device), 48 | max_float(ptr->max_float), 49 | max_int(ptr->max_int), 50 | policies(ptr->policies) { 51 | // for(auto i=0; i<13; ++i){ 52 | // for(auto j=0; j<13; ++j){ 53 | // promotion_table[i][j] = ptr->promotion_table[i][j]; 54 | // } 55 | // } 56 | } 57 | }; 58 | 59 | /** @brief The default properties are defined by the configuration files and environmental flags on the system 60 | * 61 | * @return 62 | */ 63 | inline std::shared_ptr default_properties(){ 64 | static std::shared_ptr props; 65 | if(not props){ 66 | // TODO Load this from a file 67 | props = std::make_shared(); 68 | } 69 | return props; 70 | } 71 | } 72 | } 73 | #endif //METADIFF_GRAPH_IR_PROPS_H 74 | -------------------------------------------------------------------------------- /include/api/special.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 27/10/16. 3 | // 4 | 5 | #ifndef GRAPH_IR_API_SPECIAL_H 6 | #define GRAPH_IR_API_SPECIAL_H 7 | 8 | 9 | namespace md{ 10 | namespace api{ 11 | /** @brief Casts the input tensor to a different DataType 12 | * 13 | * @param node 14 | * @param data_type 15 | * @return 16 | */ 17 | Node cast(Node node, DataType data_type); 18 | 19 | /** @brief Boradcasts the input tensor to a specified Shape 20 | * 21 | * @param node 22 | * @param shape 23 | * @throw exception if the new Shape is incosistant with the current one 24 | * @return 25 | */ 26 | Node broadcast(Node node, Shape shape); 27 | 28 | /** @brief Makes a new tensor which is an alias of the input 29 | * 30 | * @param node 31 | * @return 32 | */ 33 | Node alias(Node node); 34 | 35 | /** @brief Makes a new tensor which is a view of the input, but is non-diffentiable 36 | * 37 | * @param node 38 | * @return 39 | */ 40 | Node make_constant(Node node); 41 | 42 | /** @brief Adds to the graph an update of the shared node provided to the update 43 | * 44 | * @param shared 45 | * @param update 46 | */ 47 | void update(Node shared, Node update); 48 | 49 | /** @brief Makes a selection elementwise between two tensors based on the condition 50 | * 51 | * Formally R = select(C, A, B), then 52 | * R[i,j,k,l] = A[i, j, k, l] if C[i, j, k, l] = 1 else B[i, j, k, l] 53 | * The condition argument is expected to be a boolean, if not it will be casted to such. 54 | * 55 | * @param condition 56 | * @param if_true 57 | * @param if_false 58 | * @return 59 | */ 60 | Node select(Node condition, Node if_true, Node if_false); 61 | 62 | template , std::is_same>::value>> 64 | Node select(Node condition, T if_true, F if_false){ 65 | Graph g = condition.g(); 66 | return select(condition, wrap(if_true, g), wrap(if_false, g)); 67 | }; 68 | 69 | template , std::is_same>::value>> 71 | Node select(C condition, Node if_true, F if_false){ 72 | Graph g = if_true.g(); 73 | return select(wrap(condition, g), if_true, wrap(if_false, g)); 74 | }; 75 | 76 | template , std::is_same>::value>> 78 | Node select(C condition, T if_true, Node if_false){ 79 | Graph g = if_false.g(); 80 | return select(wrap(condition, g), wrap(if_true, g), if_false); 81 | }; 82 | } 83 | } 84 | #endif //GRAPH_IR_API_SPECIAL_H 85 | -------------------------------------------------------------------------------- /src/node.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 03/10/16. 3 | // 4 | 5 | #include "graph_ir.h" 6 | namespace md{ 7 | namespace gir{ 8 | 9 | NodeData::NodeData(std::weak_ptr const graph, 10 | size_t id, 11 | std::string name, 12 | Device device, 13 | Operator op, 14 | unsigned int grad_level, 15 | std::string scope): 16 | graph(graph), 17 | id(id), 18 | name(name), 19 | data_type(op->get_data_type()), 20 | shape(op->get_shape()), 21 | op(op), 22 | is_input_dependent(op->is_input_dependent()), 23 | is_differentiable(op->is_differentiable()), 24 | grad_level(grad_level), 25 | device(device), 26 | scope(scope) { } 27 | 28 | Graph Node::g() const { 29 | auto ptr = unwrap(); 30 | if (ptr->graph.expired()) { 31 | logger("XXX::node::XXX")->error("Trying to access the graph of a Node, but the pointer has expired"); 32 | ptr->graph.lock(); 33 | } 34 | return ptr->graph.lock(); 35 | } 36 | 37 | // void Node::copy_to(const Graph graph, NodeVec ancestors) const { 38 | // std::shared_ptr ptr = unwrap(); 39 | // g_logger(ptr.g()->name)->trace("Copying node {} to graph {} resulting in node {}", 40 | // ptr->id, graph->name, graph->nodes.size()); 41 | // std::shared_ptr node = std::make_shared(graph, ptr->device); 42 | // node->id = graph->nodes.size(); 43 | // graph->nodes.push_back(node); 44 | // node->device = ptr->device; 45 | // node->name = ptr->name; 46 | // node->is_input_dependent = ptr->is_input_dependent; 47 | // node->is_differentiable = ptr->is_differentiable; 48 | // node->data_type = ptr->data_type; 49 | // node->shape = ptr->shape; 50 | // node->op = ptr->op->copy_to(graph.get(), ancestors); 51 | // node->grad_level = ptr->grad_level; 52 | // node->execution = ptr->execution; 53 | // node->group = ptr->group; 54 | // for (size_t i = 0; i < ancestors.size(); i++) { 55 | // ancestors[i]->children.push_back(node); 56 | // } 57 | // } 58 | 59 | std::shared_ptr Node::unwrap() const{ 60 | if (ptr.expired()) { 61 | logger("XXX::node::XXX")->error("Trying to access the NodeData of a Node, but pointer has expired"); 62 | ptr.lock(); 63 | } 64 | return ptr.lock(); 65 | } 66 | 67 | int Node::order() const { 68 | for(auto i=0; i<4; ++i){ 69 | if(unwrap()->shape[3-i] != 1){ 70 | return 4-i; 71 | } 72 | } 73 | return 0; 74 | } 75 | } 76 | } -------------------------------------------------------------------------------- /include/api/elementwise.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 30/10/16. 3 | // 4 | 5 | #ifndef GRAPH_IR_API_ELEMENTWISE_H 6 | #define GRAPH_IR_API_ELEMENTWISE_H 7 | 8 | namespace md{ 9 | namespace api{ 10 | /** @brief Elementwise square 11 | * 12 | * @param node 13 | * @return 14 | */ 15 | Node square(Node node); 16 | 17 | /** @brief Elementwise square root 18 | * 19 | * @param node 20 | * @return 21 | */ 22 | Node sqrt(Node node); 23 | 24 | /** @brief Elementwise exponential 25 | * 26 | * @param node 27 | * @return 28 | */ 29 | Node exp(Node node); 30 | 31 | /** @brief Elementwise natural logarithm 32 | * 33 | * @param node 34 | * @return 35 | */ 36 | Node log(Node node); 37 | 38 | /** @brief Elementwise logarithm in base 10 39 | * 40 | * @param node 41 | * @return 42 | */ 43 | Node log10(Node node); 44 | 45 | /** @brief Elementwise absoulte value 46 | * 47 | * @param node 48 | * @return 49 | */ 50 | Node abs(Node node); 51 | 52 | /** @brief Elementwise calculates the function log(x+1) 53 | * 54 | * @param node 55 | * @return 56 | */ 57 | Node log1p(Node node); 58 | 59 | /** @brief Elementwise trigonometirc sine 60 | * 61 | * @param node 62 | * @return 63 | */ 64 | Node sin(Node node); 65 | 66 | /** @brief Elementwise trigonometirc cosine 67 | * 68 | * @param node 69 | * @return 70 | */ 71 | Node cos(Node node); 72 | 73 | /** @brief Elementwise trigonometirc tangent 74 | * 75 | * @param node 76 | * @return 77 | */ 78 | Node tan(Node node); 79 | 80 | /** @brief Elementwise trigonometirc cotangent 81 | * 82 | * @param node 83 | * @return 84 | */ 85 | Node cot(Node node); 86 | 87 | /** @brief Elementwise hyperbolic sine 88 | * 89 | * @param node 90 | * @return 91 | */ 92 | Node sinh(Node node); 93 | 94 | /** @brief Elementwise hyperbolic cosine 95 | * 96 | * @param node 97 | * @return 98 | */ 99 | Node cosh(Node node); 100 | 101 | /** @brief Elementwise hyperbolic tangent 102 | * 103 | * @param node 104 | * @return 105 | */ 106 | Node tanh(Node node); 107 | 108 | /** @brief Elementwise hyperbolic cotangent 109 | * 110 | * @param node 111 | * @return 112 | */ 113 | Node coth(Node node); 114 | 115 | /** @brief Takes the elements of node1 to the power as the elements of node2 116 | * 117 | * @param node 118 | * @return 119 | */ 120 | Node pow(Node node1, Node node2); 121 | } 122 | } 123 | #endif //GRAPH_IR_API_ELEMENTWISE_H 124 | -------------------------------------------------------------------------------- /include/api/optimized.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 31/10/16. 3 | // 4 | 5 | #ifndef GRAPH_IR_OPTIMIZED_H 6 | #define GRAPH_IR_OPTIMIZED_H 7 | 8 | namespace md{ 9 | namespace api{ 10 | /** @brief Elementwise calculates log(1 + exp(x)) 11 | * 12 | * @param node 13 | * @param threshold 14 | * @return 15 | */ 16 | Node softplus(Node node, double threshold = 50); 17 | 18 | /** @brief Elementwise calculautes the logistic function - 1 / (1 + exp(-x)) 19 | * 20 | * @param node 21 | * @return 22 | */ 23 | Node sigmoid(Node node); 24 | 25 | /** @brief Calculates log(sum(exp(x), axes)) in a more stable way 26 | * 27 | * @param node 28 | * @param axes 29 | * @return 30 | */ 31 | Node log_sum_exp(Node node, Axes axes); 32 | 33 | /** @brief Takes the log of the sum of exponentials of all of the elements of the variable 34 | * 35 | * @param node 36 | * @return 37 | */ 38 | Node log_sum_exp(Node node); 39 | 40 | /** @brief See log_sum_exp(node, Axes) 41 | * @param node 42 | * @param axis 43 | * @return 44 | */ 45 | Node log_sum_exp(Node node, int axis); 46 | 47 | /** @brief See log_sum_exp(node, Axes) 48 | * 49 | * @param node 50 | * @param axis0 51 | * @param axis1 52 | * @return 53 | */ 54 | Node log_sum_exp(Node node, int axis0, int axis1); 55 | 56 | /** @brief See log_sum_exp(node, Axes) 57 | * 58 | * @param node 59 | * @param axis0 60 | * @param axis1 61 | * @param axis2 62 | * @return 63 | */ 64 | Node log_sum_exp(Node node, int axis0, int axis1, int axis2); 65 | 66 | /** @brief Takes the softmax of the node, normalizing along the axes provieded 67 | * 68 | * @param node 69 | * @return 70 | */ 71 | Node softmax(Node node, Axes axes); 72 | 73 | /** @brief Takes the softmax, normalizing along the last non-unit dimension 74 | * 75 | * @param node 76 | * @return 77 | */ 78 | Node softmax(Node node); 79 | 80 | /** @brief See softmax(node, Axes) 81 | * If the axis=-1 then normalizes all of the elements 82 | * @param node 83 | * @param axis 84 | * @return 85 | */ 86 | Node softmax(Node node, int axis); 87 | 88 | /** @brief See softmax(node, Axes) 89 | * 90 | * @param node 91 | * @param axis0 92 | * @param axis1 93 | * @return 94 | */ 95 | Node softmax(Node node, int axis0, int axis1); 96 | 97 | /** @brief See softmax(node, Axes) 98 | * 99 | * @param node 100 | * @param axis0 101 | * @param axis1 102 | * @param axis2 103 | * @return 104 | */ 105 | Node softmax(Node node, int axis0, int axis1, int axis2); 106 | } 107 | } 108 | #endif //GRAPH_IR_OPTIMIZED_H 109 | -------------------------------------------------------------------------------- /include/api/linalg.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 31/10/16. 3 | // 4 | 5 | #ifndef GRAPH_IR_API_LINALG_H 6 | #define GRAPH_IR_API_LINALG_H 7 | 8 | namespace md{ 9 | namespace api{ 10 | /** @brief Multiplies the matricies in the same order as provided and transposing which required 11 | * 12 | * @param nodes 13 | * @param transpositions 14 | * @return 15 | */ 16 | Node matrix_mul(NodeVec nodes, std::vector transpositions = {}); 17 | 18 | /** @brief Multiplies two matrices, with option for transposing any of them 19 | * 20 | * @param left 21 | * @param right 22 | * @param transpose_left 23 | * @param transpose_right 24 | * @return 25 | */ 26 | Node dot(Node left, Node right, bool transpose_left = false, bool transpose_right = false); 27 | 28 | /** Takes the matrix inverse of the variable 29 | * 30 | * @param node 31 | * @return 32 | */ 33 | Node matrix_inverse(Node node, bool transpose = false); 34 | 35 | /** @brief Takes the matrix inverse of the first matrix and mutyplies with the second 36 | * 37 | * @param node1 38 | * @param node2 39 | * @param transpose 40 | * @return 41 | */ 42 | Node matrix_inverse_mul(Node node1, Node node2, bool transpose_inv = false, bool transpose_mul = false); 43 | 44 | /** @brief Takes the determinant of the matrix 45 | * 46 | * @param node 47 | * @return 48 | */ 49 | Node determinant(Node node); 50 | 51 | /** @brief Takes the lograithm of the determinant of the matrix 52 | * 53 | * @param node 54 | * @return 55 | */ 56 | Node log_det(Node node); 57 | 58 | /** @brief Takes the trace of the matrix 59 | * 60 | * @param node 61 | * @return 62 | */ 63 | Node trace(Node node); 64 | 65 | /** @brief Takes the Kronecker product between the two 66 | * 67 | * @param node1 68 | * @param node2 69 | * @return 70 | */ 71 | Node kron(Node node1, Node node2); 72 | 73 | /** @brief Calculates forward diff for Cholesky decomposition using blas routines 74 | * See https://arxiv.org/pdf/1602.07527v1.pdf 75 | * 76 | * @param cholesky 77 | * @param parent_derivative 78 | * @param lower 79 | * @return 80 | */ 81 | Node cholesky_forward_diff_blas(Node cholesky, Node parent_derivative, bool lower = true); 82 | 83 | /** @brief Calculates backward diff for Cholesky decomposition using blas routines 84 | * See https://arxiv.org/pdf/1602.07527v1.pdf 85 | * 86 | * @param cholesky 87 | * @param my_derivative 88 | * @param lower 89 | * @return 90 | */ 91 | Node cholesky_backward_diff_blas(Node cholesky, Node my_derivative, bool lower = true); 92 | } 93 | 94 | namespace gir { 95 | // Matrix inverse 96 | inline Node operator~(Node node) { 97 | return api::matrix_inverse(node); 98 | } 99 | } 100 | } 101 | 102 | #endif //GRAPH_IR_API_LINALG_H 103 | -------------------------------------------------------------------------------- /src/exceptions.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 06/10/16. 3 | // 4 | 5 | #include "graph_ir.h" 6 | 7 | namespace md { 8 | namespace gir { 9 | // void UnsupportedGradient::log(std::shared_ptr const logger, 10 | // const spdlog::level::level_enum level) const { 11 | // logger->log(level, 12 | // "Requested gradient with respect to node {}, but it is not a scalar.", 13 | // to_string(nodes[0])); 14 | // } 15 | // 16 | // void OtherError::log(std::shared_ptr const logger, 17 | // const spdlog::level::level_enum level) const { 18 | // logger->log(level, "{}. Nodes involved:{}", msg, to_string(nodes)); 19 | // } 20 | // 21 | // void WrongGradient::log(std::shared_ptr const logger, 22 | // const spdlog::level::level_enum level) const { 23 | // logger->log(level, 24 | // "The gradient node with id {} was sent to node with id {} " 25 | // "and operator {}, " 26 | // "but all the parents of that node are constant.", 27 | // nodes[1]->id, nodes[0]->id, op_name); 28 | // } 29 | // 30 | // void ImplicitBroadcast::log(std::shared_ptr const logger, 31 | // const spdlog::level::level_enum level) const { 32 | // logger->log(level, 33 | // "Implicit broadcast in operator {} of nodes:\n{}", 34 | // op_name, 35 | // to_string(nodes)); 36 | // } 37 | // 38 | // void IncompatibleShapes::log(std::shared_ptr const logger, 39 | // const spdlog::level::level_enum level) const { 40 | // logger->log(level, 41 | // "Incompatible shapes in operator {} of nodes: {}", 42 | // op_name, 43 | // to_string(nodes)); 44 | // } 45 | // 46 | // void TypePromotion::log(std::shared_ptr const logger, 47 | // const spdlog::level::level_enum level) const { 48 | // logger->log(level, 49 | // "Promoting expected type {} to {} in operator {}.\nNodes: {}", 50 | // from, 51 | // to, 52 | // op_name, 53 | // to_string(nodes)); 54 | // } 55 | // 56 | // 57 | // void InvalidArguments::log(std::shared_ptr const logger, 58 | // const spdlog::level::level_enum level) const { 59 | // logger->log(level, 60 | // "Invalid arguments to operator {}. Reason: {}.\nNodes: {}", 61 | // op_name, 62 | // reason, 63 | // to_string(nodes)); 64 | // } 65 | // 66 | // void MissingRequiredInput::log(std::shared_ptr const logger, 67 | // const spdlog::level::level_enum level) const { 68 | // logger->log(level, "Missing required input when trying to compile the graph.\n" 69 | // "Missing node: {}\nTarget nodes: {}\nProvided inputs: {}", 70 | // to_string(nodes[0]), to_string(targets), to_string(inputs)); 71 | // } 72 | 73 | } 74 | } -------------------------------------------------------------------------------- /src/operators/logical.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 10/10/16. 3 | // 4 | 5 | #include "graph_ir.h" 6 | 7 | namespace md{ 8 | namespace gir{ 9 | 10 | Node GraphInternal::logical_not(Node node){ 11 | auto base = get_base_node(node); 12 | // The not(not(x)) = x 13 | if(base->op->name == "LogicalNot"){ 14 | return api::alias(base->op->get_parents()[0]); 15 | } 16 | // Standard 17 | return apply(node); 18 | } 19 | 20 | Node GraphInternal::logical_and(Node node1, Node node2){ 21 | // TODO check if node1 == node2 than just return that 22 | return apply(get_base_node(node1), get_base_node(node2)); 23 | } 24 | 25 | Node GraphInternal::logical_or(Node node1, Node node2){ 26 | // TODO check if node1 == node2 than just return that 27 | return apply(get_base_node(node1), get_base_node(node2)); 28 | } 29 | 30 | Node GraphInternal::greater_than(Node node1, Node node2){ 31 | // TODO check if node1 == node2 than just return that 32 | return apply(get_base_node(node1), get_base_node(node2)); 33 | } 34 | 35 | Node GraphInternal::less_than(Node node1, Node node2){ 36 | // TODO check if node1 == node2 than just return that 37 | return apply(get_base_node(node1), get_base_node(node2)); 38 | } 39 | 40 | Node GraphInternal::greater_than_or_equal(Node node1, Node node2){ 41 | // TODO check if node1 == node2 than just return that 42 | return apply(get_base_node(node1), get_base_node(node2)); 43 | } 44 | 45 | Node GraphInternal::less_than_or_equal(Node node1, Node node2){ 46 | // TODO check if node1 == node2 than just return that 47 | return apply(get_base_node(node1), get_base_node(node2)); 48 | } 49 | 50 | Node GraphInternal::equals(Node node1, Node node2){ 51 | // TODO check if node1 == node2 than just return that 52 | return apply(get_base_node(node1), get_base_node(node2)); 53 | } 54 | 55 | Node GraphInternal::not_equals(Node node1, Node node2){ 56 | // TODO check if node1 == node2 than just return that 57 | return apply(get_base_node(node1), get_base_node(node2)); 58 | } 59 | 60 | Node GraphInternal::approx_equals(Node node1, Node node2, double tol){ 61 | // TODO check if node1 == node2 than just return that 62 | auto op = std::make_shared(this, get_base_node(node1), get_base_node(node2), tol); 63 | return derived_node(op); 64 | } 65 | 66 | Node GraphInternal::isNan(Node node){ 67 | auto base = get_base_node(node); 68 | // The not(not(x)) = x 69 | if(base->op->name == "IsNaN"){ 70 | return api::alias(base->op->get_parents()[0]); 71 | } 72 | // Standard 73 | return apply(node); 74 | } 75 | 76 | Node GraphInternal::isInf(Node node){ 77 | auto base = get_base_node(node); 78 | // The not(not(x)) = x 79 | if(base->op->name == "IsInf"){ 80 | return api::alias(base->op->get_parents()[0]); 81 | } 82 | // Standard 83 | return apply(node); 84 | } 85 | } 86 | } 87 | 88 | -------------------------------------------------------------------------------- /src/api/optimized.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 31/10/16. 3 | // 4 | 5 | #include "graph_ir.h" 6 | 7 | namespace md{ 8 | namespace api{ 9 | 10 | Node softplus(Node node, double threshold){ 11 | Graph g = node.g(); 12 | // Standard 13 | Operator op = std::make_shared(g.get(), node, threshold); 14 | return g->derived_node(op); 15 | } 16 | 17 | Node sigmoid(Node node) { 18 | Graph g = node.g(); 19 | // Standard 20 | Operator op = std::make_shared(g.get(), node); 21 | return g->derived_node(op); 22 | } 23 | 24 | Node log_sum_exp(Node node, Axes axes){ 25 | // If no axes do nothing 26 | if(axes.size() == 0){ 27 | return alias(node); 28 | } 29 | // Validate axes 30 | validate_axes(node->shape, axes, "LogSumExp"); 31 | // Verify correctness 32 | Graph g = node.g(); 33 | auto base = get_base_node(node); 34 | if (base->op->name == "Log") { 35 | // If the parent is a log return remove it 36 | auto cast_op = std::dynamic_pointer_cast(base->op); 37 | return log(sum(cast_op->parent, axes)); 38 | } 39 | // Standard 40 | Operator op = std::make_shared(g.get(), node, axes); 41 | return g->derived_node(op); 42 | } 43 | 44 | Node log_sum_exp(Node node){ 45 | return log_sum_exp(node, auto_infer_axes(node->shape)); 46 | } 47 | 48 | Node log_sum_exp(Node node, int axis){ 49 | return log_sum_exp(node, Axes{axis}); 50 | } 51 | 52 | Node log_sum_exp(Node node, int axis0, int axis1){ 53 | return log_sum_exp(node, Axes{axis0, axis1}); 54 | } 55 | 56 | Node log_sum_exp(Node node, int axis0, int axis1, int axis2){ 57 | return log_sum_exp(node, Axes{axis0, axis1, axis2}); 58 | } 59 | 60 | Node softmax(Node node, Axes axes){ 61 | if(node.order() == 0){ 62 | op_logger("Softmax")->error("The input is a scalar."); 63 | throw InvalidOperatorArgument(NodeVec{node}, 64 | "Softmax", "The input is a scalar."); 65 | } 66 | // Validate axes 67 | validate_axes(node->shape, axes, "Softmax"); 68 | // Verify correctness 69 | Graph g = node.g(); 70 | // Standard 71 | Operator op = std::make_shared(g.get(), node, axes); 72 | return g->derived_node(op); 73 | 74 | } 75 | 76 | Node softmax(Node node){ 77 | int axis = 3; 78 | for(auto i=2; i>=0; --i){ 79 | if(node->shape[i] == 1){ 80 | axis = i; 81 | } else { 82 | break; 83 | } 84 | } 85 | return softmax(node, Axes{axis}); 86 | } 87 | 88 | Node softmax(Node node, int axis){ 89 | if(axis == -1){ 90 | return softmax(node, auto_infer_axes(node->shape)); 91 | } else { 92 | return softmax(node, Axes{axis}); 93 | } 94 | } 95 | 96 | Node softmax(Node node, int axis0, int axis1){ 97 | return softmax(node, Axes{axis0, axis1}); 98 | } 99 | 100 | Node softmax(Node node, int axis0, int axis1, int axis2){ 101 | return softmax(node, Axes{axis0, axis1, axis2}); 102 | } 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /src/operators/shape.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by alex on 18/10/16. 3 | // 4 | 5 | #include "graph_ir.h" 6 | 7 | namespace md{ 8 | namespace gir{ 9 | Node GraphInternal::diag(Node node){ 10 | // If it is a scalar nothing to do 11 | if(node.order() == 0){ 12 | return api::alias(node); 13 | } 14 | // diag(diag(x)) = x 15 | auto base = get_base_node(node); 16 | if(base->op->name == "Diag"){ 17 | return api::alias(base->op->get_parents()[0]); 18 | } 19 | // Standard 20 | return derived_node(std::make_shared(this, node)); 21 | } 22 | 23 | Node GraphInternal::reshape(Node node, Shape shape){ 24 | // If shapes are equal do nothing 25 | if(node->shape == shape){ 26 | return api::alias(node); 27 | } 28 | // The reshape(reshape(x)) = reshape(x) 29 | auto base = get_base_node(node); 30 | if(base->op->name == "Reshape"){ 31 | return reshape(base->op->get_parents()[0], shape); 32 | } 33 | // Standard 34 | return derived_node(std::make_shared(this, base, shape)); 35 | } 36 | 37 | Node GraphInternal::reorder(Node node, Axes order){ 38 | // For a scalar do nothing 39 | if(node.order() == 0){ 40 | return api::alias(node); 41 | } 42 | bool ordered = true; 43 | for(auto i=0; iop->name == "Reorder"){ 56 | std::shared_ptr cast_op = std::dynamic_pointer_cast(base->op); 57 | if(cast_op->order.size() != order.size()){ 58 | return derived_node(std::make_shared(this, base, order)); 59 | } 60 | Axes new_order; 61 | for(auto i=0;i < order.size(); ++i){ 62 | new_order.push_back(cast_op->order[order[i]]); 63 | } 64 | ordered = true; 65 | for(auto i=0; iop->get_parents()[0]); 73 | } 74 | return reorder(base->op->get_parents()[0], new_order); 75 | } 76 | // Standard 77 | return derived_node(std::make_shared(this, base, order)); 78 | } 79 | 80 | Node GraphInternal::transpose(Node node){ 81 | // For a scalar do nothing 82 | if(node.order() == 0){ 83 | return api::alias(node); 84 | } 85 | // Switch the last two dimensions 86 | int dims = node.order(); 87 | dims = dims == 1 ? 2 : dims; 88 | Axes order; 89 | for(auto i=0;i