├── prometeo ├── cgen │ ├── code.h │ ├── __init__.py │ ├── string_repr.py │ ├── op_util.py │ ├── node_util.py │ └── source_repr.py ├── cmdline │ └── __init__.py ├── laparser │ ├── __init__.py │ └── laparser.py ├── mem │ └── __init__.py ├── nonlinear │ ├── __init__.py │ ├── casadi_wrapper.h.in │ ├── casadi_wrapper.c.in │ └── nonlinear.py ├── auxl │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── prmt_list.cpython-36.pyc │ └── plist.py ├── old_mem │ ├── mem_manager.py │ ├── malloc_wrapper.c │ ├── test.py │ ├── malloc_wrapper.py │ ├── memory_manager.py │ └── Makefile ├── linalg │ ├── __init__.py │ ├── pvec_blasfeo_wrapper.py │ ├── pvec.py │ ├── blasfeo_wrapper.py │ ├── pmat_blasfeo_wrapper.py │ └── pmat.py ├── __main__.py ├── cpmt │ ├── pmt_aux.h │ ├── pmt_heap.h │ ├── prometeo.h │ ├── pmt_aux.c │ ├── Makefile │ ├── pvec_blasfeo_wrapper.h │ ├── pmat_blasfeo_wrapper.h │ ├── pvec_blasfeo_wrapper.c │ ├── timing.h │ ├── timing.c │ └── pmat_blasfeo_wrapper.c └── __init__.py ├── docs ├── requirements.txt ├── source │ ├── blas_api │ │ ├── .blas_api.rst.swp │ │ └── blas_api.rst │ ├── python_syntax │ │ ├── .python_syntax.rst.swp │ │ └── python_syntax.rst │ ├── performance │ │ └── performance.rst │ ├── installation │ │ └── installation.rst │ ├── index.rst │ └── conf.py ├── Makefile └── make.bat ├── experimental ├── experimental_examples │ ├── new_test.h │ ├── test_prmt_mat_compact.prmt │ ├── blas_interface.py │ ├── new_test.py │ ├── test_prmt_mat.prmt │ ├── test_prmt_mat.py │ ├── new_test.c │ ├── code.prmt │ └── test_blasfeo_ctypes.py ├── laparser │ ├── infile.txt │ └── typed_record.json ├── pymatrix │ └── pymatrix_example.py ├── heap_computation │ ├── cycle_detect.py │ ├── mem_test_code.py │ └── test_mem.py ├── sized_type_checking │ └── case1.py ├── meta_info │ └── case1.py ├── type_and_tuple_indexing.py ├── blas_api │ └── simple_example.py └── dgemm_example │ ├── parse.py │ └── dgemm.py ├── benchmarks ├── Project.toml ├── riccati_benchmark.pdf ├── riccati_benchmark.png ├── riccati_benchmark_blasfeo_api.json ├── riccati_benchmark_numpy.json ├── riccati_benchmark_prometeo.json ├── riccati_benchmark_julia.json ├── riccati.jl ├── run_benchmark_julia.py ├── run_benchmark_numpy.py ├── test_riccati.jl ├── riccati_mass_spring.py.in ├── run_benchmark.py └── test_riccati_numpy.py.in ├── logo └── logo.png ├── gifs ├── helloworld.gif └── helloworld_light.gif ├── figures ├── prometeo-crop.pdf ├── prometeo_crop.png └── simple_ast_annotated.png ├── .gitmodules ├── examples ├── helloworld │ └── helloworld.py ├── fibonacci │ ├── CPU_time.txt │ └── fibonacci.py ├── README.md ├── test │ ├── test_lapack.py │ ├── test_assignments.py │ └── test.py ├── laparser │ └── laparser.py ├── simple_example │ └── simple_example.py ├── pure_python_inline │ └── pure_python_inline.py ├── heap_analysis │ └── heap_analysis.py ├── simple_class │ └── simple_class.py ├── riccati_example │ ├── riccati_numpy.py │ ├── riccati_array.py │ ├── riccati_debug.py │ ├── riccati_compact.py │ ├── riccati_mass_spring.py │ ├── riccati.py │ └── riccati_mass_spring_2.py └── nonlinear │ └── nonlinear.py ├── LICENSE ├── setup.py └── .travis.yml /prometeo/cgen/code.h: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx-rtd-theme 2 | -------------------------------------------------------------------------------- /experimental/experimental_examples/new_test.h: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /prometeo/cmdline/__init__.py: -------------------------------------------------------------------------------- 1 | from .pmt import * 2 | -------------------------------------------------------------------------------- /prometeo/laparser/__init__.py: -------------------------------------------------------------------------------- 1 | from .laparser import * 2 | -------------------------------------------------------------------------------- /prometeo/mem/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ast_analyzer 2 | -------------------------------------------------------------------------------- /prometeo/nonlinear/__init__.py: -------------------------------------------------------------------------------- 1 | from .nonlinear import * 2 | -------------------------------------------------------------------------------- /experimental/laparser/infile.txt: -------------------------------------------------------------------------------- 1 | [[ C = M \ (A.T * B) * (A - B)]] 2 | -------------------------------------------------------------------------------- /benchmarks/Project.toml: -------------------------------------------------------------------------------- 1 | [deps] 2 | MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2" 3 | -------------------------------------------------------------------------------- /logo/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zanellia/prometeo/HEAD/logo/logo.png -------------------------------------------------------------------------------- /prometeo/auxl/__init__.py: -------------------------------------------------------------------------------- 1 | from .plist import * 2 | from ..linalg import pmat 3 | -------------------------------------------------------------------------------- /gifs/helloworld.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zanellia/prometeo/HEAD/gifs/helloworld.gif -------------------------------------------------------------------------------- /figures/prometeo-crop.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zanellia/prometeo/HEAD/figures/prometeo-crop.pdf -------------------------------------------------------------------------------- /figures/prometeo_crop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zanellia/prometeo/HEAD/figures/prometeo_crop.png -------------------------------------------------------------------------------- /gifs/helloworld_light.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zanellia/prometeo/HEAD/gifs/helloworld_light.gif -------------------------------------------------------------------------------- /benchmarks/riccati_benchmark.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zanellia/prometeo/HEAD/benchmarks/riccati_benchmark.pdf -------------------------------------------------------------------------------- /benchmarks/riccati_benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zanellia/prometeo/HEAD/benchmarks/riccati_benchmark.png -------------------------------------------------------------------------------- /figures/simple_ast_annotated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zanellia/prometeo/HEAD/figures/simple_ast_annotated.png -------------------------------------------------------------------------------- /docs/source/blas_api/.blas_api.rst.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zanellia/prometeo/HEAD/docs/source/blas_api/.blas_api.rst.swp -------------------------------------------------------------------------------- /experimental/laparser/typed_record.json: -------------------------------------------------------------------------------- 1 | { 2 | "C": "pmat", 3 | "M": "pmat", 4 | "A": "pmat", 5 | "B": "pmat" 6 | } 7 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "external/blasfeo"] 2 | path = external/blasfeo 3 | url = https://github.com/zanellia/blasfeo 4 | branch = master 5 | -------------------------------------------------------------------------------- /docs/source/python_syntax/.python_syntax.rst.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zanellia/prometeo/HEAD/docs/source/python_syntax/.python_syntax.rst.swp -------------------------------------------------------------------------------- /prometeo/auxl/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zanellia/prometeo/HEAD/prometeo/auxl/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /prometeo/auxl/__pycache__/prmt_list.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zanellia/prometeo/HEAD/prometeo/auxl/__pycache__/prmt_list.cpython-36.pyc -------------------------------------------------------------------------------- /examples/helloworld/helloworld.py: -------------------------------------------------------------------------------- 1 | from prometeo import * 2 | 3 | def main() -> int: 4 | 5 | print('\nhello world!\n') 6 | 7 | return 0 8 | 9 | -------------------------------------------------------------------------------- /prometeo/old_mem/mem_manager.py: -------------------------------------------------------------------------------- 1 | import malloc_wrapper 2 | class mem_manager: 3 | 4 | heap_db = {} 5 | use_prmt_heap = False 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /prometeo/old_mem/malloc_wrapper.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | void prmt_malloc(void **mem, int n_bytes){ 5 | *mem = malloc(n_bytes); 6 | return; 7 | } 8 | -------------------------------------------------------------------------------- /experimental/pymatrix/pymatrix_example.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python3.6 2 | import pymatrix as p 3 | 4 | A = p.Matrix(2,2) 5 | B = p.Matrix(2,2) 6 | C = A*B 7 | print (A[1][1]) 8 | print (C) 9 | -------------------------------------------------------------------------------- /prometeo/linalg/__init__.py: -------------------------------------------------------------------------------- 1 | from .pmat_blasfeo_wrapper import * 2 | from .pvec_blasfeo_wrapper import * 3 | from .pmat import * 4 | from .pvec import * 5 | from .blasfeo_wrapper import * 6 | -------------------------------------------------------------------------------- /prometeo/__main__.py: -------------------------------------------------------------------------------- 1 | """prometeo command line tool""" 2 | 3 | import sys 4 | from prometeo.cmdline import pmt_main 5 | 6 | def console_entry() -> None: 7 | return pmt_main() 8 | 9 | if __name__ == '__main__': 10 | pmt_main() 11 | -------------------------------------------------------------------------------- /examples/fibonacci/CPU_time.txt: -------------------------------------------------------------------------------- 1 | Nuitka: 12.525257587432861 s 2 | Python: 27.128 s 3 | prometeo: 0.699902 s 4 | 5 | # efficient implementation (10E7 runs) 6 | Nuitka: 10.039 s 7 | Python: 11.787 s 8 | PyPy3.7: 1.78 s 9 | prometeo: 0.657075 s 10 | -------------------------------------------------------------------------------- /prometeo/cpmt/pmt_aux.h: -------------------------------------------------------------------------------- 1 | #ifndef PMT_AUX_H_ 2 | #define PMT_AUX_H_ 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | int align_char_to(int num, char **c_ptr); 9 | 10 | #ifdef __cplusplus 11 | } 12 | #endif 13 | 14 | #endif // PMT_AUX_H_ 15 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | This folder contains examples using prometeo. After installing prometeo, you should be able to run `pmt dgemm.py` to execute the script using its Python backend. In order to generate, compile and execute C code using prometeo's C backend, run `pmt dgemm.pt --cgen True`. 2 | -------------------------------------------------------------------------------- /prometeo/cgen/__init__.py: -------------------------------------------------------------------------------- 1 | from .code_gen_c import to_source # NOQA 2 | from .node_util import iter_node, strip_tree, dump_tree # NOQA 3 | from .node_util import ExplicitNodeVisitor # NOQA 4 | from .op_util import get_op_symbol, get_op_precedence # NOQA 5 | from .op_util import symbol_data # NOQA 6 | -------------------------------------------------------------------------------- /prometeo/cpmt/pmt_heap.h: -------------------------------------------------------------------------------- 1 | #ifndef PROMETEO_HEAP_H_ 2 | #define PROMETEO_HEAP_H_ 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | extern void* ___c_pmt_8_heap; 9 | extern void* ___c_pmt_64_heap; 10 | 11 | #ifdef __cplusplus 12 | } 13 | #endif 14 | 15 | #endif // PROMETEO_HEAP_H_ 16 | -------------------------------------------------------------------------------- /prometeo/cpmt/prometeo.h: -------------------------------------------------------------------------------- 1 | #ifndef PROMETEO_H_ 2 | #define PROMETEO_H_ 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | #include "pmat_blasfeo_wrapper.h" 9 | #include "pvec_blasfeo_wrapper.h" 10 | #include "pmt_heap.h" 11 | 12 | #ifdef __cplusplus 13 | } 14 | #endif 15 | 16 | #endif // PROMETEO_H_ 17 | -------------------------------------------------------------------------------- /prometeo/__init__.py: -------------------------------------------------------------------------------- 1 | from . import cgen 2 | from . import linalg 3 | from . import mem 4 | from . import auxl 5 | from . import cmdline 6 | from . import mem 7 | from . import laparser 8 | from . import nonlinear 9 | from .linalg import * 10 | from .nonlinear import pfun 11 | from .auxl import * 12 | from .lib import blasfeo, prometeo 13 | -------------------------------------------------------------------------------- /experimental/heap_computation/cycle_detect.py: -------------------------------------------------------------------------------- 1 | def function1(arg1, arg2): 2 | function2(arg1, arg2) 3 | return 4 | 5 | def function2(arg1, arg2): 6 | function3(arg1, arg2) 7 | return 8 | 9 | def function3(arg1, arg2): 10 | function1(arg1, arg2) 11 | return 12 | 13 | def function4(arg1, arg2): 14 | return 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /prometeo/cpmt/pmt_aux.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "pmt_aux.h" 4 | 5 | int align_char_to(int num, char **c_ptr) 6 | { 7 | size_t s_ptr = (size_t) *c_ptr; 8 | s_ptr = (s_ptr + num - 1) / num * num; 9 | int offset = num - (int) (s_ptr - (size_t)(*c_ptr)); 10 | *c_ptr = (char *) s_ptr; 11 | return offset; 12 | } 13 | -------------------------------------------------------------------------------- /examples/test/test_lapack.py: -------------------------------------------------------------------------------- 1 | from prometeo import * 2 | 3 | n : dims = 5 4 | 5 | 6 | def main() -> int: 7 | 8 | C : pmat = pmat(n,n) 9 | for i in range(5): 10 | C[i,i] = 2.0 11 | 12 | D : pmat = pmat(n,n) 13 | 14 | ipiv : List = plist(int, n) 15 | 16 | pmt_getrf(C, D, ipiv) 17 | pmat_print(D) 18 | 19 | pmt_potrf(C, D) 20 | pmat_print(D) 21 | 22 | return 0 23 | -------------------------------------------------------------------------------- /prometeo/old_mem/test.py: -------------------------------------------------------------------------------- 1 | from malloc_wrapper import * 2 | from memory_manager import * 3 | 4 | memory_manager.use_prmt_heap = 0 5 | memory_manager.prmt_heap_head = prmt_malloc(64*10) 6 | address = memory_manager.prmt_alloc_p(10) 7 | print(address) 8 | memory_manager.use_prmt_heap = 1 9 | address = memory_manager.prmt_alloc_p(10) 10 | print(address) 11 | address = memory_manager.prmt_alloc_p(10) 12 | print(address) 13 | -------------------------------------------------------------------------------- /prometeo/nonlinear/casadi_wrapper.h.in: -------------------------------------------------------------------------------- 1 | #ifndef {{ fun_descriptor.name }}_H_ 2 | #define {{ fun_descriptor.name }}_H_ 3 | 4 | int {{ fun_descriptor.name }}(const double** arg, double** res, int* iw, double* w, int mem); 5 | 6 | void {{ fun_descriptor.name}}_eval(const double* in, double* out, 7 | int (*{{ fun_descriptor.name}})(const double**, double**, int*, double*, int)); 8 | 9 | #endif // {{ fun_descriptor.name }}_H_ 10 | -------------------------------------------------------------------------------- /experimental/sized_type_checking/case1.py: -------------------------------------------------------------------------------- 1 | from prometeo import * 2 | 3 | n : dims = 10 4 | 5 | class ClassA(): 6 | A : pmat = pmat(n,n) 7 | 8 | def method1(self, arg1 : int) -> int: 9 | return 0 10 | 11 | class ClassB(): 12 | attr1 : ClassA = ClassA() 13 | 14 | class ClassC(): 15 | attr2 : ClassB = ClassB() 16 | 17 | def main() -> int: 18 | A : pmat = pmat(n,n) 19 | D : ClassC = ClassC() 20 | return 0 21 | 22 | -------------------------------------------------------------------------------- /experimental/meta_info/case1.py: -------------------------------------------------------------------------------- 1 | from prometeo import * 2 | 3 | n : dims = 10 4 | 5 | class ClassA(): 6 | A : pmat = pmat(n,n) 7 | 8 | def method1(self, arg1 : int) -> int: 9 | return 0 10 | 11 | class ClassB(): 12 | B : ClassA = ClassA() 13 | 14 | def main() -> int: 15 | A : pmat = pmat(n,n) 16 | D : ClassA = ClassA() 17 | return 0 18 | # def main() -> int: 19 | # A : pmat = pmat(n,n) 20 | # E : ClassB = ClassB() 21 | # return 0 22 | 23 | -------------------------------------------------------------------------------- /examples/laparser/laparser.py: -------------------------------------------------------------------------------- 1 | from prometeo import * 2 | 3 | nx: dims = 2 4 | 5 | def main() -> int: 6 | 7 | A: pmat = pmat(nx, nx) 8 | A[0,0] = 0.8 9 | A[0,1] = 0.1 10 | A[1,0] = 0.3 11 | A[1,1] = 0.8 12 | 13 | B: pmat = pmat(nx, nx) 14 | B[0,0] = 1.0 15 | B[0,1] = 2.0 16 | B[1,0] = 0.0 17 | B[1,1] = 1.0 18 | 19 | C: pmat = pmat(nx, nx) 20 | D: pmat = pmat(nx, nx) 21 | 22 | pparse('C = A - A.T \ (B * D).T') 23 | 24 | pmat_print(C) 25 | -------------------------------------------------------------------------------- /prometeo/old_mem/malloc_wrapper.py: -------------------------------------------------------------------------------- 1 | from ctypes import * 2 | import os 3 | 4 | def prmt_malloc(n_bytes): 5 | mem_pointer = c_void_p() 6 | cwd = os.getcwd() 7 | malloc_wrapper = CDLL('%s/libmalloc_wrapper.so.0.1'%cwd) 8 | malloc_wrapper.prmt_malloc(byref(mem_pointer), n_bytes) 9 | return mem_pointer 10 | 11 | 12 | def prmt_cast_to_double_p(pointer): 13 | return cast(pointer, POINTER(c_double)) 14 | 15 | def prmt_cast_to_int_p(pointer): 16 | return cast(pointer, POINTER(c_int)) 17 | 18 | -------------------------------------------------------------------------------- /prometeo/old_mem/memory_manager.py: -------------------------------------------------------------------------------- 1 | import malloc_wrapper as mw 2 | 3 | class memory_manager: 4 | 5 | use_prmt_heap = [] 6 | prmt_heap_head = [] 7 | 8 | @classmethod 9 | def prmt_alloc_p(cls, n_bytes): 10 | if cls.use_prmt_heap == 0: 11 | return mw.prmt_malloc(n_bytes) 12 | else: 13 | old_heap_head = cls.prmt_heap_head 14 | cls.prmt_heap_head.value = cls.prmt_heap_head.value + n_bytes 15 | return old_heap_head 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /experimental/experimental_examples/test_prmt_mat_compact.prmt: -------------------------------------------------------------------------------- 1 | #from prmt_mat import * 2 | #from blasfeo_wrapper import * 3 | 4 | n: int = 10*1 5 | 6 | A: prmt_mat = prmt_mat(n, n) 7 | n2: int = n*n 8 | for i in range(n2): 9 | A[i] = i 10 | 11 | B: prmt_mat = prmt_mat(n, n) 12 | for i in range(n): 13 | B[i*(n + 1)] = 1.0 14 | 15 | C: prmt_mat = prmt_mat(n, n) 16 | 17 | dgemm_nt(A, B, C, C) 18 | 19 | # print results 20 | print('\n\nB = ') 21 | B.print() 22 | print('\n\nA = ') 23 | A.print() 24 | print('\n\nC = ') 25 | C.print() 26 | 27 | -------------------------------------------------------------------------------- /prometeo/old_mem/Makefile: -------------------------------------------------------------------------------- 1 | CFLAGS := -fPIC -O3 -g -Wall -Werror 2 | CC := gcc 3 | MAJOR := 0 4 | MINOR := 1 5 | NAME := malloc_wrapper 6 | VERSION := $(MAJOR).$(MINOR) 7 | 8 | lib: lib$(NAME).so.$(VERSION) 9 | 10 | test: $(NAME)_test 11 | LD_LIBRARY_PATH=. ./$(NAME) 12 | 13 | $(NAME)_test: lib$(NAME).so 14 | $(CC) $(NAME)_test.c -o $@ -L. -l$(NAME) 15 | 16 | lib$(NAME).so: lib$(NAME).so.(VERSION) 17 | ldconfig -v -n . 18 | ln -s lib$(NAME).so.$(MAJOR) lib$(NAME).so 19 | 20 | lib$(NAME).so.$(VERSION): $(NAME).o 21 | $(CC) -shared -Wl,-soname,lib$(NAME).so.$(MAJOR) $^ -o $@ 22 | 23 | clean: 24 | $(RM) $(NAME) *.o *.so* 25 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = source 8 | BUILDDIR = build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /examples/simple_example/simple_example.py: -------------------------------------------------------------------------------- 1 | from prometeo import * 2 | 3 | nv : dims = 10 4 | 5 | def foo(a: int) -> int: 6 | return a + 1 7 | 8 | def main() -> int: 9 | 10 | A: pmat = pmat(nv, nv) 11 | for i in range(nv): 12 | for j in range(nv): 13 | A[i, j] = 1.0 14 | 15 | B: pmat = pmat(nv, nv) 16 | for i in range(nv): 17 | B[0, i] = 2.0 18 | 19 | D: pmat = pmat(nv, nv) 20 | 21 | a : int = 1 22 | b : int = 1 23 | c : int = 1 24 | a = (a + b)*c 25 | b = foo(b) 26 | 27 | pmat_print(A) 28 | pmat_print(B) 29 | pmt_gemm(A,B,D) 30 | pmat_print(D) 31 | 32 | return 0 33 | 34 | -------------------------------------------------------------------------------- /benchmarks/riccati_benchmark_blasfeo_api.json: -------------------------------------------------------------------------------- 1 | [[1.1562e-06, 2], [5.2025e-06, 6], [1.26971e-05, 10], [2.57249e-05, 14], [4.55953e-05, 18], [7.45939e-05, 22], [0.0001162272, 26], [0.00016709, 30], [0.00023447, 34], [0.00031923, 38], [0.00042081, 42], [0.00054511, 46], [0.00069597, 50], [0.00086941, 54], [0.00104857, 58], [0.00128312, 62], [0.00154798, 66], [0.00182006, 70], [0.00213509, 74], [0.00250468, 78], [0.00290148, 82], [0.00337086, 86], [0.00383233, 90], [0.00438224, 94], [0.00500131, 98], [0.0056311, 102], [0.0064388, 106], [0.007184, 110], [0.0080539, 114], [0.0090422, 118], [0.0101444, 122], [0.0116885, 126], [0.0124042, 130], [0.0136965, 134], [0.0151132, 138], [0.0165238, 142], [0.0180701, 146]] -------------------------------------------------------------------------------- /experimental/experimental_examples/blas_interface.py: -------------------------------------------------------------------------------- 1 | from prometeo import * 2 | 3 | n: dims = 10 4 | 5 | A: pmat = pmat(n, n) 6 | for i in range(n): 7 | for j in range(n): 8 | A[i,j] = i*n + j 9 | 10 | B: pmat = pmat(n, n) 11 | for i in range(n): 12 | for j in range(n): 13 | B[i,j] = 0.0 14 | 15 | for i in range(n): 16 | B[i,i] = 1.0 17 | 18 | C: pmat = pmat(n, n) 19 | for i in range(n): 20 | for j in range(n): 21 | C[i,j] = 0.0 22 | 23 | pmt_gemm_nt(A.T, B, C, C) 24 | # pmt_gemm_nt(A, B.T, C, C) 25 | pmt 26 | 27 | # print results 28 | print('\n\nB = ') 29 | pmt_print(B) 30 | print('\n\nA = ') 31 | pmt_print(A) 32 | print('\n\nC = ') 33 | pmt_print(C) 34 | 35 | -------------------------------------------------------------------------------- /benchmarks/riccati_benchmark_numpy.json: -------------------------------------------------------------------------------- 1 | [[4.929829e-05, 2], [7.336357e-05, 6], [0.000107064, 10], [0.0003996603, 14], [0.0005574743, 18], [0.0007907525, 22], [0.001036185, 26], [0.001245177, 30], [0.00162462, 34], [0.001862798, 38], [0.002392628, 42], [0.002768791, 46], [0.002781501, 50], [0.003578358, 54], [0.004304037, 58], [0.005066304, 62], [0.005523829, 66], [0.006356812, 70], [0.007234588, 74], [0.007849369, 78], [0.00864877, 82], [0.01056338, 86], [0.01075954, 90], [0.01155611, 94], [0.01255875, 98], [0.01260738, 102], [0.01522686, 106], [0.01645947, 110], [0.01947773, 114], [0.02093554, 118], [0.02471325, 122], [0.02450719, 126], [0.0285146, 130], [0.03252773, 134], [0.03108652, 138], [0.03682599, 142], [0.04005446, 146]] -------------------------------------------------------------------------------- /experimental/type_and_tuple_indexing.py: -------------------------------------------------------------------------------- 1 | class _psize: 2 | _inner_get = False 3 | 4 | def __getitem__(self, index): 5 | if self._inner_get is True: 6 | self._inner_get = False 7 | return self[index] 8 | 9 | self._inner_get = True 10 | return self 11 | 12 | psize = _psize() 13 | 14 | a : psize[0] = 1 15 | a : psize[0][1] = 1 16 | 17 | from typing import Any 18 | class _psize2: 19 | 20 | def __getitem__(self, index): 21 | if isinstance(index, tuple): 22 | if len(index) == 2: 23 | return Any 24 | else: 25 | raise Exception ('pmat dimensions should be a 2-dimensional tuple.') 26 | 27 | psize2 = _psize2() 28 | a : psize2[0,1] = 1 29 | 30 | -------------------------------------------------------------------------------- /examples/pure_python_inline/pure_python_inline.py: -------------------------------------------------------------------------------- 1 | from prometeo import * 2 | # pure > 3 | import numpy as np 4 | # pure < 5 | 6 | n : dims = 10 7 | 8 | def main() -> int: 9 | 10 | A: pmat = pmat(n, n) 11 | for i in range(n): 12 | for j in range(n): 13 | A[i, j] = 1.0 14 | 15 | # pure > 16 | M = np.array([[1.0, 2.0],[0.0, 0.5]]) 17 | print('\neigenvalues of M computed with ' 18 | 'numpy in pure Python block:\n\n', 19 | np.linalg.eigvals(M), '\n') 20 | # pure < 21 | 22 | B: pmat = pmat(n, n) 23 | for i in range(n): 24 | B[0, i] = 2.0 25 | 26 | C: pmat = pmat(n, n) 27 | 28 | pmat_print(A) 29 | pmat_print(B) 30 | C = A * B 31 | pmat_print(C) 32 | return 0 33 | -------------------------------------------------------------------------------- /benchmarks/riccati_benchmark_prometeo.json: -------------------------------------------------------------------------------- 1 | [[1.3809e-06, 2], [5.2045e-06, 6], [1.27937e-05, 10], [2.6326899999999997e-05, 14], [4.61569e-05, 18], [7.6416e-05, 22], [0.00011697230000000001, 26], [0.00017441, 30], [0.00024607, 34], [0.00033203000000000003, 38], [0.00043853, 42], [0.00058714, 46], [0.0007071299999999999, 50], [0.00087425, 54], [0.00106736, 58], [0.00129861, 62], [0.00153613, 66], [0.00182635, 70], [0.00216257, 74], [0.00249289, 78], [0.00288342, 82], [0.00330878, 86], [0.0038209100000000003, 90], [0.00431535, 94], [0.00490689, 98], [0.0059646, 102], [0.0066799, 106], [0.0074567999999999995, 110], [0.0082712, 114], [0.009186999999999999, 118], [0.0102265, 122], [0.0111129, 126], [0.0122641, 130], [0.013431199999999999, 134], [0.0146833, 138], [0.0159574, 142], [0.0174314, 146]] -------------------------------------------------------------------------------- /examples/fibonacci/fibonacci.py: -------------------------------------------------------------------------------- 1 | from prometeo import * 2 | 3 | def fib(n : int) -> int: 4 | a : int = 0 5 | b : int = 1 6 | c : int = 0 7 | for i in range(n): 8 | c = a + b 9 | a = b 10 | b = c 11 | return b 12 | 13 | # import time 14 | # start = time.time() 15 | # res : int = 0 16 | 17 | # for i in range(30): 18 | # for j in range(1000000): 19 | # res = fib(i) 20 | 21 | # print('%i' %res) 22 | # end = time.time() 23 | # print('execution time = ', end - start) 24 | 25 | def main() -> int: 26 | 27 | res : int = 0 28 | 29 | counter : int = 0 30 | for i in range(30): 31 | for j in range(1000000): 32 | res = fib(counter) 33 | 34 | counter = counter + 1 35 | print('%i' %res) 36 | return 0 37 | -------------------------------------------------------------------------------- /experimental/experimental_examples/new_test.py: -------------------------------------------------------------------------------- 1 | from prometeo.linalg import * 2 | import sys 3 | 4 | n: int = 10 5 | 6 | A: prmt_mat = prmt_mat(n, n) 7 | for i in range(n): 8 | for j in range(n): 9 | A[i][j] = i*n + j 10 | 11 | B: prmt_mat = prmt_mat(n, n) 12 | for i in range(n): 13 | for j in range(n): 14 | B[i][j] = 0.0 15 | 16 | for i in range(n): 17 | B[i][i] = 1.0 18 | 19 | C: prmt_mat = prmt_mat(n, n) 20 | for i in range(n): 21 | for j in range(n): 22 | C[i][j] = 0.0 23 | 24 | i: int = 0.0 25 | while i < n: 26 | i = i + 1 27 | 28 | if i >= 5: 29 | i = i *2 30 | 31 | prmt_gemm_nt(A, B, C, C) 32 | 33 | # print results 34 | print('\n\nB = ') 35 | prmt_print(B) 36 | print('\n\nA = ') 37 | prmt_print(A) 38 | print('\n\nC = ') 39 | prmt_print(C) 40 | 41 | -------------------------------------------------------------------------------- /experimental/heap_computation/mem_test_code.py: -------------------------------------------------------------------------------- 1 | def function_1(): 2 | function_2() 3 | function_3() 4 | return 5 | 6 | def function_2(): 7 | function_3() 8 | function_1() 9 | return 10 | 11 | def function_3(): 12 | return 13 | 14 | # def function_4(): 15 | # function_5() 16 | # return 17 | 18 | # def function_5(): 19 | # function_6() 20 | # function_7() 21 | # return 22 | 23 | 24 | # def function_6(): 25 | # function_5() 26 | # function_7() 27 | # return 28 | 29 | # def function_7(): 30 | # return 31 | 32 | class class1: 33 | def __init__(self): 34 | self.a = 1 35 | def method_1(): 36 | function_1() 37 | 38 | def main(): 39 | class_instance = class1() 40 | # class1.class2.class3.method_1() 41 | class1.method_1() 42 | for i in range(2): 43 | a = 1 44 | -------------------------------------------------------------------------------- /experimental/experimental_examples/test_prmt_mat.prmt: -------------------------------------------------------------------------------- 1 | # import prometeo as prmt 2 | from prometeo.linalg import * 3 | import sys 4 | 5 | n: int = 10 6 | 7 | void_p = int 8 | 9 | data_A: void_p = POINTER(c_double)() 10 | bw.d_zeros(byref(data_A), n, n) 11 | for i in range(n*n): 12 | data_A[i] = i 13 | 14 | A: prmt_mat = prmt_mat(n, n) 15 | A.set(data_A) 16 | 17 | B: prmt_mat = prmt_mat(n, n) 18 | data_B: void_p = POINTER(c_double)() 19 | bw.d_zeros(byref(data_B), n, n) 20 | 21 | for i in range(n): 22 | data_B[i*(n + 1)] = 1.0 23 | 24 | B.set(data_B) 25 | 26 | data_C: void_p = POINTER(c_double)() 27 | bw.d_zeros(byref(data_C), n, n) 28 | 29 | C: prmt_mat = prmt_mat(n, n) 30 | C.set(data_C) 31 | dgemm_nt(A, B, C, C) 32 | 33 | # print results 34 | print('\n\nB = ') 35 | B.print() 36 | print('\n\nA = ') 37 | A.print() 38 | print('\n\nC = ') 39 | C.print() 40 | 41 | -------------------------------------------------------------------------------- /experimental/experimental_examples/test_prmt_mat.py: -------------------------------------------------------------------------------- 1 | # import prometeo as prmt 2 | from prometeo.linalg import * 3 | import sys 4 | 5 | n: int = 10 6 | 7 | void_p = int 8 | 9 | data_A: void_p = POINTER(c_double)() 10 | bw.d_zeros(byref(data_A), n, n) 11 | for i in range(n*n): 12 | data_A[i] = i 13 | 14 | A: prmt_mat = prmt_mat(n, n) 15 | A.set(data_A) 16 | 17 | B: prmt_mat = prmt_mat(n, n) 18 | data_B: void_p = POINTER(c_double)() 19 | bw.d_zeros(byref(data_B), n, n) 20 | 21 | for i in range(n): 22 | data_B[i*(n + 1)] = 1.0 23 | 24 | B.set(data_B) 25 | 26 | data_C: void_p = POINTER(c_double)() 27 | bw.d_zeros(byref(data_C), n, n) 28 | 29 | C: prmt_mat = prmt_mat(n, n) 30 | C.set(data_C) 31 | dgemm_nt(A, B, C, C) 32 | 33 | # print results 34 | print('\n\nB = ') 35 | B.print() 36 | print('\n\nA = ') 37 | A.print() 38 | print('\n\nC = ') 39 | C.print() 40 | 41 | -------------------------------------------------------------------------------- /benchmarks/riccati_benchmark_julia.json: -------------------------------------------------------------------------------- 1 | [[8.7145796e-06, 2], [9.50699016e-05, 6], [0.00019727480050000002, 10], [0.0003355052169, 14], [0.0005048127928, 18], [0.0007358109085000001, 22], [0.001041763539, 26], [0.00207038512, 30], [0.00206396869, 34], [0.0024755824500000004, 38], [0.00365335601, 42], [0.0034702604900000004, 46], [0.00353503739, 50], [0.00422210872, 54], [0.0056587448200000005, 58], [0.01057789679, 62], [0.006911778570000001, 66], [0.00758395776, 70], [0.00844232821, 74], [0.00907628536, 78], [0.00974183034, 82], [0.01155585817, 86], [0.01256300307, 90], [0.013069412340000001, 94], [0.014000742060000002, 98], [0.0179238039, 102], [0.022524654300000002, 106], [0.024384135, 110], [0.0279384358, 114], [0.0257425213, 118], [0.0300238572, 122], [0.0265167779, 126], [0.030903217300000004, 130], [0.031679097, 134], [0.0351298983, 138], [0.036108309400000003, 142], [0.0349719789, 146]] -------------------------------------------------------------------------------- /prometeo/auxl/plist.py: -------------------------------------------------------------------------------- 1 | from ..linalg import pmat 2 | def plist(list_type, sizes): 3 | if list_type == 'pmat': 4 | ret_list = [list_type]*len(sizes) 5 | for i in range(len(sizes)): 6 | ret_list[i] = pmat(sizes[i][0], sizes[i][1]) 7 | elif list_type == 'pvec': 8 | ret_list = [list_type]*len(sizes) 9 | for i in range(len(sizes)): 10 | ret_list[i] = pvec(sizes[i][0]) 11 | elif list_type == 'int': 12 | ret_list = [list_type]*sizes 13 | for i in range(sizes): 14 | ret_list[i] = 0 15 | elif list_type == 'float': 16 | ret_list = [list_type]*sizes 17 | for i in range(sizes): 18 | ret_list[i] = 0.0 19 | else: 20 | raise Exception('Invalid List type: valid types are [pmat, pvec, int, float]. You have {}'.format(list_type)) 21 | 22 | return ret_list 23 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /experimental/experimental_examples/new_test.c: -------------------------------------------------------------------------------- 1 | #include "new_test.h" 2 | int n = 10; 3 | struct prmt_mat * A = prmt_mat(n, n); 4 | for (int i = 0; i < n; i++) { 5 | for (int j = 0; j < n; j++) { 6 | prmt_mat_set_el(A, i, j, (i * n + j)); 7 | } 8 | } 9 | struct prmt_mat * B = prmt_mat(n, n); 10 | for (int i = 0; i < n; i++) { 11 | for (int j = 0; j < n; j++) { 12 | prmt_mat_set_el(B, i, j, (0.0)); 13 | } 14 | } 15 | for (int i = 0; i < n; i++) { 16 | prmt_mat_set_el(B, i, i, (1.0)); 17 | } 18 | struct prmt_mat * C = prmt_mat(n, n); 19 | for (int i = 0; i < n; i++) { 20 | for (int j = 0; j < n; j++) { 21 | prmt_mat_set_el(C, i, j, (0.0)); 22 | } 23 | } 24 | int i = 0.0; 25 | while(i < n) { 26 | i = i + 1; 27 | } 28 | if(i >= 5) { 29 | i = i * 2; 30 | } 31 | prmt_gemm_nt(A, B, C, C); 32 | print('\n\nB = '); 33 | prmt_print(B); 34 | print('\n\nA = '); 35 | prmt_print(A); 36 | print('\n\nC = '); 37 | prmt_print(C); 38 | -------------------------------------------------------------------------------- /examples/heap_analysis/heap_analysis.py: -------------------------------------------------------------------------------- 1 | from prometeo import * 2 | 3 | n :dims = 10 4 | m :dims = 10 5 | 6 | # (ps + m -1)*(nc + n - 1) + (m + n + bs*nc -1) 7 | # ps = 4, nc = 4 8 | # -> 1 pmat = (204 + 64)*8 = 2144 9 | 10 | # worst-case path is [main -> f1 -> f2 -> f3] (10 pmats = 21440 bytes) 11 | 12 | def f1() -> None: 13 | A : pmat = pmat(n,n) 14 | B : pmat = pmat(n,n) 15 | C : pmat = pmat(n,n) 16 | f3() 17 | f4() 18 | return 19 | 20 | def f2() -> None: 21 | A : pmat = pmat(n,n) 22 | B : pmat = pmat(n,n) 23 | f1() 24 | f3() 25 | return 26 | 27 | def f3() -> None: 28 | A : pmat = pmat(n,n) 29 | B : pmat = pmat(n,n) 30 | C : pmat = pmat(n,n) 31 | D : pmat = pmat(n,n) 32 | E : pmat = pmat(n,n) 33 | return 34 | 35 | def f4() -> None: 36 | f5() 37 | return 38 | 39 | def f5() -> None: 40 | f4() 41 | return 42 | 43 | def main() -> int: 44 | f1() 45 | f2() 46 | return 0 47 | -------------------------------------------------------------------------------- /examples/simple_class/simple_class.py: -------------------------------------------------------------------------------- 1 | from prometeo import * 2 | 3 | def foo(a : int) -> int: 4 | return a 5 | 6 | class Number: 7 | def __init__(self) -> None: 8 | self.value: int = 1 9 | 10 | def get_value(self) -> int: 11 | 12 | return self.value 13 | 14 | class Simple_class: 15 | def __init__(self) -> None: 16 | self.a: int = 1 17 | self.number : Number = Number() 18 | 19 | def method1(self) -> int: 20 | 21 | return self.a 22 | 23 | def method1(self, b : int) -> int: 24 | 25 | c : int = self.a + b 26 | 27 | return c 28 | 29 | def main() -> int: 30 | 31 | S : Simple_class = Simple_class() 32 | a : int = 1 33 | # b : float = 1.0 34 | a = foo(a) 35 | a = S.method1(a) 36 | a = S.method1() 37 | a = S.method1(S.number.value) 38 | a = S.number.get_value() 39 | a = S.number.value 40 | a = S.number.get_value() 41 | 42 | return 0 43 | -------------------------------------------------------------------------------- /experimental/heap_computation/test_mem.py: -------------------------------------------------------------------------------- 1 | # from prometeo.mem.ast_analyzer import get_call_graph 2 | from prometeo.mem.ast_analyzer import compute_reach_graph 3 | from prometeo.mem.ast_analyzer_2 import ast_visitor 4 | from prometeo.mem.ast_analyzer_2 import compute_reach_graph 5 | # from prometeo.cgen.code_gen import to_source 6 | import ast 7 | if __name__ == '__main__': 8 | import argparse 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('-i', '--input', help='Input .py file', required=True) 11 | args = parser.parse_args() 12 | tree = ast.parse(open(args.input).read()) 13 | # call_graph = get_call_graph(tree) 14 | 15 | visitor = ast_visitor() 16 | # import pdb; pdb.set_trace() 17 | visitor.visit(tree) 18 | call_graph = visitor.callees 19 | print(call_graph) 20 | 21 | # to_source(tree) 22 | # print('call graph:\n', call_graph) 23 | # import pdb; pdb.set_trace() 24 | reach_map = compute_reach_graph(call_graph) 25 | print('reach_map:\n', reach_map) 26 | -------------------------------------------------------------------------------- /examples/riccati_example/riccati_numpy.py: -------------------------------------------------------------------------------- 1 | from numpy import * 2 | from scipy import linalg 3 | 4 | nx = 2 5 | nu = 2 6 | nxu = 4 7 | N = 5 8 | 9 | A = array([[0.8, 0.1], [0.3, 0.8]]) 10 | B = array([[1.0, 0.0], [0.0, 1.0]]) 11 | Q = array([[1.0, 0.0], [0.0, 1.0]]) 12 | R = array([[1.0, 0.0], [0.0, 1.0]]) 13 | P = Q 14 | 15 | BA = zeros((nx, nxu)) 16 | M = zeros((nxu, nxu)) 17 | Mxx = zeros((nx, nx)) 18 | for i in range(N): 19 | BA = concatenate((B,A),1) 20 | BAtP = dot(transpose(BA), P) 21 | M = zeros((nxu, nxu)) 22 | M[0:nu, 0:nu] = R 23 | M[nu:nu+nx, nu:nu+nx] = Q 24 | M = M + dot(BAtP, BA) 25 | L = linalg.cholesky(M) 26 | print('L:\n', L) 27 | Mxx = L[nu:nu+nx, nu:nu+nx] 28 | P = dot(transpose(Mxx), Mxx) 29 | print('P:\n', P) 30 | 31 | P = Q 32 | for i in range(N): 33 | P = Q + dot(transpose(A),dot(P,A)) - dot(dot(transpose(A),dot(P,B)), \ 34 | linalg.solve(R + dot(transpose(B), dot(P,B)), \ 35 | dot(dot(transpose(B),P), A))) 36 | 37 | print('P:\n', P) 38 | 39 | -------------------------------------------------------------------------------- /experimental/blas_api/simple_example.py: -------------------------------------------------------------------------------- 1 | from prometeo import * 2 | 3 | r : dims = 4 4 | 5 | def main() -> int: 6 | 7 | A: pmat = pmat(r, r) 8 | for i in range(r): 9 | for j in range(r): 10 | A[i, j] = 1.0 11 | 12 | B: pmat = pmat(r, r) 13 | for i in range(r): 14 | B[0, i] = 2.0 15 | 16 | C: pmat = pmat(r, r) 17 | 18 | print("A:\n") 19 | pmat_print(A) 20 | print("B:\n") 21 | pmat_print(B) 22 | 23 | print("gemm(A, B, C):\n") 24 | pmt_gemm(A, B, C) 25 | pmat_print(C) 26 | 27 | for i in range(r): 28 | for j in range(r): 29 | C[i, j] = 5.0 30 | 31 | print("gemm(A, B, C, beta=1.0):\n") 32 | pmt_gemm(A, B, C, beta=1.0) 33 | pmat_print(C) 34 | 35 | D: pmat = pmat(r, r) 36 | pmt_gemm(A, B, C, D) 37 | print("gemm(A, B, C, D):\n") 38 | pmat_print(C) 39 | 40 | pmt_gemm(A, B.T, C, alpha=0.5, beta=0.5) 41 | print("gemm(A, B.T, C, alpha=0.0, beta=0.1):\n") 42 | pmat_print(C) 43 | 44 | return 0 45 | 46 | -------------------------------------------------------------------------------- /experimental/experimental_examples/code.prmt: -------------------------------------------------------------------------------- 1 | class class_m: 2 | field1: double = 1.0 3 | field2: double = 2.0 4 | 5 | class class_n: 6 | field1: double = 1.0 7 | field2: double = 2.0 8 | 9 | class p_class: 10 | attr_1: int = 1 11 | attr_2: double = None 12 | 13 | def method_1(arg1: class_m, arg2: class_n) -> double: 14 | a: double = arg1.field1 15 | b: double = arg2.field2 16 | c: double = a*b + b*a*a 17 | 18 | return c 19 | 20 | def method_1(arg1: class_m) -> double: 21 | a: double = arg1.field1 22 | b: double = arg1.field2 23 | c: double = a*b + b*a*a 24 | 25 | return c 26 | 27 | def method_2(arg1: class_m) -> double: 28 | c: double = arg1.field1*arg1.field1 29 | return c 30 | 31 | def loop_method(arg1: class_m) -> int: 32 | res: int = 0 33 | var1: double = 0 34 | #var1 = self.method_2(arg1) 35 | 36 | for i in range(var1): 37 | res = res + 1 38 | return res 39 | 40 | 41 | -------------------------------------------------------------------------------- /benchmarks/riccati.jl: -------------------------------------------------------------------------------- 1 | # author: Tommaso Sartor 2 | 3 | module riccati 4 | 5 | using LinearAlgebra 6 | using MKL 7 | 8 | # Riccat recursion 9 | 10 | function riccati_trf(N, nx, nu, BAt, RSQ, L, LN, BAtL, M) 11 | 12 | LN = RSQ[nu+1:nu+nx, nu+1:nu+nx]; 13 | LAPACK.potrf!('L', LN); 14 | 15 | for ii in 1:N 16 | BAtL = copy(BAt); 17 | if ii==1 18 | # BLAS.gemm!('N', 'N', 1.0, BAt, LN, 0.0, BAtL); 19 | BLAS.trmm!('R', 'L', 'N', 'N', 1.0, LN, BAtL); 20 | else 21 | # BLAS.gemm!('N', 'N', 1.0, BAt, L[nu+1:nu+nx,nu+1:nu+nx,N+2-ii], 0.0, BAtL); 22 | # display(L[:,:,N+2-ii]) 23 | # println() 24 | # BLAS.trmm!('R', 'L', 'N', 'N', 1.0, L[nu+1:nu+nx,nu+1:nu+nx,N+2-ii], BAtL); 25 | BLAS.trmm!('R', 'L', 'N', 'N', 1.0, view(L, nu+1:nu+nx, nu+1:nu+nx, N+2-ii), BAtL); 26 | end 27 | # M = copy(RSQ); 28 | # BLAS.syrk!('L', 'N', 1.0, BAtL, 1.0, M); 29 | # LAPACK.potrf!('L', M); 30 | # L[:,:,N+1-ii] = copy(M); 31 | L[:,:,N+1-ii] .= RSQ; 32 | MM = view(L, :, :, N+1-ii); 33 | BLAS.syrk!('L', 'N', 1.0, BAtL, 1.0, MM); 34 | LAPACK.potrf!('L', MM); 35 | end 36 | 37 | #display(L) 38 | #println() 39 | 40 | end 41 | 42 | export riccati_trf 43 | 44 | end 45 | -------------------------------------------------------------------------------- /prometeo/cpmt/Makefile: -------------------------------------------------------------------------------- 1 | PMT_FLAGS = -DMEASURE_TIMINGS 2 | CC = gcc 3 | # CC = g++ 4 | # CC = clang 5 | CFLAGS = -std=c99 6 | CFLAGS += -fPIC 7 | CFLAGS += $(PMT_FLAGS) 8 | PREFIX = ./.. 9 | SRCS += pmat_blasfeo_wrapper.c 10 | SRCS += pvec_blasfeo_wrapper.c 11 | SRCS += pmt_aux.c 12 | SRCS += timing.c 13 | CFLAGS +=-I../../external/blasfeo/include/ 14 | OPT_LD_FLAGS = 15 | 16 | all: $(SRCS) 17 | $(CC) -c $(CFLAGS) $(SRCS) 18 | 19 | blasfeo: 20 | ( cd ../../external/blasfeo; $(MAKE) static_library -j4 && $(MAKE) shared_library ) 21 | 22 | shared: all 23 | $(CC) $(OPT_LD_FLAGS) -shared -o libcpmt.so *.o 24 | 25 | install_shared: blasfeo all shared 26 | mkdir -p $(PREFIX)/lib/blasfeo 27 | cp -f ../../external/blasfeo/lib/libblasfeo.so $(PREFIX)/lib/blasfeo 28 | mkdir -p $(PREFIX)/include/blasfeo 29 | cp -f ../../external/blasfeo/include/*.h $(PREFIX)/include/blasfeo 30 | mkdir -p $(PREFIX)/lib/prometeo 31 | cp -f ./libcpmt.so $(PREFIX)/lib/prometeo 32 | mkdir -p $(PREFIX)/include/prometeo 33 | cp -f ./*.h $(PREFIX)/include/prometeo 34 | 35 | clean: 36 | rm -f *.o 37 | rm -f *.so 38 | 39 | clean_all: clean 40 | ( cd ../../external/blasfeo; $(MAKE) deep_clean ) 41 | -------------------------------------------------------------------------------- /docs/source/performance/performance.rst: -------------------------------------------------------------------------------- 1 | 2 | Performance 3 | =========== 4 | 5 | Since prometeo programs transpile to pure C code that calls the high performance linear algebra library BLASFEO (`publication `__, `code `__), execution time can be comparable to hand-written high-performance code. The figure below shows a comparison of the CPU time necessary to carry out a Riccati factorization using highly optimized hand-written C code with calls to BLASFEO and the ones obtained with prometeo transpiled code from `this example `__). The computation times obtained with NumPy and Julia are added too for comparison - notice however that these last two implementations of the Riccati factorization are **not as easily embeddable** as the C code generated by prometeo and the hand-coded C implementation. All the benchmarks have been run on a Dell XPS-9360 equipped with an i7-7560U CPU running at 2.30 GHz (to avoid frequency fluctuations due to thermal throttling). 6 | 7 | .. image:: ../../../benchmarks/riccati_benchmark.png 8 | :alt: my-picture1 9 | -------------------------------------------------------------------------------- /prometeo/cpmt/pvec_blasfeo_wrapper.h: -------------------------------------------------------------------------------- 1 | #ifndef PROMETEO_PRMT_VEC_BLASFEO_WRAPPER_H_ 2 | #define PROMETEO_PRMT_VEC_BLASFEO_WRAPPER_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #ifdef __cplusplus 13 | extern "C" { 14 | #endif 15 | 16 | // (dummy) pvec wrapper to blasfeo_dvec 17 | struct pvec { 18 | struct blasfeo_dvec *bvec; 19 | }; 20 | 21 | struct pvec * c_pmt_create_pvec(int m); 22 | void c_pmt_assign_and_advance_blasfeo_dvec(int m, struct blasfeo_dvec **bvec); 23 | 24 | // BLAS API 25 | // void c_pmt_dgemm(struct pvec *A, struct pvec *B, struct pvec *C, struct pvec *D); 26 | // void c_pmt_dgead(double alpha, struct pvec *A, struct pvec *B); 27 | 28 | // auxiliary 29 | void c_pmt_pvec_fill(struct pvec *a, double fill_value); 30 | void c_pmt_pvec_set_el(struct pvec *a, int i, double value); 31 | double c_pmt_pvec_get_el(struct pvec *a, int i); 32 | void c_pmt_pvec_copy(struct pvec *a, struct pvec *b); 33 | void c_pmt_pvec_print(struct pvec *a); 34 | 35 | #ifdef __cplusplus 36 | } 37 | #endif 38 | 39 | #endif // PROMETEO_PRMT_VEC_BLASFEO_WRAPPER_H_ 40 | 41 | -------------------------------------------------------------------------------- /examples/riccati_example/riccati_array.py: -------------------------------------------------------------------------------- 1 | from prometeo import * 2 | 3 | nx: dims = 2 4 | nu: dims = 2 5 | nxu: dims = nx + nu 6 | N: dims = 5 7 | 8 | def main() -> int: 9 | 10 | # number of repetitions for timing 11 | nrep : int = 10000 12 | 13 | A: pmat = pmat(nx, nx) 14 | A[0,0] = 0.8 15 | A[0,1] = 0.1 16 | A[1,0] = 0.3 17 | A[1,1] = 0.8 18 | 19 | B: pmat = pmat(nx, nu) 20 | B[0,0] = 1.0 21 | B[1,1] = 1.0 22 | 23 | Q: pmat = pmat(nx, nx) 24 | Q[0,0] = 1.0 25 | Q[1,1] = 1.0 26 | 27 | R: pmat = pmat(nu, nu) 28 | R[0,0] = 1.0 29 | R[1,1] = 1.0 30 | 31 | RSQ: pmat = pmat(nxu, nxu) 32 | Lxx: pmat = pmat(nx, nx) 33 | M: pmat = pmat(nxu, nxu) 34 | w_nxu_nx: pmat = pmat(nxu, nx) 35 | BAt : pmat = pmat(nxu, nx) 36 | BA : pmat = pmat(nx, nxu) 37 | pmat_hcat(B, A, BA) 38 | pmat_tran(BA, BAt) 39 | 40 | RSQ[0:nu,0:nu] = R 41 | RSQ[nu:nu+nx,nu:nu+nx] = Q 42 | 43 | # array-type Riccati factorization 44 | for i in range(nrep): 45 | pmt_potrf(Q, Lxx) 46 | M[nu:nu+nx,nu:nu+nx] = Lxx 47 | for i in range(1, N): 48 | pmt_trmm_rlnn(Lxx, BAt, w_nxu_nx) 49 | pmt_syrk_ln(w_nxu_nx, w_nxu_nx, RSQ, M) 50 | pmt_potrf(M, M) 51 | Lxx[0:nx,0:nx] = M[nu:nu+nx,nu:nu+nx] 52 | 53 | return 0 54 | 55 | -------------------------------------------------------------------------------- /docs/source/installation/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | PyPI installation 5 | ***************** 6 | 7 | prometeo can be installed through PyPI with `pip install prometeo-dsl`. Notice that, since prometeo makes extensive use of `type hints `__ to equip Python code with static typing information, the minimum Python version required is 3.6. 8 | 9 | manual installation 10 | ******************* 11 | 12 | If you want to install prometeo building the sources on your local machine you can proceed as follows: 13 | 14 | - Run `git submodule update --init` to clone the submodules. 15 | - Run `make install_shared` from `/prometeo/cpmt` to compile and install the shared library associated with the C backend. Notice that the default installation path is `/prometeo/cpmt/install`. 16 | - You need Python 3.6. or later. 17 | - Optional: to keep things clean you can setup a virtual environment with `virtualenv --python= `. 18 | - Run `pip install -e .` from `` to install the Python package. 19 | 20 | Finally, you can run the examples in `/examples` with `pmt .py --cgen=`, where the `--cgen` flag determines whether the code is executed by the Python interpreter or C code is generated compiled and run. 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2021, Andrea Zanelli 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /prometeo/nonlinear/casadi_wrapper.c.in: -------------------------------------------------------------------------------- 1 | #include "casadi_wrapper_{{ fun_descriptor.name }}.h" 2 | 3 | void {{ fun_descriptor.name }}_eval(const double* inputs, double* outputs, 4 | int (*{{ fun_descriptor.name }})(const double**, double**, int*, double*, int)) { 5 | 6 | {% set arg_count = 0 %} 7 | {% set offset = 0 %} 8 | {% for arg in fun_descriptor.args %} 9 | const double *in_{{ arg_count }} = inputs + {{ offset }}; 10 | {% set arg_count = arg_count + 1 %} 11 | {% set offset = offset + arg['size'][0]*arg['size'][1] %} 12 | {% endfor %} 13 | 14 | {% set out_count = 0 %} 15 | {% set offset = 0 %} 16 | {% for out in fun_descriptor.outs %} 17 | const double *out_{{ out_count }} = inputs + {{ offset }}; 18 | {% set out_count = out_count + 1 %} 19 | {% set offset = offset + out['size'][0]*out['size'][1] %} 20 | {% endfor %} 21 | 22 | const double* casadi_arg[{{ fun_descriptor.args|length }}]; 23 | double* casadi_res[{{ fun_descriptor.outs|length }}]; 24 | 25 | {% for argnum in range(0, fun_descriptor.args|length) %} 26 | casadi_arg[{{ argnum }}] = in_{{ argnum }}; 27 | {% endfor %} 28 | 29 | {% for outnum in range(0, fun_descriptor.outs|length) %} 30 | casadi_res[{{ outnum }}] = out_{{ outnum }}; 31 | {% endfor %} 32 | 33 | int* iw = 0; 34 | double* w = 0; 35 | int mem = 0; 36 | 37 | {{ fun_descriptor.name }}(casadi_arg, casadi_res, iw, w, mem); 38 | } 39 | -------------------------------------------------------------------------------- /benchmarks/run_benchmark_julia.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import subprocess 3 | import json 4 | import matplotlib 5 | import matplotlib.pyplot as plt 6 | 7 | NM = range(2,150,4) 8 | # NM = range(2,20,2) 9 | NREP_small = 10000 10 | NREP_medium = 100 11 | NREP_large = 10 12 | AVG_CPU_TIME = [] 13 | res_file = 'riccati_benchmark_julia.json' 14 | RUN = True 15 | UPDATE_res = True 16 | 17 | if not UPDATE_res: 18 | print('Warning: not updating result file!') 19 | 20 | if RUN: 21 | # get MKL 22 | subprocess.run(["julia", "-q", "-e", "import Pkg; Pkg.activate(\".\"); Pkg.instantiate()"], check=True) 23 | for i in range(len(NM)): 24 | print('running Riccati benchmark for case NM = {}'.format(NM[i])) 25 | code = "" 26 | if NM[i] < 30: 27 | NREP = NREP_small 28 | elif NM[i] < 100: 29 | NREP = NREP_medium 30 | else: 31 | NREP = NREP_large 32 | 33 | proc = subprocess.Popen([f"julia -q --project=. test_riccati.jl {NM[i]} {NREP}"], shell=True, stdout=subprocess.PIPE) 34 | 35 | try: 36 | outs, errs = proc.communicate() 37 | except TimeOutExpired: 38 | proc.kill() 39 | print('Exception raised at NM = {}'.format(NM[i])) 40 | outs, errs = proc.communicate() 41 | 42 | AVG_CPU_TIME.append([float(outs.decode()), NM[i]]) 43 | 44 | if UPDATE_res: 45 | with open(res_file, 'w+') as res: 46 | json.dump(AVG_CPU_TIME, res) 47 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages, dist 2 | 3 | class BinaryDistribution(dist.Distribution): 4 | def has_ext_modules(foo): 5 | return True 6 | 7 | # read the contents of your README file 8 | from os import path 9 | this_directory = path.abspath(path.dirname(__file__)) 10 | with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f: 11 | long_description = f.read() 12 | 13 | setup(name='prometeo-dsl', 14 | version='0.0.11', 15 | python_requires='>=3.8, <3.10', 16 | description='Python-to-C transpiler and domain specific language for embedded high-performance computing', 17 | url='http://github.com/zanellia/prometeo', 18 | author='Andrea Zanelli', 19 | long_description=long_description, 20 | long_description_content_type='text/markdown', 21 | license='LGPL', 22 | packages = find_packages(), 23 | entry_points={'console_scripts': ['pmt=prometeo.__main__:console_entry']}, 24 | install_requires=[ 25 | 'astpretty', 26 | 'strip_hints', 27 | 'astunparse', 28 | 'numpy', 29 | 'scipy', 30 | 'multipledispatch', 31 | 'pyparsing', 32 | 'casadi==3.5.5', 33 | 'jinja2', 34 | 'numexpr' 35 | ], 36 | package_data={'prometeo': \ 37 | ['lib/prometeo/libcpmt.so', \ 38 | 'lib/blasfeo/libblasfeo.so', \ 39 | 'include/prometeo/*', \ 40 | 'include/blasfeo/*']}, 41 | # include_package_data=True, 42 | zip_safe=False, 43 | distclass=BinaryDistribution 44 | ) 45 | -------------------------------------------------------------------------------- /prometeo/linalg/pvec_blasfeo_wrapper.py: -------------------------------------------------------------------------------- 1 | from .blasfeo_wrapper import * 2 | from ctypes import * 3 | 4 | def c_pmt_set_blasfeo_dvec(v, data: POINTER(c_double)): 5 | 6 | m = v.m 7 | bw.blasfeo_pack_dvec(m, data, byref(v), 0) 8 | 9 | def c_pmt_set_blasfeo_dvec_el(value, v, ai): 10 | 11 | bw.blasfeo_dvecin1(value, byref(v), ai); 12 | 13 | def c_pmt_get_blasfeo_dvec_el(v, ai): 14 | 15 | el = bw.blasfeo_dvecex1(byref(v), ai) 16 | return el 17 | 18 | def c_pmt_set_pmt_blasfeo_dvec(data, v, ai): 19 | 20 | m = v.m 21 | bw.blasfeo_pack_dvec(m, data, byref(v), 0) 22 | 23 | def c_pmt_create_blasfeo_dvec(m: int): 24 | 25 | size_strvec = bw.blasfeo_memsize_dvec(m) 26 | memory_strvec = c_void_p() 27 | bw.v_zeros_align(byref(memory_strvec), size_strvec) 28 | 29 | ptr_memory_strvec = cast(memory_strvec, c_char_p) 30 | 31 | data = (POINTER(c_double) * 1)() 32 | bw.d_zeros(byref(data), m, 1) 33 | 34 | sv = blasfeo_dvec() 35 | 36 | bw.blasfeo_allocate_dvec(m, byref(sv)) 37 | bw.blasfeo_create_dvec(m, byref(sv), ptr_memory_strvec) 38 | bw.blasfeo_pack_dvec(m, data, byref(sv), 0) 39 | # initialize to 0.0 40 | bw.blasfeo_dvecse(m, 0.0, byref(sv), 0); 41 | return sv 42 | 43 | def c_pmt_vecpe(m, ipiv, a): 44 | ba = a.blasfeo_dvec 45 | bw.blasfeo_dvecpe(m, ipiv, byref(ba)); 46 | return 47 | 48 | # auxiliary functions 49 | def c_pmt_print_blasfeo_dvec(v): 50 | bw.blasfeo_print_dvec(v.blasfeo_dvec.m, byref(v.blasfeo_dvec), 0) 51 | -------------------------------------------------------------------------------- /examples/riccati_example/riccati_debug.py: -------------------------------------------------------------------------------- 1 | from prometeo import * 2 | 3 | # nm: dims = 3 4 | nx: dims = 2 5 | nu: dims = 2 6 | nxu: dims = nx + nu 7 | N: dims = 5 8 | 9 | def main() -> int: 10 | # number of repetitions for timing 11 | nrep : int = 1 12 | 13 | # set up dynamics TODO(needs discretization!) 14 | A: pmat = pmat(nx, nx) 15 | A[0,0] = 0.8 16 | A[0,1] = 0.1 17 | A[1,0] = 0.3 18 | A[1,1] = 0.8 19 | 20 | B: pmat = pmat(nx, nu) 21 | B[0,0] = 1.0 22 | B[0,1] = 0.0 23 | B[1,0] = 0.0 24 | B[1,1] = 1.0 25 | 26 | Q: pmat = pmat(nx, nx) 27 | Q[0,0] = 1.0 28 | Q[0,1] = 0.0 29 | Q[1,0] = 0.0 30 | Q[1,1] = 1.0 31 | 32 | R: pmat = pmat(nu, nu) 33 | R[0,0] = 1.0 34 | R[0,1] = 0.0 35 | R[1,0] = 0.0 36 | R[1,1] = 1.0 37 | 38 | RSQ: pmat = pmat(nxu, nxu) 39 | Lxx: pmat = pmat(nx, nx) 40 | M: pmat = pmat(nxu, nxu) 41 | w_nxu_nx: pmat = pmat(nxu, nx) 42 | BAt : pmat = pmat(nxu, nx) 43 | BA : pmat = pmat(nx, nxu) 44 | pmat_hcat(B, A, BA) 45 | pmat_tran(BA, BAt) 46 | 47 | RSQ[0:nu,0:nu] = R 48 | RSQ[nu:nu+nx,nu:nu+nx] = Q 49 | 50 | # array-type Riccati factorization 51 | for i in range(nrep): 52 | pmt_potrf(Q, Lxx) 53 | M[nu:nu+nx,nu:nu+nx] = Lxx 54 | for i in range(1, N): 55 | pmt_trmm_rlnn(Lxx, BAt, w_nxu_nx) 56 | pmt_syrk_ln(w_nxu_nx, w_nxu_nx, RSQ, M) 57 | pmt_potrf(M, M) 58 | Lxx[0:nx,0:nx] = M[nu:nu+nx,nu:nu+nx] 59 | pmat_print(M) 60 | 61 | return 0 62 | -------------------------------------------------------------------------------- /experimental/dgemm_example/parse.py: -------------------------------------------------------------------------------- 1 | import prometeo 2 | import ast 3 | import astpretty 4 | import typing 5 | 6 | class v(ast.NodeVisitor): 7 | def generic_visit(self, node): 8 | print (type(node).__name__) 9 | ast.NodeVisitor.generic_visit(self, node) 10 | 11 | class FuncLister(ast.NodeVisitor): 12 | def visit_FunctionDef(self, node): 13 | print(node.name) 14 | self.generic_visit(node) 15 | 16 | def iter_all_ast(node): 17 | for field, value in ast.iter_fields(node): 18 | if isinstance(value, list): 19 | for item in value: 20 | if isinstance(item, ast.AST): 21 | for child in iter_all_ast(item): 22 | print(child) 23 | elif isinstance(value, ast.AST): 24 | for child in iter_all_ast(value): 25 | print(child) 26 | 27 | 28 | # filename = 'new_test' 29 | filename = 'dgemm' 30 | py_filename = filename + '.py' 31 | c_filename = filename + '.c' 32 | tree = ast.parse(''.join(open(py_filename))) 33 | astpretty.pprint(tree) 34 | 35 | result = prometeo.cgen.code_gen_c.to_source(tree, filename, main=True, ___c_prmt_8_heap_size=1000, ___c_prmt_64_heap_size=100000 ) 36 | 37 | print("source = \n", prometeo.cgen.source_repr.pretty_source(result.source)) 38 | print("header = \n", prometeo.cgen.source_repr.pretty_source(result.header)) 39 | 40 | dest_file = open(filename + '.c', 'w') 41 | dest_file.write(prometeo.cgen.source_repr.pretty_source(result.source)) 42 | 43 | dest_file = open(filename + '.h', 'w') 44 | dest_file.write(prometeo.cgen.source_repr.pretty_source(result.header)) 45 | -------------------------------------------------------------------------------- /prometeo/linalg/pvec.py: -------------------------------------------------------------------------------- 1 | from ctypes import * 2 | from .pvec_blasfeo_wrapper import * 3 | from .pmat_blasfeo_wrapper import * 4 | from .pmat import * 5 | from .blasfeo_wrapper import * 6 | from abc import ABC 7 | 8 | class pvec_(ABC): 9 | pass 10 | 11 | class pvec(pvec_): 12 | 13 | blasfeo_dvec = None 14 | 15 | def __init__(self, m: int): 16 | self._m = m 17 | self.blasfeo_dvec = c_pmt_create_blasfeo_dvec(m) 18 | 19 | @property 20 | def m(self): 21 | return self._m 22 | 23 | def __getitem__(self, index): 24 | return pvec_get(self, index) 25 | 26 | def __setitem__(self, index, value): 27 | pvec_set(self, value, index) 28 | return 29 | 30 | def fill(self, value): 31 | for i in range(self.blasfeo_dvec.m): 32 | self[i] = value 33 | return 34 | 35 | def copy(self, to_be_copied): 36 | for i in range(self.blasfeo_dvec.m): 37 | value = to_be_copied[i] 38 | self[i] = value 39 | return 40 | 41 | def pmt_vecpe(m, ipiv, a): 42 | c_pmt_vecpe(m, ipiv, a) 43 | 44 | # auxiliary functions 45 | def pvec_set_data(v: pvec, data: POINTER(c_double)): 46 | c_pmt_set_blasfeo_dvec(v.blasfeo_dvec, data) 47 | 48 | def pvec_set(v: pvec, value, i): 49 | c_pmt_set_blasfeo_dvec_el(value, v.blasfeo_dvec, i) 50 | 51 | def pvec_get(v: pvec, i): 52 | el = c_pmt_get_blasfeo_dvec_el(v.blasfeo_dvec, i) 53 | return el 54 | 55 | def pvec_print(v: pvec): 56 | c_pmt_print_blasfeo_dvec(v) 57 | 58 | def pvec_copy(a: pvec, b: pvec): 59 | for i in range(a.blasfeo_dvec.m): 60 | b[i] = a[i] 61 | return 62 | 63 | -------------------------------------------------------------------------------- /benchmarks/run_benchmark_numpy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import subprocess 3 | import json 4 | import matplotlib 5 | import matplotlib.pyplot as plt 6 | 7 | NM = range(2,150,4) 8 | # NM = range(2,20,2) 9 | NREP_small = 10000 10 | NREP_medium = 100 11 | NREP_large = 10 12 | AVG_CPU_TIME = [] 13 | res_file = 'riccati_benchmark_numpy_blasfeo.json' 14 | RUN = True 15 | UPDATE_res = True 16 | 17 | if not UPDATE_res: 18 | print('Warning: not updating result file!') 19 | 20 | if RUN: 21 | for i in range(len(NM)): 22 | print('running Riccati benchmark for case NM = {}'.format(NM[i])) 23 | code = "" 24 | if NM[i] < 30: 25 | NREP = NREP_small 26 | elif NM[i] < 100: 27 | NREP = NREP_medium 28 | else: 29 | NREP = NREP_large 30 | 31 | with open('test_riccati_numpy.py.in') as template: 32 | code = template.read() 33 | code = code.replace('NM', str(NM[i])) 34 | code = code.replace('NREP', str(NREP)) 35 | 36 | with open('test_riccati_numpy.py', 'w+') as bench_file: 37 | bench_file.write(code) 38 | 39 | cmd = 'python test_riccati_numpy.py' 40 | proc = subprocess.Popen([cmd], shell=True, stdout=subprocess.PIPE) 41 | 42 | try: 43 | outs, errs = proc.communicate() 44 | except TimeOutExpired: 45 | proc.kill() 46 | print('Exception raised at NM = {}'.format(NM[i])) 47 | outs, errs = proc.communicate() 48 | 49 | AVG_CPU_TIME.append([float(outs.decode()), NM[i]]) 50 | 51 | if UPDATE_res: 52 | with open(res_file, 'w+') as res: 53 | json.dump(AVG_CPU_TIME, res) 54 | -------------------------------------------------------------------------------- /examples/test/test_assignments.py: -------------------------------------------------------------------------------- 1 | # Test assignments 2 | 3 | from prometeo import * 4 | 5 | n : dims = 3 6 | 7 | def main() -> int: 8 | 9 | # integer declaration 10 | int_var : int = 1 11 | print('int variable declaration:\n%i\n' %int_var) 12 | 13 | # float declaration 14 | float_var : float = 1.0 15 | print('float variable declaration:\n%f\n' %float_var) 16 | 17 | # pvec declaration 18 | pvec_var : pvec = pvec(n) 19 | print('pvec variable declaration:') 20 | pvec_print(pvec_var) 21 | 22 | # pmat declaration 23 | pmat_var: pmat = pmat(n, n) 24 | print('pmat variable declaration:') 25 | pmat_print(pmat_var) 26 | # pmat_print(asd) 27 | 28 | # float to pvec 29 | pvec_var[0] = float_var 30 | print('float to pvec:') 31 | pvec_print(pvec_var) 32 | 33 | # float (const) to pvec 34 | pvec_var[1] = 3.0 35 | print('float (const) to pvec:') 36 | pvec_print(pvec_var) 37 | 38 | # float to pmat 39 | pmat_var[0,1] = float_var 40 | print('float to pmat:') 41 | pmat_print(pmat_var) 42 | 43 | # float (const) to pmat 44 | pmat_var[1,1] = 2.0 45 | print('float (const) to pmat:') 46 | pmat_print(pmat_var) 47 | 48 | # pvec to float 49 | float_var = pvec_var[0] 50 | print('pvec to float:\n%f\n' %float_var) 51 | 52 | # pmat to float 53 | float_var = pmat_var[1, 1] 54 | print('pmat to float:\n%f\n' %float_var) 55 | 56 | # subscripted pmat to pmat 57 | for i in range(2): 58 | pmat_var[0,i] = pmat_var[0,i] 59 | 60 | # subscripted pvec to pvec 61 | pvec_var[0] = pvec_var[1] 62 | 63 | # subscripted pmat to pvec 64 | pvec_var[1] = pmat_var[0, 2] 65 | 66 | # subscripted pvec to pmat 67 | pmat_var[0, 2] = pvec_var[1] 68 | 69 | return 0 70 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. prometeo documentation master file, created by 2 | sphinx-quickstart on Tue Aug 18 13:58:41 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to prometeo's documentation! 7 | ==================================== 8 | This is prometeo, an experimental modeling tool for embedded high-performance computing. prometeo provides a domain specific language (DSL) based on a subset of the Python language that allows one to conveniently write scientific computing programs in a high-level language (Python itself) that can be transpiled to high-performance self-contained C code easily deployable on embedded devices. 9 | 10 | features 11 | -------- 12 | 13 | 1. **Python compatible syntax :** prometeo is a DSL embedded into the Python language. prometeo programs can be executed from the Python interpreter. 14 | 15 | 2. **efficient :** prometeo programs transpile to high-performance C code. 16 | 3. **statically typed :** prometeo uses Python's native type hints to strictly enforce static typing. 17 | 4. **deterministic memory usage :** a specific program structure is required and enforced through static analysis. In this way prometeo transpiled programs have a guaranteed maximum heap usage. 18 | 5. **fast memory menagement :** thanks to its static analysis, prometeo can avoid allocating and garbage-collecting memory, resulting in faster and safer execution. 19 | 6. **self-contained and embeddable :** unlike other similar tools and languages, prometeo targets specifically embedded applications and programs written in prometeo transpile to self-contained C code that does not require linking against the Python run-time library. 20 | 21 | .. toctree:: 22 | :hidden: 23 | :maxdepth: 2 24 | :caption: Contents: 25 | 26 | installation/installation 27 | python_syntax/python_syntax 28 | blas_api/blas_api 29 | performance/performance 30 | 31 | -------------------------------------------------------------------------------- /examples/nonlinear/nonlinear.py: -------------------------------------------------------------------------------- 1 | from prometeo import * 2 | 3 | n: dims = 2 4 | m: dims = 1 5 | 6 | 7 | # class my_class(): 8 | # def __init__(self) -> None: 9 | # A: pmat = pmat(1, n) 10 | # A[0,0] = 1.0 11 | # A[0,1] = 0.0 12 | # A[0,0] = 0.0 13 | # A[0,1] = 1.0 14 | 15 | # v: pvec = pvec(n) 16 | 17 | # v[0] = 1.0 18 | # v[0] = 0.0 19 | # v[1] = 0.0 20 | # v[1] = 1.0 21 | 22 | # x : ca = ca.SX.sym('x', 1, 1) 23 | # test_fun : pfun = pfun('test_fun', 'ca.mtimes(A, v) + sin(x)', \ 24 | # {'A': A, 'v': v, 'x': x}) 25 | 26 | # self.test_fun : pfun = test_fun 27 | # # TODO(andrea): no way to call this from outside the 28 | # # constructor as of now! 29 | 30 | # return 31 | 32 | 33 | def main() -> int: 34 | 35 | A: pmat = pmat(m, n) 36 | A[0,0] = 1.0 37 | A[0,1] = 0.0 38 | A[0,0] = 0.0 39 | A[0,1] = 1.0 40 | 41 | v: pvec = pvec(n) 42 | v[0] = 1.0 43 | v[0] = 0.0 44 | v[1] = 0.0 45 | v[1] = 1.0 46 | 47 | # TODO(andrea): how about using something like this? 48 | # @casadi 49 | # def test_fun([[n,n]], [[m,n]], [A, v]) 50 | # x : ca = ca.SX.sym('x', 1, 1) 51 | # exp = ca.mtimes(A, v) + sin(x) 52 | # test_fun : pfun2 = pfun2('test_fun', [x], [exp], [A, v]) 53 | 54 | x : ca = ca.SX.sym('x', 2, 1) 55 | test_fun : pfun = pfun('test_fun', 'ca.mtimes(A, v) + ca.sin(x[0,0]) + ca.dot(x,x)', \ 56 | {'A': A, 'v': v, 'x': x}) 57 | 58 | res : pmat = pmat(m, m) 59 | res[0,0] = 0.1 60 | res = test_fun(res) 61 | 62 | print(res) 63 | 64 | # test_jac : pfun = pfun('test_jac', 'ca.jacobian( \ 65 | # ca.mtimes(A, v) + sin(x), x)', \ 66 | # {'A': A, 'v': v, 'x': x}) 67 | 68 | # res = test_jac(1.0) 69 | 70 | # # C : my_class = my_class() 71 | 72 | # def myfun(a : int) -> int: 73 | # print('YO') 74 | # return 0 75 | 76 | # print(res) 77 | 78 | return 0 79 | 80 | 81 | -------------------------------------------------------------------------------- /examples/riccati_example/riccati_compact.py: -------------------------------------------------------------------------------- 1 | from prometeo import * 2 | 3 | sizes: dimv = [[2,2], [2,2], [2,2], [2,2], [2,2]] 4 | nx: dims = 2 5 | nxu: dims = 4 6 | nu: dims = 2 7 | N: dims = 5 8 | 9 | class qp_data: 10 | def __init__(self) -> None: 11 | self.A: List = plist(pmat, sizes) 12 | self.B: List = plist(pmat, sizes) 13 | self.Q: List = plist(pmat, sizes) 14 | self.R: List = plist(pmat, sizes) 15 | self.P: List = plist(pmat, sizes) 16 | 17 | self.fact: pmat = pmat(nx,nx) 18 | 19 | def factorize(self) -> None: 20 | Qk: pmat = pmat(nx, nx) 21 | Rk: pmat = pmat(nu, nu) 22 | Ak: pmat = pmat(nx, nx) 23 | Bk: pmat = pmat(nu, nu) 24 | Pk: pmat = pmat(nx, nx) 25 | pmat_copy(self.Q[N-1], Pk) 26 | pmat_copy(Pk, self.P[N-1]) 27 | 28 | for i in range(1, N): 29 | pmat_copy(self.Q[N-i-1], Qk) 30 | pmat_copy(self.R[N-i-1], Rk) 31 | pmat_copy(self.B[N-i-1], Bk) 32 | pmat_copy(self.A[N-i-1], Ak) 33 | 34 | pparse('Pk = Qk + Ak.T * Pk * Ak ' \ 35 | '- (Ak.T * Pk * Bk) * ((Rk + Bk.T * Pk * Bk)' \ 36 | '\ (Bk.T * Pk * Ak))') 37 | 38 | pmat_print(Pk) 39 | pmat_copy(Pk, self.P[N-i-1]) 40 | 41 | return 42 | 43 | def main() -> int: 44 | 45 | A: pmat = pmat(nx, nx) 46 | A[0,0] = 0.8 47 | A[0,1] = 0.1 48 | A[1,0] = 0.3 49 | A[1,1] = 0.8 50 | 51 | B: pmat = pmat(nx, nu) 52 | B[0,0] = 1.0 53 | B[0,1] = 0.0 54 | B[1,0] = 0.0 55 | B[1,1] = 1.0 56 | 57 | Q: pmat = pmat(nx, nx) 58 | Q[0,0] = 1.0 59 | Q[0,1] = 0.0 60 | Q[1,0] = 0.0 61 | Q[1,1] = 1.0 62 | 63 | R: pmat = pmat(nu, nu) 64 | R[0,0] = 1.0 65 | R[0,1] = 0.0 66 | R[1,0] = 0.0 67 | R[1,1] = 1.0 68 | 69 | qp : qp_data = qp_data() 70 | 71 | for i in range(N): 72 | qp.A[i] = A 73 | 74 | for i in range(N): 75 | qp.B[i] = B 76 | 77 | for i in range(N): 78 | qp.Q[i] = Q 79 | 80 | for i in range(N): 81 | qp.R[i] = R 82 | 83 | qp.factorize() 84 | 85 | return 0 86 | -------------------------------------------------------------------------------- /benchmarks/test_riccati.jl: -------------------------------------------------------------------------------- 1 | # author: Tommaso Sartor 2 | 3 | include("./riccati.jl") 4 | using .riccati 5 | 6 | # main script 7 | length(ARGS) != 2 && error("Usage: julia test_riccati.jl ") 8 | 9 | nmass = parse(Int, ARGS[1]) 10 | nrep = parse(Int, ARGS[2]) 11 | 12 | nx = 2*nmass 13 | nu = nmass 14 | N = 5 15 | 16 | # data 17 | 18 | # mass spring system 19 | Ts = 0.5 20 | 21 | Ac = zeros(nx, nx) 22 | for ii in 1:nmass 23 | Ac[ii, nmass+ii] = 1.0 24 | end 25 | for ii in 1:nmass 26 | Ac[nmass+ii, ii] = -2.0 27 | end 28 | for ii in 1:nmass-1 29 | Ac[nmass+ii+1, ii] = 1.0 30 | end 31 | for ii in 1:nmass-1 32 | Ac[nmass+ii, ii+1] = 1.0 33 | end 34 | 35 | Bc = zeros(nx, nu) 36 | for ii in 1:nu 37 | Bc[nmass+ii, ii] = 1.0 38 | end 39 | 40 | #display(Ac) 41 | #println() 42 | #display(Bc) 43 | #println() 44 | 45 | MM = [ Ts*Ac Ts*Bc; zeros(nu, nx+nu) ] 46 | 47 | #display(MM) 48 | #println() 49 | 50 | #MM = exp( MM ) 51 | MM = randn(nu+nx, nu+nx) 52 | 53 | #display(MM) 54 | #println() 55 | 56 | A = MM[1:nx,1:nx] 57 | B = MM[1:nx, nx+1:end] 58 | 59 | #display(A) 60 | #println() 61 | #display(B) 62 | #println() 63 | 64 | Q = zeros(nx, nx) 65 | for ii in 1:nx 66 | Q[ii, ii] = 1.0 67 | end 68 | 69 | R = zeros(nu, nu) 70 | for ii in 1:nu 71 | R[ii, ii] = 2.0 72 | end 73 | 74 | #display(Q) 75 | #println() 76 | #display(R) 77 | #println() 78 | 79 | x0 = zeros(nx, 1) 80 | x0[1] = 3.5 81 | x0[2] = 3.5 82 | 83 | #display(x0) 84 | #println() 85 | 86 | 87 | # work matrices 88 | 89 | BAt = [transpose(B); transpose(A)] 90 | #display(BAt) 91 | #println() 92 | RSQ = [R zeros(nu,nx); zeros(nx,nu) Q] 93 | 94 | BAtL = zeros(nu+nx, nx) 95 | L = zeros(nu+nx, nu+nx, N) 96 | LN = zeros(nx, nx) 97 | M = zeros(nu+nx, nu+nx) 98 | 99 | 100 | # riccati recursion, square root algorithm 101 | 102 | for rep in 1:nrep 103 | riccati_trf(N, nx, nu, BAt, RSQ, L, LN, BAtL, M) 104 | end 105 | 106 | time_start = time_ns() 107 | 108 | for rep in 1:nrep 109 | riccati_trf(N, nx, nu, BAt, RSQ, L, LN, BAtL, M) 110 | end 111 | 112 | time_end = time_ns() 113 | 114 | println((time_end-time_start)/nrep*1e-9) 115 | 116 | 117 | 118 | #display(L) 119 | #println() 120 | #display(LN) 121 | #println() 122 | -------------------------------------------------------------------------------- /benchmarks/riccati_mass_spring.py.in: -------------------------------------------------------------------------------- 1 | from prometeo import * 2 | 3 | nm: dims = NM 4 | nx: dims = 2*nm 5 | nu: dims = nm 6 | nxu: dims = nx + nu 7 | N: dims = 5 8 | 9 | def main() -> int: 10 | # number of repetitions for timing 11 | nrep : int = NREP 12 | 13 | # set up dynamics TODO(needs discretization!) 14 | A: pmat = pmat(nx, nx) 15 | Ac11 : pmat = pmat(nm,nm) 16 | for i in range(nm): 17 | Ac11[i,i] = 1.0 18 | 19 | Ac12 : pmat = pmat(nm,nm) 20 | for i in range(nm): 21 | Ac12[i,i] = 1.0 22 | 23 | Ac21 : pmat = pmat(nm,nm) 24 | for i in range(nm): 25 | Ac21[i,i] = -2.0 26 | 27 | for i in range(nm-1): 28 | Ac21[i+1,i] = 1.0 29 | Ac21[i,i+1] = 1.0 30 | 31 | Ac22 : pmat = pmat(nm,nm) 32 | for i in range(nm): 33 | Ac22[i,i] = 1.0 34 | 35 | for i in range(nm): 36 | for j in range(nm): 37 | A[i,j] = Ac11[i,j] 38 | 39 | for i in range(nm): 40 | for j in range(nm): 41 | A[i,nm+j] = Ac12[i,j] 42 | 43 | for i in range(nm): 44 | for j in range(nm): 45 | A[nm+i,j] = Ac21[i,j] 46 | 47 | for i in range(nm): 48 | for j in range(nm): 49 | A[nm+i,nm+j] = Ac22[i,j] 50 | 51 | 52 | tmp : float = 0.0 53 | for i in range(nx): 54 | tmp = A[i,i] 55 | tmp = tmp + 1.0 56 | A[i,i] = tmp 57 | 58 | B: pmat = pmat(nx, nu) 59 | 60 | for i in range(nu): 61 | B[nm+i,i] = 1.0 62 | 63 | Q: pmat = pmat(nx, nx) 64 | for i in range(nx): 65 | Q[i,i] = 1.0 66 | 67 | R: pmat = pmat(nu, nu) 68 | for i in range(nu): 69 | R[i,i] = 1.0 70 | 71 | RSQ: pmat = pmat(nxu, nxu) 72 | Lxx: pmat = pmat(nx, nx) 73 | M: pmat = pmat(nxu, nxu) 74 | w_nxu_nx: pmat = pmat(nxu, nx) 75 | BAt : pmat = pmat(nxu, nx) 76 | BA : pmat = pmat(nx, nxu) 77 | pmat_hcat(B, A, BA) 78 | pmat_tran(BA, BAt) 79 | 80 | RSQ[0:nu,0:nu] = R 81 | RSQ[nu:nu+nx,nu:nu+nx] = Q 82 | 83 | # array-type Riccati factorization 84 | for i in range(nrep): 85 | pmt_potrf(Q, Lxx) 86 | M[nu:nu+nx,nu:nu+nx] = Lxx 87 | for i in range(1, N): 88 | pmt_trmm_rlnn(Lxx, BAt, w_nxu_nx) 89 | pmt_syrk_ln(w_nxu_nx, w_nxu_nx, RSQ, M) 90 | pmt_potrf(M, M) 91 | Lxx[0:nx,0:nx] = M[nu:nu+nx,nu:nu+nx] 92 | 93 | return 0 94 | -------------------------------------------------------------------------------- /examples/riccati_example/riccati_mass_spring.py: -------------------------------------------------------------------------------- 1 | from prometeo import * 2 | 3 | nm: dims = 10 4 | nx: dims = 2*nm 5 | nu: dims = nm 6 | nxu: dims = nx + nu 7 | N: dims = 5 8 | 9 | def main() -> int: 10 | # number of repetitions for timing 11 | nrep : int = 10000 12 | 13 | # set up dynamics TODO(needs discretization!) 14 | A: pmat = pmat(nx, nx) 15 | Ac11 : pmat = pmat(nm,nm) 16 | for i in range(nm): 17 | Ac11[i,i] = 1.0 18 | 19 | Ac12 : pmat = pmat(nm,nm) 20 | for i in range(nm): 21 | Ac12[i,i] = 1.0 22 | 23 | Ac21 : pmat = pmat(nm,nm) 24 | for i in range(nm): 25 | Ac21[i,i] = -2.0 26 | 27 | for i in range(nm-1): 28 | Ac21[i+1,i] = 1.0 29 | Ac21[i,i+1] = 1.0 30 | 31 | Ac22 : pmat = pmat(nm,nm) 32 | for i in range(nm): 33 | Ac22[i,i] = 1.0 34 | 35 | for i in range(nm): 36 | for j in range(nm): 37 | A[i,j] = Ac11[i,j] 38 | 39 | for i in range(nm): 40 | for j in range(nm): 41 | A[i,nm+j] = Ac12[i,j] 42 | 43 | for i in range(nm): 44 | for j in range(nm): 45 | A[nm+i,j] = Ac21[i,j] 46 | 47 | for i in range(nm): 48 | for j in range(nm): 49 | A[nm+i,nm+j] = Ac22[i,j] 50 | 51 | 52 | tmp : float = 0.0 53 | for i in range(nx): 54 | tmp = A[i,i] 55 | tmp = tmp + 1.0 56 | A[i,i] = tmp 57 | 58 | B: pmat = pmat(nx, nu) 59 | 60 | for i in range(nu): 61 | B[nm+i,i] = 1.0 62 | 63 | Q: pmat = pmat(nx, nx) 64 | for i in range(nx): 65 | Q[i,i] = 1.0 66 | 67 | R: pmat = pmat(nu, nu) 68 | for i in range(nu): 69 | R[i,i] = 1.0 70 | 71 | RSQ: pmat = pmat(nxu, nxu) 72 | Lxx: pmat = pmat(nx, nx) 73 | M: pmat = pmat(nxu, nxu) 74 | w_nxu_nx: pmat = pmat(nxu, nx) 75 | BAt : pmat = pmat(nxu, nx) 76 | BA : pmat = pmat(nx, nxu) 77 | pmat_hcat(B, A, BA) 78 | pmat_tran(BA, BAt) 79 | 80 | RSQ[0:nu,0:nu] = R 81 | RSQ[nu:nu+nx,nu:nu+nx] = Q 82 | 83 | # array-type Riccati factorization 84 | for i in range(nrep): 85 | pmt_potrf(Q, Lxx) 86 | M[nu:nu+nx,nu:nu+nx] = Lxx 87 | for i in range(1, N): 88 | pmt_trmm_rlnn(Lxx, BAt, w_nxu_nx) 89 | pmt_syrk_ln(w_nxu_nx, w_nxu_nx, RSQ, M) 90 | pmt_potrf(M, M) 91 | Lxx[0:nx,0:nx] = M[nu:nu+nx,nu:nu+nx] 92 | 93 | return 0 94 | -------------------------------------------------------------------------------- /examples/riccati_example/riccati.py: -------------------------------------------------------------------------------- 1 | from prometeo import * 2 | 3 | sizes: dimv = [[2,2], [2,2], [2,2], [2,2], [2,2]] 4 | nx: dims = 2 5 | nxu: dims = 4 6 | nu: dims = 2 7 | N: dims = 5 8 | 9 | class qp_data: 10 | def __init__(self) -> None: 11 | self.A: List = plist('pmat', sizes) 12 | self.B: List = plist('pmat', sizes) 13 | self.Q: List = plist('pmat', sizes) 14 | self.R: List = plist('pmat', sizes) 15 | self.P: List = plist('pmat', sizes) 16 | 17 | def factorize(self) -> None: 18 | M: pmat = pmat(nxu, nxu) 19 | Mxx: pmat = pmat(nx, nx) 20 | L: pmat = pmat(nxu, nxu) 21 | Q: pmat = pmat(nx, nx) 22 | R: pmat = pmat(nu, nu) 23 | BA: pmat = pmat(nx, nxu) 24 | BAtP: pmat = pmat(nxu, nx) 25 | pmat_copy(self.Q[N-1], self.P[N-1]) 26 | 27 | for i in range(1, N): 28 | pmat_hcat(self.B[N-i], self.A[N-i], BA) 29 | pmat_fill(BAtP, 0.0) 30 | pmt_gemm_tn(BA, self.P[N-i], BAtP, BAtP) 31 | 32 | pmat_copy(self.Q[N-i], Q) 33 | pmat_copy(self.R[N-i], R) 34 | pmat_fill(M, 0.0) 35 | M[0:nu,0:nu] = R[0:nu,0:nu] 36 | M[nu:nu+nx,nu:nu+nx] = Q[0:nx,0:nx] 37 | 38 | pmt_gemm_nn(BAtP, BA, M, M) 39 | pmat_fill(L, 0.0) 40 | pmt_potrf(M, L) 41 | pmat_print(L) 42 | 43 | Mxx[0:nx, 0:nx] = L[nu:nu+nx, nu:nu+nx] 44 | 45 | pmat_fill(self.P[N-i-1], 0.0) 46 | pmt_gemm_nt(Mxx, Mxx, self.P[N-i-1], self.P[N-i-1]) 47 | # pmat_print(self.P[N-i-1]) 48 | 49 | def main() -> int: 50 | 51 | A: pmat = pmat(nx, nx) 52 | A[0,0] = 0.8 53 | A[0,1] = 0.1 54 | A[1,0] = 0.3 55 | A[1,1] = 0.8 56 | 57 | B: pmat = pmat(nx, nu) 58 | B[0,0] = 1.0 59 | B[0,1] = 0.0 60 | B[1,0] = 0.0 61 | B[1,1] = 1.0 62 | 63 | Q: pmat = pmat(nx, nx) 64 | Q[0,0] = 1.0 65 | Q[0,1] = 0.0 66 | Q[1,0] = 0.0 67 | Q[1,1] = 1.0 68 | 69 | R: pmat = pmat(nu, nu) 70 | R[0,0] = 1.0 71 | R[0,1] = 0.0 72 | R[1,0] = 0.0 73 | R[1,1] = 1.0 74 | 75 | qp : qp_data = qp_data() 76 | 77 | for i in range(N): 78 | qp.A[i] = A 79 | 80 | for i in range(N): 81 | qp.B[i] = B 82 | 83 | for i in range(N): 84 | qp.Q[i] = Q 85 | 86 | for i in range(N): 87 | qp.R[i] = R 88 | 89 | qp.factorize() 90 | 91 | return 0 92 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | group: travis_latest 3 | python: 4 | - "3.8" 5 | - "3.9" 6 | cache: pip 7 | jobs: 8 | include: 9 | - name: Linux Build + examples 10 | os: linux 11 | dist: xenial 12 | services: 13 | - docker 14 | env: 15 | - CC="gcc" 16 | - name: macOS Build + examples 17 | os: osx 18 | osx_image: xcode13.3 19 | language: shell 20 | 21 | install: 22 | - if [ "$TRAVIS_OS_NAME" = "osx" ]; then pushd prometeo/cpmt; make install_shared 23 | CC=clang TARGET=GENERIC OPT_LD_FLAGS=-Wl,-undefined,dynamic_lookup; popd; export 24 | DYLD_LIBRARY_PATH=$(pwd)/prometeo/lib/blasfeo:$DYLD_LIBRARY_PATH; export DYLD_LIBRARY_PATH=$(pwd)/prometeo/lib/prometeo:$DYLD_LIBRARY_PATH; 25 | export LD_LIBRARY_PATH=$(pwd)/prometeo/lib/blasfeo:$LD_LIBRARY_PATH; export LD_LIBRARY_PATH=$(pwd)/prometeo/lib/prometeo:$LD_LIBRARY_PATH; 26 | else pushd prometeo/cpmt; make install_shared; popd; export LD_LIBRARY_PATH=$(pwd)/prometeo/lib/blasfeo:$LD_LIBRARY_PATH; 27 | export LD_LIBRARY_PATH=$(pwd)/prometeo/lib/prometeo:$LD_LIBRARY_PATH; fi 28 | - python3 -m pip install . 29 | 30 | script: 31 | - pushd examples/simple_example 32 | - pmt simple_example.py --cgen=True 33 | - pmt simple_example.py --cgen=False 34 | - popd 35 | - pushd examples/test 36 | - pmt test.py --cgen=True 37 | - pmt test.py --cgen=False 38 | - pmt test_assignments.py --cgen=True 39 | - pmt test_assignments.py --cgen=False 40 | - popd 41 | - pushd examples/riccati_example 42 | - pmt riccati.py --cgen=True 43 | - pmt riccati.py --cgen=False 44 | - popd 45 | 46 | deploy: 47 | provider: pypi 48 | skip_cleanup: true 49 | skip_existing: true 50 | on: 51 | tags: true 52 | condition: $TRAVIS_OS_NAME != "osx" 53 | user: "__token__" 54 | password: 55 | secure: YtZfYvju9MJywGaaxCGk54/t5sY+nqGBvnsWc3vn9dqbpR0bGmehUZ3tVMxhgAwD7C0lG1Vl3eejEyRklesnO6W4KlWpv0M+ZDf7KvpNWIimQcS/rnkzFN+she/SL1vCHwq0SlLJfoPbmCnRMXWio4O1r07qs1pIM67UeZm8U/MH5ABJDaWrvsz6LVcsggf7HMMPe5RnT7tWDAS+CPE8IB1hGtJpha0A6DaG/mm+JcoeXNRqdYrpVZvKsnVY3mE+xuPY17M/n7BXo4/g/A1bu9QGDLRsAaE4FSkUu9KPXtsPVHENZGw1Jvltgv5KIMczOFohiZYNJx8ij3Wpam8okfwun/TlaC5C0z/MBBqN7G0JTBdw12JNaZMJA2GN7wPO9mC1JDWcFO0rUO2wGCneFDLP93agaNWgeD9IeUwBBnTGYjG+TYnmtt0J4CnYRKHW3HLQCNn7XZfpWxw6Q/4vz/quB28ofOLoj87rgljnzikC6m8ia5EGl4Y+/157leGoFrNdYtyXnJEkmlr7bGkLEvuY56kIbelvoiAxhyv3DveA1xHFzu4wmNWMDnw6Q7IRVD0litbMrnd6t/QTB6s117BcZflfQKIlh5MT9BH9LJ+cP8RFbl6B2cbof9VJfkv2XRZC6/Af6OaDkkvKBykppHH6vnXaD0XZu4SIiltHzxM= 56 | -------------------------------------------------------------------------------- /experimental/dgemm_example/dgemm.py: -------------------------------------------------------------------------------- 1 | from prometeo import * 2 | 3 | sizes: dimv = [[2,2], [2,2], [2,2]] 4 | n: dims = 10 5 | class p_class: 6 | attr_1: int = 1 7 | attr_2: float = 3.0 8 | attr_3: pmat = pmat(10, 10) 9 | 10 | def method_2(self, A: pmat[2,2], B: pmat[2,2], C: pmat[2,2]) -> None: 11 | C = A * B 12 | pmat_print(C) 13 | return 14 | 15 | def function1(A: pmat[2,2], B: pmat[2,2], C: pmat[2,2]) -> None: 16 | C = A * B 17 | pmat_print(C) 18 | attr_3: pmat = pmat(10, 10) 19 | return 20 | 21 | def main() -> None: 22 | 23 | n_list: List = plist(int, 10) 24 | n_list[0] = 1 25 | 26 | test_class: p_class = p_class() 27 | test_class.attr_1 = 2 28 | 29 | j: int = 0 30 | for i in range(10): 31 | j = j + 1 32 | 33 | while j > 0: 34 | j = j - 1 35 | 36 | A: pmat = pmat(n, n) 37 | A[0,2] = -2.0 38 | 39 | for i in range(2): 40 | A[0,i] = A[0,i] 41 | 42 | pmat_fill(A, 1.0) 43 | 44 | B: pmat = pmat(n, n) 45 | for i in range(2): 46 | B[0,i] = A[0,i] 47 | pmat_fill(B, 2.0) 48 | 49 | C: pmat = pmat(n, n) 50 | 51 | test_class.method_2(A, B, C) 52 | 53 | pmat_list: List = plist(pmat, sizes) 54 | pmat_list[0] = A 55 | 56 | C = A * B 57 | pmat_print(C) 58 | C = A + B 59 | pmat_print(C) 60 | C = A - B 61 | pmat_print(C) 62 | 63 | function1(A, B, C) 64 | function1(pmat_list[0], B, C) 65 | 66 | pmat_fill(A, 0.0) 67 | for i in range(10): 68 | A[i,i] = 1.0 69 | 70 | pmat_print(A) 71 | 72 | a : pvec = pvec(10) 73 | a[1] = 3.0 74 | b : pvec = pvec(3) 75 | b[0] = a[1] 76 | b[1] = A[0, 2] 77 | A[0,2] = a[0] 78 | 79 | el : float = 1.0 80 | el = a[1] 81 | el = A[1, 1] 82 | pvec_print(a) 83 | pvec_print(b) 84 | 85 | c : pvec = pvec(10) 86 | c = A * a 87 | pvec_print(c) 88 | 89 | # # test LU solve 90 | # ipiv: List = plist(int, 2) 91 | fact : pmat = pmat(2, 2) 92 | M : pmat = pmat(2,2) 93 | # pmt_getrf(M, fact, ipiv) 94 | res: pvec = pvec(2) 95 | # rhs: pvec = pvec(2) 96 | # rhs[0] = 1.0 97 | # rhs[1] = -3.0 98 | # pmt_getrsv(fact, ipiv, rhs) 99 | 100 | # test Cholesky solve 101 | M[0,0] = 1.0 102 | M[0,1] = 0.1 103 | M[1,0] = 0.1 104 | M[1,1] = 1.0 105 | pmt_potrf(M, fact) 106 | # pmt_potrsv(fact, rhs) 107 | 108 | # UNCOMMENT THESE LINES TO EXECUTE 109 | # if __name__ == "__main__": 110 | # execute only if run as a script 111 | # main() 112 | -------------------------------------------------------------------------------- /docs/source/blas_api/blas_api.rst: -------------------------------------------------------------------------------- 1 | BLAS/LAPACK API 2 | ==================================== 3 | 4 | Below a description of prometeo's BLAS/LAPACK API can be found: 5 | 6 | LEVEL 1 BLAS 7 | ############ 8 | 9 | LEVEL 2 BLAS 10 | ############ 11 | 12 | * General matrix-vector multiplication (GEMV) 13 | 14 | .. math:: 15 | 16 | 17 | z \leftarrow \beta \cdot y + \alpha \cdot \text{op}(A) x 18 | 19 | .. code-block:: python 20 | 21 | pmt_gemv(A[.T], x, [y], z, [alpha=1.0], [beta=0.0]) 22 | 23 | * Solve linear system with (lower or upper) triangular matrix coefficient (TRSV) 24 | 25 | .. math:: 26 | 27 | 28 | \text{op}(A)\,x = b 29 | 30 | 31 | .. code-block:: python 32 | 33 | pmt_trsv(A[.T], b, [lower=True]) 34 | 35 | * Matrix-vector multiplication with (lower or upper) triangular matrix coefficient (TRMV) 36 | 37 | .. math:: 38 | 39 | 40 | z \leftarrow \text{op}(A)\,x 41 | 42 | 43 | .. code-block:: python 44 | 45 | pmt_trmv(A[.T], x, z, [lower=True]) 46 | 47 | LEVEL 3 BLAS 48 | ############ 49 | 50 | * General matrix-matrix multiplication (GEMM) 51 | 52 | .. math:: 53 | 54 | D \leftarrow \beta \cdot C + \alpha \cdot \text{op}(A) \, \text{op}(B) 55 | 56 | .. code-block:: python 57 | 58 | pmt_gemm(A[.T], B[.T], [C], D, [alpha=1.0], [beta=0.0]) 59 | 60 | * Symmetric rank :math:`k` update (SYRK) 61 | 62 | .. math:: 63 | 64 | D \leftarrow \beta \cdot C + \alpha \cdot \text{op}(A) \,\text{op}(B) 65 | 66 | with :math:`C` and :math:`D` lower triangular. 67 | 68 | .. code-block:: python 69 | 70 | pmt_syrk(A[.T], B[.T], [C], D, [alpha=1.0], [beta=0.0]) 71 | 72 | * Triangular matrix-matrix multiplication (TRMM) 73 | 74 | .. math:: 75 | 76 | D \leftarrow \alpha \cdot B\, A^{\top} 77 | 78 | with :math:`B` upper triangular or 79 | 80 | 81 | .. math:: 82 | 83 | D \leftarrow \alpha \cdot A\, B 84 | 85 | with :math:`A` lower triangular. 86 | 87 | .. code-block:: python 88 | 89 | pmt_trmm(A[.T], B, D, [alpha=1.0], [beta=0.0]) 90 | 91 | LAPACK 92 | ####### 93 | 94 | 95 | * Cholesky factorization (POTRF) 96 | 97 | .. math:: 98 | 99 | C = D\,D^{\top} 100 | 101 | with :math:`D` lower triangular and :math:`C` symmetric and positive definite 102 | 103 | .. code-block:: python 104 | 105 | pmt_potrf(C, D) 106 | 107 | * LU factorization (GETRF) 108 | 109 | .. math:: 110 | 111 | C = L\,P\,U 112 | 113 | .. code-block:: python 114 | 115 | pmt_getr(C, D) 116 | 117 | * QR factorization (GEQRF) 118 | 119 | .. math:: 120 | 121 | C = Q\,R 122 | 123 | .. code-block:: python 124 | 125 | pmt_geqrf(C, D) 126 | -------------------------------------------------------------------------------- /prometeo/cpmt/pmat_blasfeo_wrapper.h: -------------------------------------------------------------------------------- 1 | #ifndef PROMETEO_PRMT_MAT_BLASFEO_WRAPPER_H_ 2 | #define PROMETEO_PRMT_MAT_BLASFEO_WRAPPER_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include "pmat_blasfeo_wrapper.h" 12 | #include "pvec_blasfeo_wrapper.h" 13 | 14 | #ifdef __cplusplus 15 | extern "C" { 16 | #endif 17 | 18 | // (dummy) pmat wrapper to blasfeo_dmat 19 | struct pmat { 20 | struct blasfeo_dmat *bmat; 21 | }; 22 | 23 | struct pmat * c_pmt_create_pmat(int m, int n); 24 | void c_pmt_assign_and_advance_blasfeo_dmat(int m, int n, struct blasfeo_dmat **bmat); 25 | 26 | // BLAS API 27 | void c_pmt_gemm_nn(struct pmat *A, struct pmat *B, struct pmat *C, struct pmat *D); 28 | struct pmat * _c_pmt_gemm_nn(struct pmat *A, struct pmat *B); 29 | void c_pmt_gemm_tn(struct pmat *A, struct pmat *B, struct pmat *C, struct pmat *D); 30 | void c_pmt_gemm_nt(struct pmat *A, struct pmat *B, struct pmat *C, struct pmat *D); 31 | void c_pmt_trmm_rlnn(struct pmat *A, struct pmat *B, struct pmat *D); 32 | void c_pmt_syrk_ln(struct pmat *A, struct pmat *B, struct pmat *C, struct pmat *D); 33 | void c_pmt_getrf(struct pmat *A, struct pmat *fact, int *ipiv); 34 | void c_pmt_potrf(struct pmat *A, struct pmat *fact); 35 | void c_pmt_getrsm(struct pmat *fact, int *ipiv, struct pmat *rhs); 36 | struct pmat * _c_pmt_getrsm(struct pmat *A, struct pmat *rhs); 37 | void c_pmt_getrsv(struct pmat *fact, int *ipiv, struct pvec *rhs); 38 | void c_pmt_potrsm(struct pmat *fact, struct pmat *rhs); 39 | void c_pmt_potrsv(struct pmat *fact, struct pvec *rhs); 40 | void c_pmt_gead(double alpha, struct pmat *A, struct pmat *B); 41 | struct pmat * _c_pmt_gead(double alpha, struct pmat *A, struct pmat *B); 42 | 43 | void c_pmt_gemv_n(struct pmat *A, struct pvec *b, struct pvec *c, struct pvec *d); 44 | 45 | // auxiliary 46 | void c_pmt_pmat_fill(struct pmat *A, double fill_value); 47 | void c_pmt_pmat_set_el(struct pmat *A, int i, int j, double value); 48 | double c_pmt_pmat_get_el(struct pmat *A, int i, int j); 49 | void c_pmt_pmat_copy(struct pmat *A, struct pmat *B); 50 | struct pmat * _c_pmt_pmat_copy(struct pmat *A, struct pmat *B); 51 | void c_pmt_pmat_tran(struct pmat *A, struct pmat *B); 52 | struct pmat * _c_pmt_pmat_tran(struct pmat *A); 53 | void c_pmt_gecp(int m, int n, struct pmat *A, int ai, int aj, struct pmat *B, int bi, int bj); 54 | void c_pmt_pmat_vcat(struct pmat *A, struct pmat *B, struct pmat *res); 55 | void c_pmt_pmat_hcat(struct pmat *A, struct pmat *B, struct pmat *res); 56 | void c_pmt_pmat_print(struct pmat *A); 57 | 58 | #ifdef __cplusplus 59 | } 60 | #endif 61 | 62 | #endif // PROMETEO_PRMT_MAT_BLASFEO_WRAPPER_H_ 63 | 64 | -------------------------------------------------------------------------------- /prometeo/cpmt/pvec_blasfeo_wrapper.c: -------------------------------------------------------------------------------- 1 | #include "pvec_blasfeo_wrapper.h" 2 | #include "pmt_heap.h" 3 | #include "pmt_aux.h" 4 | #include 5 | #include 6 | 7 | extern void make_int_multiple_of(int multiple_of, int * n); 8 | 9 | struct pvec * c_pmt_create_pvec(int m) { 10 | // assign current address of global heap to pvec pointer 11 | struct pvec *pvec = (struct pvec *) ___c_pmt_8_heap; 12 | void *pvec_address = ___c_pmt_8_heap; 13 | 14 | // advance global heap address 15 | ___c_pmt_8_heap += sizeof(struct pvec); 16 | 17 | 18 | // create (zeroed) blasfeo_dvec and advance global heap 19 | c_pmt_assign_and_advance_blasfeo_dvec(m, &(pvec->bvec)); 20 | 21 | return (struct pvec *)(pvec_address); 22 | } 23 | 24 | 25 | void c_pmt_assign_and_advance_blasfeo_dvec(int m, struct blasfeo_dvec **bvec) { 26 | // assign current address of global heap to blasfeo dvec pointer 27 | assert((size_t) ___c_pmt_8_heap % 8 == 0 && "pointer not 8-byte aligned!"); 28 | *bvec = (struct blasfeo_dvec *) ___c_pmt_8_heap; 29 | // 30 | // advance global heap address 31 | ___c_pmt_8_heap += sizeof(struct blasfeo_dvec); 32 | 33 | // assign current address of global heap to memory in blasfeo dvec 34 | char *pmem_ptr = (char *)___c_pmt_64_heap; 35 | // align_char_to(64, &pmem_ptr); 36 | ___c_pmt_64_heap = pmem_ptr; 37 | assert((size_t) ___c_pmt_64_heap % 64 == 0 && "dvec not 64-byte aligned!"); 38 | blasfeo_create_dvec(m, *bvec, ___c_pmt_64_heap); 39 | 40 | // advance global heap address 41 | int memsize = (*bvec)->memsize; 42 | make_int_multiple_of(64, &memsize); 43 | ___c_pmt_64_heap += memsize; 44 | 45 | // zero allocated memory 46 | int i; 47 | double *pa = (*bvec)->pa; 48 | int size = (*bvec)->memsize; 49 | for(i=0; ibvec->m; 59 | 60 | for(int i = 0; i < m; i++) 61 | blasfeo_dvecin1(fill_value, a->bvec, i); 62 | } 63 | 64 | void c_pmt_pvec_set_el(struct pvec *a, int i, double fill_value) { 65 | 66 | blasfeo_dvecin1(fill_value, a->bvec, i); 67 | } 68 | 69 | double c_pmt_pvec_get_el(struct pvec *a, int i) { 70 | 71 | blasfeo_dvecex1(a->bvec, i); 72 | } 73 | 74 | void c_pmt_pvec_copy(struct pvec *a, struct pvec *b) { 75 | int m = a->bvec->m; 76 | double value; 77 | 78 | for(int i = 0; i < m; i++) { 79 | value = blasfeo_dvecex1(a->bvec, i); 80 | blasfeo_dvecin1(value, b->bvec, i); 81 | } 82 | } 83 | 84 | void c_pmt_pvec_print(struct pvec *a) { 85 | int m = a->bvec->m; 86 | 87 | blasfeo_print_dvec(m, a->bvec, 0); 88 | } 89 | -------------------------------------------------------------------------------- /experimental/experimental_examples/test_blasfeo_ctypes.py: -------------------------------------------------------------------------------- 1 | from ctypes import * 2 | from os import * 3 | from blasfeo_wrapper import * 4 | 5 | print('\n') 6 | print('###############################################################') 7 | print(' Testing BLASFEO wrapper: dmat creation and call to dgemm_nt') 8 | print('###############################################################') 9 | print('\n') 10 | 11 | n = 5 12 | 13 | size_strmat = 3*bw.blasfeo_memsize_dmat(n, n) 14 | memory_strmat = c_void_p() 15 | bw.v_zeros_align(byref(memory_strmat), size_strmat) 16 | 17 | ptr_memory_strmat = cast(memory_strmat, c_char_p) 18 | 19 | A = (POINTER(c_double) * 1)() 20 | bw.d_zeros(byref(A), n, n) 21 | for i in range(n*n): 22 | A[0][i] = i 23 | 24 | sA = blasfeo_dmat() 25 | 26 | bw.blasfeo_allocate_dmat(n, n, byref(sA)) 27 | bw.blasfeo_create_dmat(n, n, byref(sA), ptr_memory_strmat) 28 | bw.blasfeo_pack_dmat(n, n, A[0], n, byref(sA), 0, 0) 29 | print('content of sA:\n') 30 | bw.blasfeo_print_dmat(n, n, byref(sA), 0, 0) 31 | 32 | ptr_memory_strmat = cast(ptr_memory_strmat, c_void_p) 33 | ptr_memory_strmat.value = ptr_memory_strmat.value + sA.memsize 34 | ptr_memory_strmat = cast(ptr_memory_strmat, c_char_p) 35 | 36 | D = (POINTER(c_double) * 1)() 37 | bw.d_zeros(byref(D), n, n) 38 | for i in range(n): 39 | D[0][i*(n + 1)] = 1.0 40 | 41 | sD = blasfeo_dmat() 42 | 43 | bw.blasfeo_allocate_dmat(n, n, byref(sD)) 44 | bw.blasfeo_create_dmat(n, n, byref(sD), ptr_memory_strmat) 45 | bw.blasfeo_pack_dmat(n, n, D[0], n, byref(sD), 0, 0); 46 | print('content of sD:\n') 47 | bw.blasfeo_print_dmat(n, n, byref(sD), 0, 0) 48 | 49 | ptr_memory_strmat = cast(ptr_memory_strmat, c_void_p) 50 | ptr_memory_strmat.value = ptr_memory_strmat.value + sD.memsize 51 | ptr_memory_strmat = cast(ptr_memory_strmat, c_char_p) 52 | 53 | B = (POINTER(c_double) * 1)() 54 | bw.d_zeros(byref(B), n, n) 55 | for i in range(n): 56 | B[0][i*(n + 1)] = 1.0 57 | 58 | sB = blasfeo_dmat() 59 | 60 | import pdb; pdb.set_trace() 61 | bw.blasfeo_allocate_dmat(n, n, byref(sB)) 62 | bw.blasfeo_create_dmat(n, n, byref(sB), ptr_memory_strmat) 63 | bw.blasfeo_pack_dmat(n, n, B[0], n, byref(sB), 0, 0); 64 | print('content of sB:\n') 65 | bw.blasfeo_print_dmat(n, n, byref(sB), 0, 0) 66 | 67 | ptr_memory_strmat = cast(ptr_memory_strmat, c_void_p) 68 | ptr_memory_strmat.value = ptr_memory_strmat.value + sB.memsize 69 | ptr_memory_strmat = cast(ptr_memory_strmat, c_char_p) 70 | 71 | # This call would require ctypes handling from the cgen engine 72 | # bw.blasfeo_dgemm_nt(n, n, n, 1.0, byref(sA), 0, 0, byref(sA), 0, 0, 1, byref(sB), 0, 0, byref(sD), 0, 0); 73 | 74 | # This (wrapped) call should make it easy to code-gen calls to blasfeo 75 | blasfeo_dgemm_nt(n, n, n, 1.0, sA, 0, 0, sA, 0, 0, 1, sB, 0, 0, sD, 0, 0) 76 | print('B + A*A (blasfeo_dgemm_nt):\n') 77 | bw.blasfeo_print_dmat(n, n, byref(sD), 0, 0) 78 | 79 | -------------------------------------------------------------------------------- /examples/riccati_example/riccati_mass_spring_2.py: -------------------------------------------------------------------------------- 1 | from prometeo import * 2 | 3 | nm: dims = 4 4 | nx: dims = 2*nm 5 | sizes: dimv = [[8,8], [8,8], [8,8], [8,8], [8,8]] 6 | nu: dims = nm 7 | nxu: dims = nx + nu 8 | N: dims = 5 9 | 10 | class qp_data: 11 | def __init__(self) -> None: 12 | self.A: List = plist(pmat, sizes) 13 | self.B: List = plist(pmat, sizes) 14 | self.Q: List = plist(pmat, sizes) 15 | self.R: List = plist(pmat, sizes) 16 | self.P: List = plist(pmat, sizes) 17 | return None 18 | 19 | def factorize(self) -> None: 20 | M: pmat = pmat(nxu, nxu) 21 | Mxx: pmat = pmat(nx, nx) 22 | L: pmat = pmat(nxu, nxu) 23 | Q: pmat = pmat(nx, nx) 24 | R: pmat = pmat(nu, nu) 25 | BA: pmat = pmat(nx, nxu) 26 | BAtP: pmat = pmat(nxu, nx) 27 | pmat_copy(self.Q[N-1], self.P[N-1]) 28 | 29 | pmat_hcat(self.B[N-1], self.A[N-1], BA) 30 | pmat_copy(self.Q[N-1], Q) 31 | pmat_copy(self.R[N-1], R) 32 | for i in range(1, N): 33 | pmat_fill(BAtP, 0.0) 34 | pmt_gemm_tn(BA, self.P[N-i], BAtP, BAtP) 35 | 36 | pmat_fill(M, 0.0) 37 | M[0:nu,0:nu] = R[0:nu,0:nu] 38 | M[nu:nu+nx,nu:nu+nx] = Q[0:nx,0:nx] 39 | 40 | pmt_gemm_nn(BAtP, BA, M, M) 41 | pmat_fill(L, 0.0) 42 | pmt_potrf(M, L) 43 | 44 | Mxx[0:nx, 0:nx] = L[nu:nu+nx, nu:nu+nx] 45 | 46 | pmt_gemm_nt(Mxx, Mxx, self.P[N-i-1], self.P[N-i-1]) 47 | pmat_print(self.P[N-i-1]) 48 | 49 | return None 50 | 51 | def main() -> int: 52 | 53 | A: pmat = pmat(nx, nx) 54 | Ac11 : pmat = pmat(nm,nm) 55 | Ac12 : pmat = pmat(nm,nm) 56 | for i in range(nm): 57 | Ac12[i,i] = 1.0 58 | 59 | Ac21 : pmat = pmat(nm,nm) 60 | for i in range(nm): 61 | Ac21[i,i] = -2.0 62 | 63 | for i in range(nm-1): 64 | Ac21[i+1,i] = 1.0 65 | Ac21[i,i+1] = 1.0 66 | 67 | Ac22 : pmat = pmat(nm,nm) 68 | 69 | for i in range(nm): 70 | for j in range(nm): 71 | A[i,j] = Ac11[i,j] 72 | 73 | for i in range(nm): 74 | for j in range(nm): 75 | A[i,nm+j] = Ac12[i,j] 76 | 77 | for i in range(nm): 78 | for j in range(nm): 79 | A[nm+i,j] = Ac21[i,j] 80 | 81 | for i in range(nm): 82 | for j in range(nm): 83 | A[nm+i,nm+j] = Ac22[i,j] 84 | 85 | tmp : float = 0.0 86 | for i in range(nx): 87 | tmp = A[i,i] 88 | tmp = tmp + 1.0 89 | A[i,i] = tmp 90 | 91 | B: pmat = pmat(nx, nu) 92 | 93 | for i in range(nu): 94 | B[nm+i,i] = 1.0 95 | 96 | Q: pmat = pmat(nx, nx) 97 | for i in range(nx): 98 | Q[i,i] = 1.0 99 | 100 | R: pmat = pmat(nu, nu) 101 | for i in range(nu): 102 | R[i,i] = 1.0 103 | 104 | qp : qp_data = qp_data() 105 | 106 | for i in range(N): 107 | qp.A[i] = A 108 | 109 | for i in range(N): 110 | qp.B[i] = B 111 | 112 | for i in range(N): 113 | qp.Q[i] = Q 114 | 115 | for i in range(N): 116 | qp.R[i] = R 117 | 118 | qp.factorize() 119 | 120 | return 0 121 | -------------------------------------------------------------------------------- /prometeo/cgen/string_repr.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Part of the astor library for Python AST manipulation. 4 | 5 | License: 3-clause BSD 6 | 7 | Copyright (c) 2015 Patrick Maupin 8 | 9 | Pretty-print strings for the decompiler 10 | 11 | We either return the repr() of the string, 12 | or try to format it as a triple-quoted string. 13 | 14 | This is a lot harder than you would think. 15 | 16 | This has lots of Python 2 / Python 3 ugliness. 17 | 18 | """ 19 | 20 | import re 21 | 22 | try: 23 | special_unicode = unicode 24 | except NameError: 25 | class special_unicode(object): 26 | pass 27 | 28 | try: 29 | basestring = basestring 30 | except NameError: 31 | basestring = str 32 | 33 | 34 | def _properly_indented(s, line_indent): 35 | mylist = s.split('\n')[1:] 36 | mylist = [x.rstrip() for x in mylist] 37 | mylist = [x for x in mylist if x] 38 | if not s: 39 | return False 40 | counts = [(len(x) - len(x.lstrip())) for x in mylist] 41 | return counts and min(counts) >= line_indent 42 | 43 | 44 | mysplit = re.compile(r'(\\|\"\"\"|\"$)').split 45 | replacements = {'\\': '\\\\', '"""': '""\\"', '"': '\\"'} 46 | 47 | 48 | def _prep_triple_quotes(s, mysplit=mysplit, replacements=replacements): 49 | """ Split the string up and force-feed some replacements 50 | to make sure it will round-trip OK 51 | """ 52 | 53 | s = mysplit(s) 54 | s[1::2] = (replacements[x] for x in s[1::2]) 55 | return ''.join(s) 56 | 57 | 58 | def pretty_string(s, embedded, current_line, uni_lit=False, 59 | min_trip_str=20, max_line=100): 60 | """There are a lot of reasons why we might not want to or 61 | be able to return a triple-quoted string. We can always 62 | punt back to the default normal string. 63 | """ 64 | 65 | default = repr(s) 66 | 67 | # Punt on abnormal strings 68 | if (isinstance(s, special_unicode) or not isinstance(s, basestring)): 69 | return default 70 | if uni_lit and isinstance(s, bytes): 71 | return 'b' + default 72 | 73 | len_s = len(default) 74 | 75 | if current_line.strip(): 76 | len_current = len(current_line) 77 | second_line_start = s.find('\n') + 1 78 | if embedded > 1 and not second_line_start: 79 | return default 80 | 81 | if len_s < min_trip_str: 82 | return default 83 | 84 | line_indent = len_current - len(current_line.lstrip()) 85 | 86 | # Could be on a line by itself... 87 | if embedded and not second_line_start: 88 | return default 89 | 90 | total_len = len_current + len_s 91 | if total_len < max_line and not _properly_indented(s, line_indent): 92 | return default 93 | 94 | fancy = '"""%s"""' % _prep_triple_quotes(s) 95 | 96 | # Sometimes this doesn't work. One reason is that 97 | # the AST has no understanding of whether \r\n was 98 | # entered that way in the string or was a cr/lf in the 99 | # file. So we punt just so we can round-trip properly. 100 | 101 | try: 102 | if eval(fancy) == s and '\r' not in fancy: 103 | return fancy 104 | except: 105 | pass 106 | return default 107 | -------------------------------------------------------------------------------- /prometeo/cgen/op_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Part of the astor library for Python AST manipulation. 4 | 5 | License: 3-clause BSD 6 | 7 | Copyright (c) 2015 Patrick Maupin 8 | 9 | This module provides data and functions for mapping 10 | AST nodes to symbols and precedences. 11 | 12 | """ 13 | 14 | import ast 15 | 16 | op_data = """ 17 | GeneratorExp 1 18 | 19 | Assign 1 20 | AnnAssign 1 21 | AugAssign 0 22 | Expr 0 23 | Yield 1 24 | YieldFrom 0 25 | If 1 26 | For 0 27 | AsyncFor 0 28 | While 0 29 | Return 1 30 | 31 | Slice 1 32 | Subscript 0 33 | Index 1 34 | ExtSlice 1 35 | comprehension_target 1 36 | Tuple 0 37 | 38 | Comma 1 39 | Assert 0 40 | Raise 0 41 | call_one_arg 1 42 | 43 | Lambda 1 44 | IfExp 0 45 | 46 | comprehension 1 47 | Or or 1 48 | And and 1 49 | Not not 1 50 | 51 | Eq == 1 52 | Gt > 0 53 | GtE >= 0 54 | In in 0 55 | Is is 0 56 | NotEq != 0 57 | Lt < 0 58 | LtE <= 0 59 | NotIn not in 0 60 | IsNot is not 0 61 | 62 | BitOr | 1 63 | BitXor ^ 1 64 | BitAnd & 1 65 | LShift << 1 66 | RShift >> 0 67 | Add + 1 68 | Sub - 0 69 | Mult * 1 70 | Div / 0 71 | Mod % 0 72 | FloorDiv // 0 73 | MatMult @ 0 74 | PowRHS 1 75 | Invert ~ 1 76 | UAdd + 0 77 | USub - 0 78 | Pow ** 1 79 | Await 1 80 | Num 1 81 | Constant 1 82 | """ 83 | 84 | op_data = [x.split() for x in op_data.splitlines()] 85 | op_data = [[x[0], ' '.join(x[1:-1]), int(x[-1])] for x in op_data if x] 86 | for index in range(1, len(op_data)): 87 | op_data[index][2] *= 2 88 | op_data[index][2] += op_data[index - 1][2] 89 | 90 | precedence_data = dict((getattr(ast, x, None), z) for x, y, z in op_data) 91 | symbol_data = dict((getattr(ast, x, None), y) for x, y, z in op_data) 92 | 93 | 94 | def get_op_symbol(obj, fmt='%s', symbol_data=symbol_data, type=type): 95 | """Given an AST node object, returns a string containing the symbol. 96 | """ 97 | return fmt % symbol_data[type(obj)] 98 | 99 | 100 | def get_op_precedence(obj, precedence_data=precedence_data, type=type): 101 | """Given an AST node object, returns the precedence. 102 | """ 103 | return precedence_data[type(obj)] 104 | 105 | 106 | class Precedence(object): 107 | vars().update((x, z) for x, y, z in op_data) 108 | highest = max(z for x, y, z in op_data) + 2 109 | -------------------------------------------------------------------------------- /docs/source/python_syntax/python_syntax.rst: -------------------------------------------------------------------------------- 1 | Python syntax 2 | ============= 3 | 4 | prometeo is an embedded domain specific language based on Python. Hence, its 5 | syntax is based on Python. Below you find details regarding the most common 6 | supported Python constructs that prometeo is able to transpile to C. 7 | 8 | variable declaration 9 | -------------------- 10 | 11 | A variable can be declared as follows 12 | 13 | .. code-block:: python 14 | 15 | : = 16 | 17 | where `` must be a valid identifier `` must be a valid 18 | prometeo built-in type or a user-defined type and `` must be 19 | an valid expression of type ``. 20 | 21 | Example: 22 | 23 | .. code-block:: python 24 | 25 | a : int = 1 26 | 27 | Notice that, unlike in Python, type hints are strictly mandatory as they will instruct 28 | prometeo's parser regarding the type of the variables being defined. 29 | 30 | `if` statement 31 | ------------ 32 | 33 | An `if` statement takes the form 34 | 35 | 36 | .. code-block:: python 37 | 38 | if : 39 | ... 40 | 41 | 42 | `for` loop 43 | ------------ 44 | 45 | A `for` loop takes the form 46 | 47 | 48 | .. code-block:: python 49 | 50 | for i in range([], ) 51 | ... 52 | 53 | where the optional parameter `` must be an expression of type `int` (default value 0) and defines the starting value of the loop's index and `` must be an expression of type `` which defines its final value. 54 | 55 | function definition 56 | ------------------- 57 | 58 | Functions can be defined as follows 59 | 60 | 61 | .. code-block:: python 62 | 63 | def ( : , ...) -> : 64 | 65 | ... 66 | 67 | return 68 | 69 | 70 | 71 | class definition 72 | ---------------- 73 | 74 | prometeo supports basic classes of the following form 75 | 76 | .. code-block:: python3 77 | 78 | class : 79 | def __init__(self, : , ...) -> None: 80 | self. : = 81 | ... 82 | 83 | def (self, : , ...) -> : 84 | ... 85 | 86 | return 87 | 88 | main function 89 | ------------- 90 | 91 | For consistency all main functions need to be defined as follows 92 | 93 | 94 | .. code-block:: python 95 | 96 | def main() -> int: 97 | 98 | ... 99 | 100 | return 0 101 | 102 | pure Python blocks 103 | ------------------- 104 | 105 | In order to be able to use the full potential of the Python language and 106 | its vast pool of libraries, it is possible to write *pure Python* blocks 107 | that are run only when prometeo code is executed directly from the Python intepreter (when --cgen is set to false). In particular, any line that is enclosed within `# pure >` and `# pure <` will be run only by the Python interpreter, but completely discarded by prometeo's parser. 108 | 109 | 110 | .. code-block:: python 111 | 112 | # some prometeo code 113 | A : pmat = pmat(n,n) 114 | ... 115 | 116 | # pure > 117 | 118 | # this is only run by the Python interpreter (--cgen=False) 119 | # and will not be transpiled) 120 | 121 | # some Python code 122 | 123 | import numpy as np 124 | 125 | M = np.array([[1.0, 2.0],[0.0, 0.5]]) 126 | print(np.linalg.eigvals(M)) 127 | ... 128 | 129 | # pure < 130 | 131 | # some more prometeo code 132 | for i in range(n): 133 | for j in range(n): 134 | A[i, j] = 1.0 135 | ... 136 | 137 | -------------------------------------------------------------------------------- /prometeo/linalg/blasfeo_wrapper.py: -------------------------------------------------------------------------------- 1 | from ctypes import * 2 | import os 3 | 4 | bw = CDLL(os.path.dirname(__file__) + '/../lib/blasfeo/libblasfeo.so') 5 | 6 | class blasfeo_dmat(Structure): 7 | _fields_ = [ ("m", c_int), 8 | ("n", c_int), 9 | ("pm", c_int), 10 | ("cn", c_int), 11 | ("pA", POINTER(c_double)), 12 | ("dA", POINTER(c_double)), 13 | ("use_dA", c_int), 14 | ("memsize", c_int)] 15 | 16 | class blasfeo_dvec(Structure): 17 | _fields_ = [ ("m", c_int), 18 | ("pm", c_int), 19 | ("pa", POINTER(c_double)), 20 | ("memsize", c_int)] 21 | 22 | bw.blasfeo_dgemm_nn.argtypes = [c_int, c_int, c_int, c_double, 23 | POINTER(blasfeo_dmat), c_int, c_int, POINTER(blasfeo_dmat), c_int, c_int, 24 | c_double, POINTER(blasfeo_dmat), c_int, c_int, POINTER(blasfeo_dmat), c_int, c_int] 25 | 26 | bw.blasfeo_dgemm_nt.argtypes = [c_int, c_int, c_int, c_double, 27 | POINTER(blasfeo_dmat), c_int, c_int, POINTER(blasfeo_dmat), c_int, c_int, 28 | c_double, POINTER(blasfeo_dmat), c_int, c_int, POINTER(blasfeo_dmat), c_int, c_int] 29 | 30 | bw.blasfeo_dgemm_tn.argtypes = [c_int, c_int, c_int, c_double, 31 | POINTER(blasfeo_dmat), c_int, c_int, POINTER(blasfeo_dmat), c_int, c_int, 32 | c_double, POINTER(blasfeo_dmat), c_int, c_int, POINTER(blasfeo_dmat), c_int, c_int] 33 | 34 | bw.blasfeo_dgemm_tt.argtypes = [c_int, c_int, c_int, c_double, 35 | POINTER(blasfeo_dmat), c_int, c_int, POINTER(blasfeo_dmat), c_int, c_int, 36 | c_double, POINTER(blasfeo_dmat), c_int, c_int, POINTER(blasfeo_dmat), c_int, c_int] 37 | 38 | bw.blasfeo_dgead.argtypes = [c_int, c_int, c_double, POINTER(blasfeo_dmat), c_int, c_int, 39 | POINTER(blasfeo_dmat), c_int, c_int] 40 | 41 | bw.blasfeo_dgein1.argtypes = [c_double, POINTER(blasfeo_dmat), c_int, c_int] 42 | 43 | bw.blasfeo_dgeex1.argtypes = [POINTER(blasfeo_dmat), c_int, c_int] 44 | 45 | bw.blasfeo_drowpe.argtypes = [c_int, POINTER(c_int), POINTER(blasfeo_dmat)] 46 | 47 | bw.blasfeo_dgetrf_rp.argtypes = [c_int, c_int, POINTER(blasfeo_dmat), c_int, c_int, 48 | POINTER(blasfeo_dmat), c_int, c_int, POINTER(c_int)] 49 | 50 | bw.blasfeo_dtrsm_llnn.argtypes = [c_int, c_int, c_double, POINTER(blasfeo_dmat), c_int, c_int, 51 | POINTER(blasfeo_dmat), c_int, c_int, POINTER(blasfeo_dmat), c_int, c_int] 52 | 53 | bw.blasfeo_dtrsm_llnu.argtypes = [c_int, c_int, c_double, POINTER(blasfeo_dmat), c_int, c_int, 54 | POINTER(blasfeo_dmat), c_int, c_int, POINTER(blasfeo_dmat), c_int, c_int] 55 | 56 | bw.blasfeo_dtrsm_lunn.argtypes = [c_int, c_int, c_double, POINTER(blasfeo_dmat), c_int, c_int, 57 | POINTER(blasfeo_dmat), c_int, c_int, POINTER(blasfeo_dmat), c_int, c_int] 58 | 59 | bw.blasfeo_dtrsv_lnn.argtypes = [c_int, c_double, POINTER(blasfeo_dmat), c_int, c_int, 60 | POINTER(blasfeo_dvec), c_int, POINTER(blasfeo_dvec), c_int] 61 | 62 | bw.blasfeo_dtrsv_lnn.argtypes = [c_int, c_double, POINTER(blasfeo_dmat), c_int, c_int, 63 | POINTER(blasfeo_dvec), c_int, POINTER(blasfeo_dvec), c_int] 64 | 65 | bw.blasfeo_dgemv_n.argtypes = [c_int, c_int, c_double, POINTER(blasfeo_dmat), c_int, c_int, 66 | POINTER(blasfeo_dvec), c_int, c_double, POINTER(blasfeo_dvec), c_int, POINTER(blasfeo_dvec), c_int] 67 | 68 | 69 | bw.blasfeo_dgese.argtypes = [c_int, c_int, c_double, POINTER(blasfeo_dmat), c_int, c_int] 70 | bw.blasfeo_dvecse.argtypes = [c_int, c_double, POINTER(blasfeo_dvec), c_int] 71 | # def blasfeo_dgemm_nt(m: c_int, n: c_int, k: c_int, alpha: c_double, 72 | # sA: POINTER(blasfeo_dmat), ai: c_int, aj: c_int, 73 | # sB: POINTER(blasfeo_dmat), bi: c_int, bj: c_int, 74 | # beta: c_double, sC: POINTER(blasfeo_dmat), 75 | # ci: c_int, cj: c_int, sD: POINTER(blasfeo_dmat), 76 | # di: c_int, dj: c_int): 77 | 78 | # bw.blasfeo_dgemm_nt(m, n, k, alpha, byref(sA), ai, aj, byref(sB), bi, bj, beta, 79 | # byref(sC), ci, cj, byref(sD), di, dj) 80 | 81 | # blasfeo_dvec 82 | 83 | bw.blasfeo_dvecin1.argtypes = [c_double, POINTER(blasfeo_dvec), c_int] 84 | bw.blasfeo_dvecex1.restype = c_double 85 | -------------------------------------------------------------------------------- /examples/test/test.py: -------------------------------------------------------------------------------- 1 | from prometeo import * 2 | 3 | sizes: dimv = [[2,2], [2,2], [2,2], [2,2], [2,2]] 4 | nx: dims = 2 5 | nxu: dims = 4 6 | nu: dims = 2 7 | N: dims = 5 8 | n: dims = 10 9 | 10 | class qp_data: 11 | def __init__(self) -> None: 12 | self.A: List = plist('pmat', sizes) 13 | self.B: List = plist('pmat', sizes) 14 | self.Q: List = plist('pmat', sizes) 15 | self.R: List = plist('pmat', sizes) 16 | self.P: List = plist('pmat', sizes) 17 | 18 | self.fact: List = plist('pmat', sizes) 19 | 20 | def factorize(self) -> None: 21 | M: pmat = pmat(nxu, nxu) 22 | Mu: pmat = pmat(nu, nu) 23 | Mxut: pmat = pmat(nu, nxu) 24 | Mxx: pmat = pmat(nx, nx) 25 | Mxu: pmat = pmat(nxu, nu) 26 | Lu: pmat = pmat(nu, nu) 27 | Lxu: pmat = pmat(nxu, nxu) 28 | Q: pmat = pmat(nx, nx) 29 | R: pmat = pmat(nu, nu) 30 | BA: pmat = pmat(nx, nxu) 31 | BAtP: pmat = pmat(nxu, nx) 32 | pmat_copy(self.Q[N-1], self.P[N-1]) 33 | for i in range(1, N): 34 | pmat_hcat(self.B[N-i], self.A[N-i], BA) 35 | pmt_gemm_tn(BA, self.P[N-i], BAtP, BAtP) 36 | 37 | pmat_copy(self.Q[N-i], Q) 38 | pmat_copy(self.R[N-i], R) 39 | # M[0:nu,0:nu] = R[0:nu,0:nu] 40 | M[0:nu,0:nu] = R[0:nu,0:nu] 41 | M[nu:nu+nx,nu:nu+nx] = Q[0:nx,0:nx] 42 | 43 | # this is still not implemented! 44 | # R = M[0:nu,0:nu] 45 | 46 | for j in range(nu): 47 | for k in range(nu): 48 | M[j,k] = R[j,k] 49 | for j in range(nx): 50 | for k in range(nx): 51 | M[nu+j,nu+k] = Q[j,k] 52 | 53 | pmt_gemm_nn(BAtP, BA, M, M) 54 | for j in range(nu): 55 | for k in range(nu): 56 | Mu[j,k] = M[j,k] 57 | pmt_potrf(Mu, Lu) 58 | 59 | for j in range(nx): 60 | for k in range(nx): 61 | Mxut[k,nu+j] = M[j,k] 62 | 63 | for j in range(nx): 64 | for k in range(nx): 65 | Mxx[k,j] = M[nu+j,nu+k] 66 | 67 | pmt_potrsm(Lu, Mxut) 68 | pmat_tran(Mxut, Mxu) 69 | pmt_gemm_nn(Mxut, Mxu, self.P[N-i-1], self.P[N-i-1]) 70 | pmt_gead(-1.0, self.P[N-i-1], Mxx) 71 | pmat_copy(Mxx, self.P[N-i-1]) 72 | pmat_print(self.P[N-i-1]) 73 | 74 | def main() -> int: 75 | 76 | # test assignments 77 | M: pmat = pmat(n, n) 78 | 79 | a : pvec = pvec(n) 80 | a[1] = 3.0 81 | 82 | d : float = 10.0 83 | 84 | # float to pmat 85 | M[0,1] = d 86 | 87 | # float (const) to pmat 88 | M[0,1] = 1.0 89 | 90 | # pmat to float 91 | d = M[0, 1] 92 | 93 | # float to pvec 94 | a[0] = d 95 | 96 | # float (const) to pvec 97 | a[0] = 1.0 98 | 99 | # pvec to float 100 | d = a[0] 101 | 102 | # subscripted pmat to pmat 103 | for i in range(2): 104 | M[0,i] = M[0,i] 105 | 106 | # subscripted pvec to pvec 107 | a[0] = a[1] 108 | 109 | # subscripted pmat to pvec 110 | a[1] = M[0, 2] 111 | 112 | # subscripted pvec to pmat 113 | M[0, 2] = a[1] 114 | 115 | # run Riccati code 116 | A: pmat = pmat(nx, nx) 117 | A[0,0] = 0.8 118 | A[0,1] = 0.1 119 | A[1,0] = 0.0 120 | A[1,1] = 0.8 121 | 122 | B: pmat = pmat(nx, nu) 123 | B[0,0] = 1.0 124 | B[0,1] = 0.0 125 | B[1,0] = 0.0 126 | B[1,1] = 1.0 127 | 128 | Q: pmat = pmat(nx, nx) 129 | Q[0,0] = 1.0 130 | Q[0,1] = 0.0 131 | Q[1,0] = 0.0 132 | Q[1,1] = 1.0 133 | 134 | R: pmat = pmat(nu, nu) 135 | R[0,0] = 1.0 136 | R[0,1] = 0.0 137 | R[1,0] = 0.0 138 | R[1,1] = 1.0 139 | 140 | qp : qp_data = qp_data() 141 | 142 | for i in range(N): 143 | qp.A[i] = A 144 | 145 | for i in range(N): 146 | qp.B[i] = B 147 | 148 | for i in range(N): 149 | qp.Q[i] = Q 150 | 151 | for i in range(N): 152 | qp.R[i] = R 153 | 154 | qp.factorize() 155 | return 0 156 | -------------------------------------------------------------------------------- /benchmarks/run_benchmark.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import subprocess 3 | import json 4 | import matplotlib 5 | import matplotlib.pyplot as plt 6 | import matplotlib as mpl 7 | plt.rcParams['text.usetex'] = True 8 | plt.rcParams['text.latex.preamble'] = [r'\usepackage{lmodern}'] 9 | font = {'family':'serif'} 10 | plt.rc('font',**font) 11 | 12 | NM = range(2,150,4) 13 | # NM = range(2,20,2) 14 | NREP_small = 10000 15 | NREP_medium = 100 16 | NREP_large = 10 17 | AVG_CPU_TIME = [] 18 | res_file = 'riccati_benchmark_prometeo.json' 19 | RUN = False 20 | UPDATE_res = False 21 | UPDATE_FIGURE = True 22 | figname = 'riccati_benchmark' 23 | 24 | blasfeo_res_file = 'riccati_benchmark_blasfeo_api.json' 25 | LOAD_BLASFEO_RES = True 26 | numpy_res_file = 'riccati_benchmark_numpy.json' 27 | LOAD_NUMPY_RES = True 28 | numpy_blasfeo_res_file = 'riccati_benchmark_numpy_blasfeo.json' 29 | LOAD_NUMPY_BLASFEO_RES = True 30 | julia_res_file = 'riccati_benchmark_julia.json' 31 | LOAD_JULIA_RES = True 32 | 33 | if not UPDATE_res: 34 | print('Warning: not updating result file! This will just ' 35 | 'plot the results at the end of the benchmark.') 36 | 37 | if RUN: 38 | for i in range(len(NM)): 39 | print('running Riccati benchmark for case NM = {}'.format(NM[i])) 40 | code = "" 41 | if NM[i] < 30: 42 | NREP = NREP_small 43 | elif NM[i] < 100: 44 | NREP = NREP_medium 45 | else: 46 | NREP = NREP_large 47 | 48 | with open('riccati_mass_spring.py.in') as template: 49 | code = template.read() 50 | code = code.replace('NM', str(NM[i])) 51 | code = code.replace('NREP', str(NREP)) 52 | 53 | with open('riccati_mass_spring.py', 'w+') as bench_file: 54 | bench_file.write(code) 55 | 56 | cmd = 'pmt riccati_mass_spring.py --cgen=True' 57 | proc = subprocess.Popen([cmd], shell=True, stdout=subprocess.PIPE) 58 | 59 | try: 60 | outs, errs = proc.communicate() 61 | except TimeOutExpired: 62 | proc.kill() 63 | print('Exception raised at NM = {}'.format(NM[i])) 64 | outs, errs = proc.communicate() 65 | 66 | AVG_CPU_TIME.append([float(outs.decode())/NREP, NM[i]]) 67 | 68 | if UPDATE_res: 69 | with open(res_file, 'w+') as res: 70 | json.dump(AVG_CPU_TIME, res) 71 | 72 | else: 73 | with open(res_file) as res: 74 | AVG_CPU_TIME = json.load(res) 75 | 76 | 77 | AVG_CPU_TIME = np.array(AVG_CPU_TIME) 78 | 79 | plt.figure() 80 | plt.semilogy(2*AVG_CPU_TIME[:,1], AVG_CPU_TIME[:,0]) 81 | 82 | legend = [r'\texttt{prometeo}'] 83 | if LOAD_BLASFEO_RES: 84 | with open(blasfeo_res_file) as res: 85 | AVG_CPU_TIME_BLASFEO = json.load(res) 86 | AVG_CPU_TIME_BLASFEO = np.array(AVG_CPU_TIME_BLASFEO) 87 | plt.semilogy(2*AVG_CPU_TIME_BLASFEO[:,1], AVG_CPU_TIME_BLASFEO[:,0], 'o') 88 | legend.append(r'\texttt{BLASFEO}') 89 | 90 | if LOAD_NUMPY_RES: 91 | with open(numpy_res_file) as res: 92 | AVG_CPU_TIME_BLASFEO = json.load(res) 93 | AVG_CPU_TIME_BLASFEO = np.array(AVG_CPU_TIME_BLASFEO) 94 | plt.semilogy(2*AVG_CPU_TIME_BLASFEO[:,1], AVG_CPU_TIME_BLASFEO[:,0], '--', alpha=0.7) 95 | legend.append(r'\texttt{NumPy}') 96 | 97 | if LOAD_JULIA_RES: 98 | with open(julia_res_file) as res: 99 | AVG_CPU_TIME_BLASFEO = json.load(res) 100 | AVG_CPU_TIME_BLASFEO = np.array(AVG_CPU_TIME_BLASFEO) 101 | plt.semilogy(2*AVG_CPU_TIME_BLASFEO[:,1], AVG_CPU_TIME_BLASFEO[:,0], '--',alpha=0.7) 102 | legend.append(r'\texttt{Julia}') 103 | 104 | if LOAD_NUMPY_BLASFEO_RES: 105 | with open(numpy_blasfeo_res_file) as res: 106 | AVG_CPU_TIME_BLASFEO = json.load(res) 107 | AVG_CPU_TIME_BLASFEO = np.array(AVG_CPU_TIME_BLASFEO) 108 | plt.semilogy(2*AVG_CPU_TIME_BLASFEO[:,1], AVG_CPU_TIME_BLASFEO[:,0]) 109 | legend.append(r'\texttt{NumPy + BLASFEO}') 110 | 111 | 112 | plt.legend(legend) 113 | plt.grid() 114 | plt.xlabel(r'matrix size ($n_x$)') 115 | plt.ylabel(r'CPU time [s]') 116 | plt.title(r'Riccati factorization') 117 | if UPDATE_FIGURE: 118 | plt.savefig(figname + '.png', dpi=300, bbox_inches="tight") 119 | plt.show() 120 | 121 | 122 | -------------------------------------------------------------------------------- /benchmarks/test_riccati_numpy.py.in: -------------------------------------------------------------------------------- 1 | # Riccati recursion 2 | # author: Gianluca Frison 3 | 4 | import numpy as np 5 | import scipy as sp 6 | import scipy.linalg 7 | import time 8 | 9 | nmass = NM 10 | nrep = NREP 11 | 12 | nx = 2*nmass 13 | nu = nmass 14 | N = 5 15 | 16 | # data 17 | 18 | # mass spring system 19 | Ts = 0.5 # sampling time 20 | 21 | Ac = np.vstack(( \ 22 | np.hstack((np.zeros((nmass, nmass)), np.eye(nmass))), \ 23 | np.hstack((np.diag(-2*np.ones(nmass)) + np.diag(np.ones(nmass-1),-1) + \ 24 | np.diag(np.ones(nmass-1),1), np.zeros((nmass, nmass)))) \ 25 | )) 26 | 27 | Bc = np.vstack((np.zeros((nmass,nu)), np.eye(nu), np.zeros((nmass-nu, nu)))) 28 | 29 | M = sp.linalg.expm(np.vstack((np.hstack((Ts*Ac, Ts*Bc)), np.zeros((nu, 2*nmass+nu))))) 30 | 31 | # dynamical system 32 | A = M[0:nx, 0:nx] 33 | B = M[0:nx, nx:] 34 | 35 | # cost function 36 | Q = np.eye(nx) 37 | R = 2*np.eye(nu) 38 | 39 | # initial state 40 | x0 = np.zeros((nx, 1)) 41 | x0[0] = 3.5 42 | x0[1] = 3.5 43 | 44 | # work matrices 45 | BAt = np.vstack((B.transpose(), A.transpose())) 46 | BAt = np.array(BAt, dtype=np.float64, order='f') 47 | 48 | RSQ = np.zeros((nu+nx, nu+nx)) 49 | RSQ[0:nu, 0:nu] = R 50 | RSQ[nu:,nu:] = Q 51 | RSQ = np.array(RSQ, dtype=np.float64, order='f') 52 | 53 | BAtP = np.zeros((nu+nx, nx)) 54 | Lu = [] 55 | for ii in range(N): 56 | Lu.append(np.zeros((nu, nu))) 57 | 58 | Lxu = [] 59 | for ii in range(N): 60 | Lxu.append(np.zeros((nx, nu))) 61 | 62 | P = [] 63 | for ii in range(N+1): 64 | P.append(np.zeros((nx, nx))) 65 | 66 | M = np.zeros((nu+nx, nu+nx)) 67 | 68 | BAtL = np.zeros((nu+nx, nx)) 69 | BAtL = np.array(BAtL, dtype=np.float64, order='f') 70 | 71 | L = [] 72 | for ii in range(N): 73 | L.append(np.zeros((nu+nx, nu+nx))) 74 | L[ii] = np.array(L[ii], dtype=np.float64, order='f') 75 | L.append(np.zeros((nx, nx))) 76 | L[N] = np.array(L[N], dtype=np.float64, order='f') 77 | 78 | Lx = np.zeros((nx, nx)) 79 | Lx = np.array(Lx, dtype=np.float64, order='f') 80 | 81 | 82 | # Riccati recursion, classical algorithm 83 | 84 | tic = time.time() 85 | 86 | # for rep in range(nrep): 87 | 88 | # P[N] = Q 89 | 90 | # for ii in range(N): 91 | # BAtP = np.dot(BAt, P[N-ii]) 92 | # M = RSQ + np.dot(BAtP, BAt.transpose()) 93 | # Lu[N-1-ii] = np.linalg.cholesky(M[0:nu,0:nu]) 94 | # Lxu[N-1-ii] = sp.linalg.blas.dtrsm(1.0, Lu[N-1-ii].transpose(), M[nu:,0:nu].transpose()) 95 | # P[N-1-ii] = M[nu:,nu:] - np.dot(Lxu[N-1-ii].transpose(), Lxu[N-1-ii]) 96 | 97 | # toc = time.time() - tic 98 | 99 | # time_ric_classic = toc/nrep 100 | 101 | #Lu 102 | #Lxu 103 | #P 104 | #return 105 | 106 | # Riccati recursion, square root algorithm 107 | 108 | # tic = time.time() 109 | 110 | # for rep in range(nrep): 111 | 112 | # L[N] = np.linalg.cholesky(Q) 113 | # Lx = L[N][0] 114 | 115 | # for ii in range(N): 116 | # BAtL = np.dot(BAt, Lx) 117 | # M = RSQ + np.dot(BAtL, BAtL.transpose()) 118 | # L[N-1-ii] = np.linalg.cholesky(M) 119 | # Lx = L[N-1-ii][nu:,nu:] 120 | 121 | # toc = time.time() - tic 122 | # time_ric_square_root_np = toc/nrep 123 | 124 | # # ser (again) as fortran some matrix used by classic riccati 125 | # BAtL = np.array(BAtL, dtype=np.float64, order='f') 126 | # for ii in range(N+1): 127 | # L[ii] = np.array(L[ii], dtype=np.float64, order='f') 128 | 129 | tic = time.time() 130 | 131 | for rep in range(nrep): 132 | 133 | L[N] = Q 134 | (L[N], info) = sp.linalg.lapack.dpotrf(L[N], lower=1, clean=1, overwrite_a=1) 135 | Lx = L[N] 136 | 137 | for ii in range(N): 138 | # BAtL = sp.linalg.blas.dgemm(1.0, BAt, Lx, beta=0.0, c=BAtL, overwrite_c=1) 139 | BAtL = sp.linalg.blas.dtrmm(1.0, Lx, BAt, side=1, lower=1) 140 | L[N-1-ii] = RSQ + sp.linalg.blas.dsyrk(1.0, BAtL, lower=1, trans=0, c=L[N-1-ii], overwrite_c=1) 141 | # (L[N-1-ii], info) = sp.linalg.lapack.dpotrf(L[N-1-ii], lower=1, clean=1, overwrite_a=1) 142 | (L[N-1-ii], info) = sp.linalg.lapack.dpotrf(L[N-1-ii], lower=1, clean=0, overwrite_a=1) 143 | Lx = L[N-1-ii][nu:,nu:] 144 | 145 | toc = time.time() - tic 146 | time_ric_square_root_sp = toc/nrep 147 | 148 | #L 149 | #print(Lx) 150 | 151 | # print('{:e}\t{:e}\t{:e}'.format(time_ric_classic, time_ric_square_root_np, time_ric_square_root_sp)) 152 | print('{:e}'.format(time_ric_square_root_sp)) 153 | -------------------------------------------------------------------------------- /prometeo/linalg/pmat_blasfeo_wrapper.py: -------------------------------------------------------------------------------- 1 | from .blasfeo_wrapper import * 2 | from ctypes import * 3 | 4 | 5 | bw.blasfeo_dgeex1.restype = c_double 6 | # bw.blasfeo_dgead.argtypes = [c_int, c_int, double 7 | # void blasfeo_dgead(int m, int n, double alpha, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dmat *sC, int yi, int cj); 8 | 9 | def c_pmt_set_blasfeo_dmat(M, data: POINTER(c_double)): 10 | 11 | m = M.m 12 | n = M.n 13 | bw.blasfeo_pack_dmat(m, n, data, n, byref(M), 0, 0) 14 | 15 | def c_pmt_set_blasfeo_dmat_el(value, M, ai, aj): 16 | 17 | bw.blasfeo_dgein1(value, byref(M), ai, aj) 18 | 19 | def c_pmt_get_blasfeo_dmat_el(M, ai, aj): 20 | el = bw.blasfeo_dgeex1(byref(M), ai, aj) 21 | return el 22 | 23 | def c_pmt_set_pmt_blasfeo_dmat(data, M, ai, aj): 24 | 25 | m = M.m 26 | n = M.n 27 | bw.blasfeo_pack_dmat(m, n, data, n, byref(M), 0, 0) 28 | 29 | def c_pmt_create_blasfeo_dmat(m: int, n: int): 30 | 31 | size_strmat = bw.blasfeo_memsize_dmat(m, n) 32 | memory_strmat = c_void_p() 33 | bw.v_zeros_align(byref(memory_strmat), size_strmat) 34 | 35 | ptr_memory_strmat = cast(memory_strmat, c_char_p) 36 | 37 | data = (POINTER(c_double) * 1)() 38 | bw.d_zeros(byref(data), n, n) 39 | 40 | sA = blasfeo_dmat() 41 | 42 | bw.blasfeo_allocate_dmat(m, n, byref(sA)) 43 | bw.blasfeo_create_dmat(m, n, byref(sA), ptr_memory_strmat) 44 | bw.blasfeo_pack_dmat(m, n, data, n, byref(sA), 0, 0) 45 | # initialize to 0.0 46 | bw.blasfeo_dgese(m, n, 0.0, byref(sA), 0, 0); 47 | return sA 48 | 49 | # TODO(andrea): move body of these functions directly 50 | # inside the intermediate-level and make those the 51 | # low-level 52 | 53 | # intermediate-level linear algebra 54 | def c_pmt_dgemm_nn(A, B, C, D): 55 | bA = A.blasfeo_dmat 56 | bB = B.blasfeo_dmat 57 | bC = C.blasfeo_dmat 58 | bD = D.blasfeo_dmat 59 | 60 | bw.blasfeo_dgemm_nn(bA.m, bB.n, bA.n, 1.0, byref(bA), 0, 0, byref(bB), 0, 0, 1, byref(bC), 0, 0, byref(bD), 0, 0) 61 | return 62 | 63 | def c_pmt_dgemm_nt(A, B, C, D): 64 | bA = A.blasfeo_dmat 65 | bB = B.blasfeo_dmat 66 | bC = C.blasfeo_dmat 67 | bD = D.blasfeo_dmat 68 | 69 | bw.blasfeo_dgemm_nt(bA.m, bA.n, bB.m, 1.0, byref(bA), 0, 0, byref(bB), 0, 0, 1, byref(bC), 0, 0, byref(bD), 0, 0) 70 | return 71 | 72 | def c_pmt_dgemm_tn(A, B, C, D): 73 | bA = A.blasfeo_dmat 74 | bB = B.blasfeo_dmat 75 | bC = C.blasfeo_dmat 76 | bD = D.blasfeo_dmat 77 | 78 | bw.blasfeo_dgemm_tn(bA.n, bB.n, bA.m, 1.0, byref(bA), 0, 0, byref(bB), 0, 0, 1, byref(bC), 0, 0, byref(bD), 0, 0) 79 | return 80 | 81 | # def c_pmt_dgemm_tt(A, B, C, D): 82 | # bA = A.blasfeo_dmat 83 | # bB = B.blasfeo_dmat 84 | # bC = C.blasfeo_dmat 85 | # bD = D.blasfeo_dmat 86 | 87 | # bw.blasfeo_dgemm_tt(bA.m, bA.n, bB.n, 1.0, byref(bA), 0, 0, byref(bB), 0, 0, 1, byref(bC), 0, 0, byref(bD), 0, 0) 88 | # return 89 | 90 | def c_pmt_dgead(alpha, A, B): 91 | bA = A.blasfeo_dmat 92 | bB = B.blasfeo_dmat 93 | 94 | bw.blasfeo_dgead(bA.m, bA.n, alpha, byref(bA), 0, 0, byref(bB), 0, 0) 95 | return 96 | 97 | def c_pmt_drowpe(m, ipiv, A): 98 | bA = A.blasfeo_dmat 99 | bw.blasfeo_drowpe(m, ipiv, byref(bA)); 100 | return 101 | 102 | def c_pmt_getrf(A, fact, ipiv): 103 | bA = A.blasfeo_dmat 104 | bfact = fact.blasfeo_dmat 105 | bw.blasfeo_dgetrf_rp(bA.m, bA.m, byref(bA), 0, 0, byref(bfact), 0, 0, ipiv) 106 | return 107 | 108 | def c_pmt_potrf(A, fact): 109 | bA = A.blasfeo_dmat 110 | bfact = fact.blasfeo_dmat 111 | 112 | bw.blasfeo_dpotrf_l(bA.m, byref(bA), 0, 0, byref(bfact), 0, 0) 113 | return 114 | 115 | def c_pmt_trsm_llnn(A, B): 116 | bA = A.blasfeo_dmat 117 | bB = B.blasfeo_dmat 118 | 119 | bw.blasfeo_dtrsm_llnn(bB.m, bB.n, 1.0, byref(bA), 0, 0, byref(bB), 0, 0, byref(bB), 0, 0) 120 | return 121 | 122 | def c_pmt_trsm_llnu(A, B): 123 | bA = A.blasfeo_dmat 124 | bB = B.blasfeo_dmat 125 | 126 | bw.blasfeo_dtrsm_llnu(bB.m, bB.n, 1.0, byref(bA), 0, 0, byref(bB), 0, 0, byref(bB), 0, 0) 127 | return 128 | 129 | def c_pmt_trsm_lunn(A, B): 130 | bA = A.blasfeo_dmat 131 | bB = B.blasfeo_dmat 132 | 133 | bw.blasfeo_dtrsm_lunn(bB.m, bB.n, 1.0, byref(bA), 0, 0, byref(bB), \ 134 | 0, 0, byref(bB), 0, 0) 135 | return 136 | 137 | def c_pmt_trsv_llnu(A, b): 138 | bA = A.blasfeo_dmat 139 | bb = b.blasfeo_dvec 140 | 141 | bw.blasfeo_dtrsv_lnn(bb.m, 1.0, byref(bA), 0, 0, byref(bb), 0, byref(bb), 0) 142 | return 143 | 144 | def c_pmt_trsv_lunn(A, b): 145 | bA = A.blasfeo_dmat 146 | bb = b.blasfeo_dvec 147 | 148 | bw.blasfeo_dtrsv_lnn(bb.m, 1.0, byref(bA), 0, 0, byref(bb), \ 149 | 0, byref(bb), 0) 150 | return 151 | 152 | def c_pmt_dgemv_n(A, b, c, d): 153 | bA = A.blasfeo_dmat 154 | bb = b.blasfeo_dvec 155 | bc = c.blasfeo_dvec 156 | bd = d.blasfeo_dvec 157 | 158 | bw.blasfeo_dgemv_n(bA.m, bA.n, 1.0, byref(bA), 0, 0, byref(bb), 0, \ 159 | 1.0, byref(bc), 0, byref(bd), 0) 160 | return 161 | 162 | # auxiliary functions 163 | def c_pmt_print_blasfeo_dmat(A): 164 | bw.blasfeo_print_dmat(A.m, A.n, byref(A.blasfeo_dmat), 0, 0) 165 | return 166 | 167 | 168 | -------------------------------------------------------------------------------- /prometeo/cpmt/timing.h: -------------------------------------------------------------------------------- 1 | /************************************************************************************************** 2 | * * 3 | * This file is part of BLASFEO. * 4 | * * 5 | * BLASFEO -- BLAS For Embedded Optimization. * 6 | * Copyright (C) 2019 by Gianluca Frison. * 7 | * Developed at IMTEK (University of Freiburg) under the supervision of Moritz Diehl. * 8 | * All rights reserved. * 9 | * * 10 | * The 2-Clause BSD License * 11 | * * 12 | * Redistribution and use in source and binary forms, with or without * 13 | * modification, are permitted provided that the following conditions are met: * 14 | * * 15 | * 1. Redistributions of source code must retain the above copyright notice, this * 16 | * list of conditions and the following disclaimer. * 17 | * 2. Redistributions in binary form must reproduce the above copyright notice, * 18 | * this list of conditions and the following disclaimer in the documentation * 19 | * and/or other materials provided with the distribution. * 20 | * * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * 22 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * 23 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * 24 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR * 25 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * 26 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * 27 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * 28 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * 29 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * 30 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * 31 | * * 32 | * Author: Gianluca Frison, gianluca.frison (at) imtek.uni-freiburg.de * 33 | * * 34 | **************************************************************************************************/ 35 | 36 | #ifndef PROMETEO_TIMING_H_ 37 | #define PROMETEO_TIMING_H_ 38 | 39 | //#include 40 | 41 | #if (defined _WIN32 || defined _WIN64) && !(defined __MINGW32__ || defined __MINGW64__) 42 | 43 | /* Use Windows QueryPerformanceCounter for timing. */ 44 | #include 45 | 46 | /** A structure for keeping internal timer data. */ 47 | typedef struct prometeo_timer_ { 48 | LARGE_INTEGER tic; 49 | LARGE_INTEGER toc; 50 | LARGE_INTEGER freq; 51 | } prometeo_timer; 52 | 53 | #elif(defined __APPLE__) 54 | 55 | #include 56 | 57 | /** A structure for keeping internal timer data. */ 58 | typedef struct prometeo_timer_ { 59 | uint64_t tic; 60 | uint64_t toc; 61 | mach_timebase_info_data_t tinfo; 62 | } prometeo_timer; 63 | 64 | #elif(defined __DSPACE__) 65 | 66 | #include 67 | 68 | typedef struct prometeo_timer_ { 69 | double time; 70 | } prometeo_timer; 71 | 72 | #elif(defined __XILINX_NONE_ELF__) 73 | 74 | #include "xtime_l.h" 75 | 76 | typedef struct prometeo_timer_ { 77 | uint64_t tic; 78 | uint64_t toc; 79 | } prometeo_timer; 80 | 81 | #else 82 | 83 | /* Use POSIX clock_gettime() for timing on non-Windows machines. */ 84 | #include 85 | 86 | #if __STDC_VERSION__ >= 199901L // C99 Mode 87 | 88 | #include 89 | #include 90 | 91 | typedef struct prometeo_timer_ { 92 | struct timeval tic; 93 | struct timeval toc; 94 | } prometeo_timer; 95 | 96 | #else // ANSI C Mode 97 | 98 | /** A structure for keeping internal timer data. */ 99 | typedef struct prometeo_timer_ { 100 | struct timespec tic; 101 | struct timespec toc; 102 | } prometeo_timer; 103 | 104 | #endif // __STDC_VERSION__ >= 199901L 105 | 106 | #endif // (defined _WIN32 || defined _WIN64) 107 | 108 | /** A function for measurement of the current time. */ 109 | void prometeo_tic(prometeo_timer* t); 110 | 111 | /** A function which returns the elapsed time. */ 112 | double prometeo_toc(prometeo_timer* t); 113 | 114 | #endif // PROMETEO_TIMING_H_ 115 | -------------------------------------------------------------------------------- /prometeo/nonlinear/nonlinear.py: -------------------------------------------------------------------------------- 1 | import casadi as ca 2 | from numpy import * 3 | from tokenize import tokenize, untokenize, NUMBER, STRING, NAME, OP 4 | from io import BytesIO 5 | from ..linalg import pmat, pvec 6 | import inspect 7 | import os 8 | from jinja2 import Environment 9 | from jinja2.loaders import FileSystemLoader 10 | 11 | def pmat_to_numpy(A): 12 | np_A = ones((A.m, A.n)) 13 | for i in range(A.m): 14 | for j in range(A.n): 15 | np_A[i,j] = A[i,j] 16 | return np_A 17 | 18 | def pvec_to_numpy(v): 19 | np_v = ones((v.m, 1)) 20 | for i in range(v.m): 21 | np_v[i] = v[i] 22 | return np_v 23 | 24 | class pfun: 25 | 26 | def __init__(self, fun_name, expr, variables): 27 | # get stack 28 | stack = inspect.stack() 29 | 30 | prefix = 'global' 31 | prefix_at = 'global' 32 | for i in range(len(stack)-5, 0, -1): 33 | prefix = prefix + '_' + stack[i].function 34 | prefix_at = prefix_at + '@' + stack[i].function 35 | # tokenize 36 | # tokens = tokenize(BytesIO(expr.encode('utf-8')).readline) 37 | # for token in tokens: 38 | # print(token) 39 | 40 | tokens = tokenize(BytesIO(expr.encode('utf-8')).readline) 41 | 42 | # convert variables to Numpy 43 | np_var_names = [] 44 | ca_var_names = [] 45 | ca_fun_args = '' 46 | # ca_variables = [] 47 | fun_descriptor = dict() 48 | fun_descriptor['args'] = [] 49 | for var_name, var in variables.items(): 50 | if isinstance(var, pmat): 51 | dec_code = 'np_' + var_name + '= pmat_to_numpy(var)' 52 | exec(dec_code) 53 | np_var_names.append(var_name) 54 | elif isinstance(var, pvec): 55 | np_var_names.append(var_name) 56 | dec_code = 'np_' + var_name + '= pvec_to_numpy(var)' 57 | exec(dec_code) 58 | elif isinstance(var, ca.SX) or isinstance(var, ca.MX): 59 | dec_code = 'ca_' + var_name + '= ca.SX.sym(\'ca_' + var_name + \ 60 | '\',' + str(var.shape[0]) + ',' + str(var.shape[1]) + ')' 61 | exec(dec_code) 62 | ca_var_names.append(var_name) 63 | ca_fun_args = ca_fun_args + ', ca_' + var_name 64 | fun_descriptor['args'].append({'name' : var_name, 'size' : (var.shape[0],var.shape[1])}) 65 | else: 66 | raise Exception('Variable {} of unknown type {}'.format(var, type(var))) 67 | 68 | # strip leading comma from ca_fun_args 69 | ca_fun_args = ca_fun_args.replace(', ', '', 1) 70 | 71 | result = [] 72 | # find variables in expr 73 | for toknum, tokval, _, _, _ in tokens: 74 | if toknum == NAME and tokval in np_var_names: # replace NUMBER tokens 75 | result.append((toknum, 'np_' + tokval)) 76 | elif toknum == NAME and tokval in ca_var_names: # replace NUMBER tokens 77 | result.append((toknum, 'ca_' + tokval)) 78 | else: 79 | result.append((toknum, tokval)) 80 | 81 | ca_expr = untokenize(result).decode('utf-8').replace(" ", "") 82 | scoped_fun_name = prefix + '_' + fun_name 83 | scoped_fun_name_at = prefix_at + '@' + fun_name 84 | fun_descriptor['name'] = scoped_fun_name 85 | dec_code = 'fun = ca.Function(\'' + scoped_fun_name+ '\', [' \ 86 | + ca_fun_args + '], [' + ca_expr + '])' 87 | 88 | exec(dec_code) 89 | 90 | # get CasADi function from locals() 91 | self._ca_fun = locals()['fun'] 92 | 93 | # dump function signature to json file (same format as function_record) 94 | if not os.path.exists('__pmt_cache__'): 95 | os.makedirs('__pmt_cache__') 96 | os.chdir('__pmt_cache__') 97 | 98 | # serialize CasADi object 99 | self._ca_fun.save(scoped_fun_name_at + '.casadi') 100 | 101 | # TODO(adrea): 102 | # 1) CasADi functions are global (for now) 103 | # function_signature = {'global': { 104 | # '_Z4pmatdimsdims' : { 105 | # 'arg_types' : ["dims", "dims"], 106 | # 'ret_type': "pmat" 107 | # }}} 108 | 109 | # with open(json_file, 'w') as f: 110 | # json.dump(self.typed_record[self.scope], f, indent=4, sort_keys=True) 111 | 112 | # generate C code 113 | self._ca_fun.generate(scoped_fun_name + '.c') 114 | import pdb; pdb.set_trace() 115 | 116 | # render templated wrapper 117 | env = Environment(loader=FileSystemLoader(os.path.dirname(os.path.abspath(__file__)))) 118 | tmpl = env.get_template("casadi_wrapper.c.in") 119 | code = tmpl.render(fun_descriptor = fun_descriptor) 120 | with open('casadi_wrapper_' + scoped_fun_name + '.c', "w+") as f: 121 | f.write(code) 122 | 123 | tmpl = env.get_template("casadi_wrapper.h.in") 124 | code = tmpl.render(fun_descriptor = fun_descriptor) 125 | with open('casadi_wrapper_' + scoped_fun_name + '.h', "w+") as f: 126 | f.write(code) 127 | 128 | os.chdir('..') 129 | 130 | def __call__(self, args): 131 | if isinstance(args, pmat): 132 | args_ = pmat_to_numpy(args) 133 | elif isinstance(args, pvec): 134 | args_ = pvec_to_numpy(args) 135 | else: 136 | raise Exception('Invalid argument to CasADi function of type {}'.format(type(args))) 137 | return self._ca_fun(args_).full() 138 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | # import os 16 | # import sys 17 | # sys.path.insert(0, os.path.abspath('.')) 18 | 19 | 20 | import sphinx_rtd_theme 21 | 22 | # -- Project information ----------------------------------------------------- 23 | 24 | project = 'prometeo' 25 | copyright = '2020, Andrea Zanelli' 26 | author = 'Andrea Zanelli' 27 | 28 | # The short X.Y version 29 | version = '' 30 | # The full version, including alpha/beta/rc tags 31 | release = '0.0.5' 32 | 33 | 34 | # -- General configuration --------------------------------------------------- 35 | 36 | # If your documentation needs a minimal Sphinx version, state it here. 37 | # 38 | # needs_sphinx = '1.0' 39 | 40 | # Add any Sphinx extension module names here, as strings. They can be 41 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 42 | # ones. 43 | extensions = [ 44 | ] 45 | 46 | # Add any paths that contain templates here, relative to this directory. 47 | templates_path = ['_templates'] 48 | 49 | # The suffix(es) of source filenames. 50 | # You can specify multiple suffix as a list of string: 51 | # 52 | # source_suffix = ['.rst', '.md'] 53 | source_suffix = '.rst' 54 | 55 | # The master toctree document. 56 | master_doc = 'index' 57 | 58 | # The language for content autogenerated by Sphinx. Refer to documentation 59 | # for a list of supported languages. 60 | # 61 | # This is also used if you do content translation via gettext catalogs. 62 | # Usually you set "language" from the command line for these cases. 63 | language = None 64 | 65 | # List of patterns, relative to source directory, that match files and 66 | # directories to ignore when looking for source files. 67 | # This pattern also affects html_static_path and html_extra_path. 68 | exclude_patterns = [] 69 | 70 | # The name of the Pygments (syntax highlighting) style to use. 71 | pygments_style = None 72 | 73 | 74 | # -- Options for HTML output ------------------------------------------------- 75 | 76 | # The theme to use for HTML and HTML Help pages. See the documentation for 77 | # a list of builtin themes. 78 | # 79 | 80 | 81 | extensions = [ 82 | "sphinx_rtd_theme", 83 | ] 84 | 85 | html_theme = "sphinx_rtd_theme" 86 | 87 | 88 | # Theme options are theme-specific and customize the look and feel of a theme 89 | # further. For a list of options available for each theme, see the 90 | # documentation. 91 | # 92 | # html_theme_options = {} 93 | 94 | # Add any paths that contain custom static files (such as style sheets) here, 95 | # relative to this directory. They are copied after the builtin static files, 96 | # so a file named "default.css" will overwrite the builtin "default.css". 97 | html_static_path = ['_static'] 98 | 99 | # Custom sidebar templates, must be a dictionary that maps document names 100 | # to template names. 101 | # 102 | # The default sidebars (for documents that don't match any pattern) are 103 | # defined by theme itself. Builtin themes are using these templates by 104 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 105 | # 'searchbox.html']``. 106 | # 107 | # html_sidebars = {} 108 | 109 | 110 | # -- Options for HTMLHelp output --------------------------------------------- 111 | 112 | # Output file base name for HTML help builder. 113 | htmlhelp_basename = 'prometeodoc' 114 | 115 | 116 | # -- Options for LaTeX output ------------------------------------------------ 117 | 118 | latex_elements = { 119 | # The paper size ('letterpaper' or 'a4paper'). 120 | # 121 | # 'papersize': 'letterpaper', 122 | 123 | # The font size ('10pt', '11pt' or '12pt'). 124 | # 125 | # 'pointsize': '10pt', 126 | 127 | # Additional stuff for the LaTeX preamble. 128 | # 129 | # 'preamble': '', 130 | 131 | # Latex figure (float) alignment 132 | # 133 | # 'figure_align': 'htbp', 134 | } 135 | 136 | # Grouping the document tree into LaTeX files. List of tuples 137 | # (source start file, target name, title, 138 | # author, documentclass [howto, manual, or own class]). 139 | latex_documents = [ 140 | (master_doc, 'prometeo.tex', 'prometeo Documentation', 141 | 'Andrea Zanelli', 'manual'), 142 | ] 143 | 144 | 145 | # -- Options for manual page output ------------------------------------------ 146 | 147 | # One entry per manual page. List of tuples 148 | # (source start file, name, description, authors, manual section). 149 | man_pages = [ 150 | (master_doc, 'prometeo', 'prometeo Documentation', 151 | [author], 1) 152 | ] 153 | 154 | 155 | # -- Options for Texinfo output ---------------------------------------------- 156 | 157 | # Grouping the document tree into Texinfo files. List of tuples 158 | # (source start file, target name, title, author, 159 | # dir menu entry, description, category) 160 | texinfo_documents = [ 161 | (master_doc, 'prometeo', 'prometeo Documentation', 162 | author, 'prometeo', 'One line description of project.', 163 | 'Miscellaneous'), 164 | ] 165 | 166 | 167 | # -- Options for Epub output ------------------------------------------------- 168 | 169 | # Bibliographic Dublin Core info. 170 | epub_title = project 171 | 172 | # The unique identifier of the text. This can be a ISBN number 173 | # or the project homepage. 174 | # 175 | # epub_identifier = '' 176 | 177 | # A unique identification for the text. 178 | # 179 | # epub_uid = '' 180 | 181 | # A list of files that should not be packed into the epub file. 182 | epub_exclude_files = ['search.html'] 183 | -------------------------------------------------------------------------------- /prometeo/cpmt/timing.c: -------------------------------------------------------------------------------- 1 | /************************************************************************************************** 2 | * * 3 | * This file is part of BLASFEO. * 4 | * * 5 | * BLASFEO -- BLAS For Embedded Optimization. * 6 | * Copyright (C) 2019 by Gianluca Frison. * 7 | * Developed at IMTEK (University of Freiburg) under the supervision of Moritz Diehl. * 8 | * All rights reserved. * 9 | * * 10 | * The 2-Clause BSD License * 11 | * * 12 | * Redistribution and use in source and binary forms, with or without * 13 | * modification, are permitted provided that the following conditions are met: * 14 | * * 15 | * 1. Redistributions of source code must retain the above copyright notice, this * 16 | * list of conditions and the following disclaimer. * 17 | * 2. Redistributions in binary form must reproduce the above copyright notice, * 18 | * this list of conditions and the following disclaimer in the documentation * 19 | * and/or other materials provided with the distribution. * 20 | * * 21 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * 22 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * 23 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * 24 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR * 25 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * 26 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * 27 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * 28 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * 29 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * 30 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * 31 | * * 32 | * Author: Gianluca Frison, gianluca.frison (at) imtek.uni-freiburg.de * 33 | * * 34 | **************************************************************************************************/ 35 | 36 | #include "timing.h" 37 | 38 | #if (defined _WIN32 || defined _WIN64) && !(defined __MINGW32__ || defined __MINGW64__) 39 | 40 | void prometeo_tic(prometeo_timer* t) { 41 | QueryPerformanceFrequency(&t->freq); 42 | QueryPerformanceCounter(&t->tic); 43 | } 44 | 45 | double prometeo_toc(prometeo_timer* t) { 46 | QueryPerformanceCounter(&t->toc); 47 | return ((t->toc.QuadPart - t->tic.QuadPart) / (double)t->freq.QuadPart); 48 | } 49 | 50 | #elif(defined __APPLE__) 51 | void prometeo_tic(prometeo_timer* t) { 52 | /* read current clock cycles */ 53 | t->tic = mach_absolute_time(); 54 | } 55 | 56 | double prometeo_toc(prometeo_timer* t) { 57 | uint64_t duration; /* elapsed time in clock cycles*/ 58 | 59 | t->toc = mach_absolute_time(); 60 | duration = t->toc - t->tic; 61 | 62 | /*conversion from clock cycles to nanoseconds*/ 63 | mach_timebase_info(&(t->tinfo)); 64 | duration *= t->tinfo.numer; 65 | duration /= t->tinfo.denom; 66 | 67 | return (double)duration / 1e9; 68 | } 69 | 70 | #elif(defined __DSPACE__) 71 | 72 | void prometeo_tic(prometeo_timer* t) { 73 | ds1401_tic_start(); 74 | t->time = ds1401_tic_read(); 75 | } 76 | 77 | double prometeo_toc(prometeo_timer* t) { 78 | return ds1401_tic_read() - t->time; 79 | } 80 | 81 | #elif defined(__XILINX_NONE_ELF__) 82 | 83 | void prometeo_tic(prometeo_timer* t) { 84 | XTime_GetTime(&(t->tic)); 85 | } 86 | 87 | double prometeo_toc(prometeo_timer* t) { 88 | uint64_t toc; 89 | XTime_GetTime(&toc); 90 | t->toc = toc; 91 | 92 | /* time in s */ 93 | return (double) (toc - t->tic) / (COUNTS_PER_SECOND); 94 | } 95 | #else 96 | 97 | #if __STDC_VERSION__ >= 199901L // C99 Mode 98 | 99 | /* read current time */ 100 | void prometeo_tic(prometeo_timer* t) { 101 | gettimeofday(&t->tic, 0); 102 | } 103 | 104 | /* return time passed since last call to tic on this timer */ 105 | double prometeo_toc(prometeo_timer* t) { 106 | struct timeval temp; 107 | 108 | gettimeofday(&t->toc, 0); 109 | 110 | if ((t->toc.tv_usec - t->tic.tv_usec) < 0) { 111 | temp.tv_sec = t->toc.tv_sec - t->tic.tv_sec - 1; 112 | temp.tv_usec = 1000000 + t->toc.tv_usec - t->tic.tv_usec; 113 | } else { 114 | temp.tv_sec = t->toc.tv_sec - t->tic.tv_sec; 115 | temp.tv_usec = t->toc.tv_usec - t->tic.tv_usec; 116 | } 117 | 118 | return (double)temp.tv_sec + (double)temp.tv_usec / 1e6; 119 | } 120 | 121 | #else // ANSI C Mode 122 | 123 | /* read current time */ 124 | void prometeo_tic(prometeo_timer* t) { 125 | clock_gettime(CLOCK_MONOTONIC, &t->tic); 126 | } 127 | 128 | 129 | /* return time passed since last call to tic on this timer */ 130 | double prometeo_toc(prometeo_timer* t) { 131 | struct timespec temp; 132 | 133 | clock_gettime(CLOCK_MONOTONIC, &t->toc); 134 | 135 | if ((t->toc.tv_nsec - t->tic.tv_nsec) < 0) { 136 | temp.tv_sec = t->toc.tv_sec - t->tic.tv_sec - 1; 137 | temp.tv_nsec = 1000000000+t->toc.tv_nsec - t->tic.tv_nsec; 138 | } else { 139 | temp.tv_sec = t->toc.tv_sec - t->tic.tv_sec; 140 | temp.tv_nsec = t->toc.tv_nsec - t->tic.tv_nsec; 141 | } 142 | 143 | return (double)temp.tv_sec + (double)temp.tv_nsec / 1e9; 144 | } 145 | 146 | #endif // __STDC_VERSION__ >= 199901L 147 | 148 | #endif // (defined _WIN32 || _WIN64) 149 | -------------------------------------------------------------------------------- /prometeo/cgen/node_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Part of the astor library for Python AST manipulation. 4 | 5 | License: 3-clause BSD 6 | 7 | Copyright 2012-2015 (c) Patrick Maupin 8 | Copyright 2013-2015 (c) Berker Peksag 9 | 10 | Utilities for node (and, by extension, tree) manipulation. 11 | For a whole-tree approach, see the treewalk submodule. 12 | 13 | """ 14 | 15 | import ast 16 | import itertools 17 | 18 | try: 19 | zip_longest = itertools.zip_longest 20 | except AttributeError: 21 | zip_longest = itertools.izip_longest 22 | 23 | 24 | class NonExistent(object): 25 | """This is not the class you are looking for. 26 | """ 27 | pass 28 | 29 | 30 | def iter_node(node, name='', unknown=None, 31 | # Runtime optimization 32 | list=list, getattr=getattr, isinstance=isinstance, 33 | enumerate=enumerate, missing=NonExistent): 34 | """Iterates over an object: 35 | 36 | - If the object has a _fields attribute, 37 | it gets attributes in the order of this 38 | and returns name, value pairs. 39 | 40 | - Otherwise, if the object is a list instance, 41 | it returns name, value pairs for each item 42 | in the list, where the name is passed into 43 | this function (defaults to blank). 44 | 45 | - Can update an unknown set with information about 46 | attributes that do not exist in fields. 47 | """ 48 | fields = getattr(node, '_fields', None) 49 | if fields is not None: 50 | for name in fields: 51 | value = getattr(node, name, missing) 52 | if value is not missing: 53 | yield value, name 54 | if unknown is not None: 55 | unknown.update(set(vars(node)) - set(fields)) 56 | elif isinstance(node, list): 57 | for value in node: 58 | yield value, name 59 | 60 | 61 | def dump_tree(node, name=None, initial_indent='', indentation=' ', 62 | maxline=120, maxmerged=80, 63 | # Runtime optimization 64 | iter_node=iter_node, special=ast.AST, 65 | list=list, isinstance=isinstance, type=type, len=len): 66 | """Dumps an AST or similar structure: 67 | 68 | - Pretty-prints with indentation 69 | - Doesn't print line/column/ctx info 70 | 71 | """ 72 | def dump(node, name=None, indent=''): 73 | level = indent + indentation 74 | name = name and name + '=' or '' 75 | values = list(iter_node(node)) 76 | if isinstance(node, list): 77 | prefix, suffix = '%s[' % name, ']' 78 | elif values: 79 | prefix, suffix = '%s%s(' % (name, type(node).__name__), ')' 80 | elif isinstance(node, special): 81 | prefix, suffix = name + type(node).__name__, '' 82 | else: 83 | return '%s%s' % (name, repr(node)) 84 | node = [dump(a, b, level) for a, b in values if b != 'ctx'] 85 | oneline = '%s%s%s' % (prefix, ', '.join(node), suffix) 86 | if len(oneline) + len(indent) < maxline: 87 | return '%s' % oneline 88 | if node and len(prefix) + len(node[0]) < maxmerged: 89 | prefix = '%s%s,' % (prefix, node.pop(0)) 90 | node = (',\n%s' % level).join(node).lstrip() 91 | return '%s\n%s%s%s' % (prefix, level, node, suffix) 92 | return dump(node, name, initial_indent) 93 | 94 | 95 | def strip_tree(node, 96 | # Runtime optimization 97 | iter_node=iter_node, special=ast.AST, 98 | list=list, isinstance=isinstance, type=type, len=len): 99 | """Strips an AST by removing all attributes not in _fields. 100 | 101 | Returns a set of the names of all attributes stripped. 102 | 103 | This canonicalizes two trees for comparison purposes. 104 | """ 105 | stripped = set() 106 | 107 | def strip(node, indent): 108 | unknown = set() 109 | leaf = True 110 | for subnode, _ in iter_node(node, unknown=unknown): 111 | leaf = False 112 | strip(subnode, indent + ' ') 113 | if leaf: 114 | if isinstance(node, special): 115 | unknown = set(vars(node)) 116 | stripped.update(unknown) 117 | for name in unknown: 118 | delattr(node, name) 119 | if hasattr(node, 'ctx'): 120 | delattr(node, 'ctx') 121 | if 'ctx' in node._fields: 122 | mylist = list(node._fields) 123 | mylist.remove('ctx') 124 | node._fields = mylist 125 | strip(node, '') 126 | return stripped 127 | 128 | 129 | class ExplicitNodeVisitor(ast.NodeVisitor): 130 | """This expands on the ast module's NodeVisitor class 131 | to remove any implicit visits. 132 | 133 | """ 134 | 135 | def abort_visit(node): # XXX: self? 136 | msg = 'No defined handler for node of type %s' 137 | raise AttributeError(msg % node.__class__.__name__) 138 | 139 | def visit(self, node, abort=abort_visit): 140 | """Visit a node.""" 141 | method = 'visit_' + node.__class__.__name__ 142 | visitor = getattr(self, method, abort) 143 | return visitor(node) 144 | 145 | 146 | def allow_ast_comparison(): 147 | """This ugly little monkey-patcher adds in a helper class 148 | to all the AST node types. This helper class allows 149 | eq/ne comparisons to work, so that entire trees can 150 | be easily compared by Python's comparison machinery. 151 | Used by the anti8 functions to compare old and new ASTs. 152 | Could also be used by the test library. 153 | 154 | 155 | """ 156 | 157 | class CompareHelper(object): 158 | def __eq__(self, other): 159 | return type(self) == type(other) and vars(self) == vars(other) 160 | 161 | def __ne__(self, other): 162 | return type(self) != type(other) or vars(self) != vars(other) 163 | 164 | for item in vars(ast).values(): 165 | if type(item) != type: 166 | continue 167 | if issubclass(item, ast.AST): 168 | try: 169 | item.__bases__ = tuple(list(item.__bases__) + [CompareHelper]) 170 | except TypeError: 171 | pass 172 | 173 | 174 | def fast_compare(tree1, tree2): 175 | """ This is optimized to compare two AST trees for equality. 176 | It makes several assumptions that are currently true for 177 | AST trees used by rtrip, and it doesn't examine the _attributes. 178 | """ 179 | 180 | geta = ast.AST.__getattribute__ 181 | 182 | work = [(tree1, tree2)] 183 | pop = work.pop 184 | extend = work.extend 185 | # TypeError in cPython, AttributeError in PyPy 186 | exception = TypeError, AttributeError 187 | zipl = zip_longest 188 | type_ = type 189 | list_ = list 190 | while work: 191 | n1, n2 = pop() 192 | try: 193 | f1 = geta(n1, '_fields') 194 | f2 = geta(n2, '_fields') 195 | except exception: 196 | if type_(n1) is list_: 197 | extend(zipl(n1, n2)) 198 | continue 199 | if n1 == n2: 200 | continue 201 | return False 202 | else: 203 | f1 = [x for x in f1 if x != 'ctx'] 204 | if f1 != [x for x in f2 if x != 'ctx']: 205 | return False 206 | extend((geta(n1, fname), geta(n2, fname)) for fname in f1) 207 | 208 | return True 209 | -------------------------------------------------------------------------------- /prometeo/cgen/source_repr.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Part of the astor library for Python AST manipulation. 4 | 5 | License: 3-clause BSD 6 | 7 | Copyright (c) 2015 Patrick Maupin 8 | 9 | Pretty-print source -- post-process for the decompiler 10 | 11 | The goals of the initial cut of this engine are: 12 | 13 | 1) Do a passable, if not PEP8, job of line-wrapping. 14 | 15 | 2) Serve as an example of an interface to the decompiler 16 | for anybody who wants to do a better job. :) 17 | """ 18 | 19 | 20 | def pretty_source(source): 21 | """ Prettify the source. 22 | """ 23 | 24 | return ''.join(split_lines(source)) 25 | 26 | 27 | def split_lines(source, maxline=79): 28 | """Split inputs according to lines. 29 | If a line is short enough, just yield it. 30 | Otherwise, fix it. 31 | """ 32 | result = [] 33 | extend = result.extend 34 | append = result.append 35 | line = [] 36 | multiline = False 37 | count = 0 38 | find = str.find 39 | for item in source: 40 | # print(item) 41 | index = find(item, '\n') 42 | if index: 43 | line.append(item) 44 | multiline = index > 0 45 | count += len(item) 46 | else: 47 | if line: 48 | if count <= maxline or multiline: 49 | extend(line) 50 | else: 51 | wrap_line(line, maxline, result) 52 | count = 0 53 | multiline = False 54 | line = [] 55 | append(item) 56 | return result 57 | 58 | 59 | def count(group, slen=str.__len__): 60 | return sum([slen(x) for x in group]) 61 | 62 | 63 | def wrap_line(line, maxline=79, result=[], count=count): 64 | """ We have a line that is too long, 65 | so we're going to try to wrap it. 66 | """ 67 | 68 | # Extract the indentation 69 | 70 | append = result.append 71 | extend = result.extend 72 | 73 | indentation = line[0] 74 | lenfirst = len(indentation) 75 | indent = lenfirst - len(indentation.strip()) 76 | assert indent in (0, lenfirst) 77 | indentation = line.pop(0) if indent else '' 78 | 79 | # Get splittable/non-splittable groups 80 | 81 | dgroups = list(delimiter_groups(line)) 82 | unsplittable = dgroups[::2] 83 | splittable = dgroups[1::2] 84 | 85 | # If the largest non-splittable group won't fit 86 | # on a line, try to add parentheses to the line. 87 | 88 | if max(count(x) for x in unsplittable) > maxline - indent: 89 | line = add_parens(line, maxline, indent) 90 | dgroups = list(delimiter_groups(line)) 91 | unsplittable = dgroups[::2] 92 | splittable = dgroups[1::2] 93 | 94 | # Deal with the first (always unsplittable) group, and 95 | # then set up to deal with the remainder in pairs. 96 | 97 | first = unsplittable[0] 98 | append(indentation) 99 | extend(first) 100 | if not splittable: 101 | return result 102 | pos = indent + count(first) 103 | indentation += ' ' 104 | indent += 4 105 | if indent >= maxline/2: 106 | maxline = maxline/2 + indent 107 | 108 | for sg, nsg in zip(splittable, unsplittable[1:]): 109 | 110 | if sg: 111 | # If we already have stuff on the line and even 112 | # the very first item won't fit, start a new line 113 | if pos > indent and pos + len(sg[0]) > maxline: 114 | append('\n') 115 | append(indentation) 116 | pos = indent 117 | 118 | # Dump lines out of the splittable group 119 | # until the entire thing fits 120 | csg = count(sg) 121 | while pos + csg > maxline: 122 | ready, sg = split_group(sg, pos, maxline) 123 | if ready[-1].endswith(' '): 124 | ready[-1] = ready[-1][:-1] 125 | extend(ready) 126 | append('\n') 127 | append(indentation) 128 | pos = indent 129 | csg = count(sg) 130 | 131 | # Dump the remainder of the splittable group 132 | if sg: 133 | extend(sg) 134 | pos += csg 135 | 136 | # Dump the unsplittable group, optionally 137 | # preceded by a linefeed. 138 | cnsg = count(nsg) 139 | if pos > indent and pos + cnsg > maxline: 140 | append('\n') 141 | append(indentation) 142 | pos = indent 143 | extend(nsg) 144 | pos += cnsg 145 | 146 | 147 | def split_group(source, pos, maxline): 148 | """ Split a group into two subgroups. The 149 | first will be appended to the current 150 | line, the second will start the new line. 151 | 152 | Note that the first group must always 153 | contain at least one item. 154 | 155 | The original group may be destroyed. 156 | """ 157 | first = [] 158 | source.reverse() 159 | while source: 160 | tok = source.pop() 161 | first.append(tok) 162 | pos += len(tok) 163 | if source: 164 | tok = source[-1] 165 | allowed = (maxline + 1) if tok.endswith(' ') else (maxline - 4) 166 | if pos + len(tok) > allowed: 167 | break 168 | 169 | source.reverse() 170 | return first, source 171 | 172 | 173 | begin_delim = set('([{') 174 | end_delim = set(')]}') 175 | end_delim.add('):') 176 | 177 | 178 | def delimiter_groups(line, begin_delim=begin_delim, 179 | end_delim=end_delim): 180 | """Split a line into alternating groups. 181 | The first group cannot have a line feed inserted, 182 | the next one can, etc. 183 | """ 184 | text = [] 185 | line = iter(line) 186 | while True: 187 | # First build and yield an unsplittable group 188 | for item in line: 189 | text.append(item) 190 | if item in begin_delim: 191 | break 192 | if not text: 193 | break 194 | yield text 195 | 196 | # Now build and yield a splittable group 197 | level = 0 198 | text = [] 199 | for item in line: 200 | if item in begin_delim: 201 | level += 1 202 | elif item in end_delim: 203 | level -= 1 204 | if level < 0: 205 | yield text 206 | text = [item] 207 | break 208 | text.append(item) 209 | else: 210 | assert not text, text 211 | break 212 | 213 | 214 | statements = set(['del ', 'return', 'yield ', 'if ', 'while ']) 215 | 216 | 217 | def add_parens(line, maxline, indent, statements=statements, count=count): 218 | """Attempt to add parentheses around the line 219 | in order to make it splittable. 220 | """ 221 | 222 | if line[0] in statements: 223 | index = 1 224 | if not line[0].endswith(' '): 225 | index = 2 226 | assert line[1] == ' ' 227 | line.insert(index, '(') 228 | if line[-1] == ':': 229 | line.insert(-1, ')') 230 | else: 231 | line.append(')') 232 | 233 | # That was the easy stuff. Now for assignments. 234 | groups = list(get_assign_groups(line)) 235 | if len(groups) == 1: 236 | # So sad, too bad 237 | return line 238 | 239 | counts = list(count(x) for x in groups) 240 | didwrap = False 241 | 242 | # If the LHS is large, wrap it first 243 | if sum(counts[:-1]) >= maxline - indent - 4: 244 | for group in groups[:-1]: 245 | didwrap = False # Only want to know about last group 246 | if len(group) > 1: 247 | group.insert(0, '(') 248 | group.insert(-1, ')') 249 | didwrap = True 250 | 251 | # Might not need to wrap the RHS if wrapped the LHS 252 | if not didwrap or counts[-1] > maxline - indent - 10: 253 | groups[-1].insert(0, '(') 254 | groups[-1].append(')') 255 | 256 | return [item for group in groups for item in group] 257 | 258 | 259 | # Assignment operators 260 | ops = list('|^&+-*/%@~') + '<< >> // **'.split() + [''] 261 | ops = set(' %s= ' % x for x in ops) 262 | 263 | 264 | def get_assign_groups(line, ops=ops): 265 | """ Split a line into groups by assignment (including 266 | augmented assignment) 267 | """ 268 | group = [] 269 | for item in line: 270 | group.append(item) 271 | if item in ops: 272 | yield group 273 | group = [] 274 | yield group 275 | -------------------------------------------------------------------------------- /prometeo/laparser/laparser.py: -------------------------------------------------------------------------------- 1 | """ 2 | linear algebra parser adapted from laparser (Mike Ellis, Ellis & Grant, Inc 2005) 3 | """ 4 | 5 | import re, sys 6 | import json 7 | from pyparsing import ( 8 | Word, 9 | alphas, 10 | ParseException, 11 | Literal, 12 | CaselessLiteral, 13 | Combine, 14 | Optional, 15 | nums, 16 | Forward, 17 | ZeroOrMore, 18 | StringEnd, 19 | alphanums, 20 | ) 21 | 22 | debug_flag = False 23 | 24 | # variables that hold intermediate parsing results and a couple of 25 | # helper functions. 26 | exprStack = [] # Holds operators and operands parsed from input. 27 | 28 | def _pushFirst(str, loc, toks): 29 | if debug_flag: 30 | print("pushing ", toks[0], "str is ", str) 31 | exprStack.append(toks[0]) 32 | 33 | # the following statements define the grammar for the parser. 34 | 35 | point = Literal(".") 36 | e = CaselessLiteral("E") 37 | plusorminus = Literal("+") | Literal("-") 38 | number = Word(nums) 39 | integer = Combine(Optional(plusorminus) + number) 40 | floatnumber = Combine( 41 | integer + Optional(point + Optional(number)) + Optional(e + integer) 42 | ) 43 | 44 | lbracket = Literal("[") 45 | rbracket = Literal("]") 46 | ident = Forward() 47 | ## The definition below treats array accesses as identifiers. This means your expressions 48 | ## can include references to array elements, rows and columns, e.g., a = b[i] + 5. 49 | ## Expressions within []'s are not presently supported, so a = b[i+1] will raise 50 | ## a ParseException. 51 | ident = Combine( 52 | Word(alphas + "-", alphanums + "_") 53 | + ZeroOrMore(lbracket + (Word(alphas + "-", alphanums + "_") | integer) + rbracket) 54 | ) 55 | 56 | plus = Literal("+") 57 | minus = Literal("-") 58 | mult = Literal("*") 59 | div = Literal("/") 60 | solveop = Literal("\\") 61 | outer = Literal("@") 62 | lpar = Literal("(").suppress() 63 | rpar = Literal(")").suppress() 64 | addop = plus | minus 65 | multop = mult | div | outer | solveop 66 | expop = Literal(".") 67 | assignop = Literal("=") 68 | 69 | expr = Forward() 70 | atom = (e | floatnumber | integer | ident).setParseAction(_pushFirst) | ( 71 | lpar + expr.suppress() + rpar 72 | ) 73 | factor = Forward() 74 | factor << atom + ZeroOrMore((expop + factor).setParseAction(_pushFirst)) 75 | 76 | term = factor + ZeroOrMore((multop + factor).setParseAction(_pushFirst)) 77 | expr << term + ZeroOrMore((addop + term).setParseAction(_pushFirst)) 78 | equation = ident + assignop + expr + StringEnd() 79 | 80 | # end of grammar definition 81 | # ----------------------------------------------------------------------------- 82 | ## the following are helper variables and functions used by the Binary Infix Operator 83 | ## Functions described below. 84 | 85 | class Operand: 86 | def __init__(self, oname, otype, osize, oexpr): 87 | self.name = oname 88 | self.type = otype 89 | self.size = osize 90 | self.expr = oexpr 91 | 92 | ## end of BIO func definitions 93 | ##---------------------------------------------------------------------------- 94 | 95 | # map operator symbols to corresponding BIO funcs 96 | class LAParser(): 97 | def __init__(self, typed_record_json, var_dim_record_json, dim_record_json): 98 | 99 | with open(typed_record_json, 'r') as f: 100 | typed_record = json.load(f) 101 | 102 | with open(var_dim_record_json, 'r') as f: 103 | var_dim_record = json.load(f) 104 | 105 | with open(dim_record_json, 'r') as f: 106 | dim_record = json.load(f) 107 | 108 | self.records = dict() 109 | self.records[0] = typed_record 110 | self.records[1] = var_dim_record 111 | self.records[2] = dim_record 112 | 113 | self.exprStack = [] 114 | 115 | def _ismat(op): 116 | if op.type == 'pmat': 117 | return True 118 | else: 119 | return False 120 | 121 | def _addfunc(a, b): 122 | typed_record = self.records[0] 123 | if _ismat(a) and _ismat(b): 124 | return Operand(a.name + '_+_' + b.name, 'pmat', [a.size[0], a.size[1]], \ 125 | '_c_pmt_gead(1.0, %s, %s)' % (a.expr, b.expr)) 126 | else: 127 | raise TypeError 128 | 129 | def _subfunc(a, b): 130 | typed_record = self.records[0] 131 | if _ismat(a) and _ismat(b): 132 | return Operand(a.name + '_-_' + b.name, 'pmat', [a.size[0], a.size[1]], \ 133 | '_c_pmt_gead(-1.0, %s,%s)' % (a.expr, b.expr)) 134 | else: 135 | raise TypeError 136 | 137 | def _mulfunc(a, b): 138 | typed_record = self.records[0] 139 | if _ismat(a) and _ismat(b): 140 | return Operand(a.name + '_*_' + b.name, 'pmat', [a.size[0], b.size[1]], \ 141 | '_c_pmt_gemm_nn(%s,%s)' % (a.expr, b.expr)) 142 | else: 143 | raise TypeError 144 | 145 | def _solvefunc(a, b): 146 | typed_record = self.records[0] 147 | if _ismat(a) and _ismat(b): 148 | return Operand(a.name + '_/_' + b.name, 'pmat', [a.size[0], b.size[1]], \ 149 | '_c_pmt_getrsm(%s,%s)' % (a.expr, b.expr)) 150 | else: 151 | raise TypeError 152 | 153 | def _expfunc(a, b): 154 | typed_record = self.records[0] 155 | if _ismat(a) and b.name == "T": 156 | return Operand(a.name + '_T', 'pmat', [a.size[1], a.size[0]], \ 157 | '_c_pmt_pmat_tran(%s)' % (a.expr)) 158 | else: 159 | raise TypeError 160 | 161 | def _assignfunc(a, b): 162 | typed_record = self.records[0] 163 | if _ismat(a) and _ismat(b): 164 | return Operand(a.name + '_+_' + b.name, 'pmat', [a.size[0], a.size[1]], \ 165 | '_c_pmt_pmat_copy(%s,%s)' % (b.expr, a.expr)) 166 | else: 167 | raise TypeError 168 | 169 | self.opn = { 170 | "+": (_addfunc), 171 | "-": (_subfunc), 172 | "*": (_mulfunc), 173 | ".": (_expfunc), 174 | "\\": (_solvefunc), 175 | "=": (_assignfunc), 176 | } 177 | 178 | # recursive function that evaluates the expression stack 179 | def _evaluateStack(self, s): 180 | typed_record = self.records[0] 181 | var_dim_record = self.records[1] 182 | token = s.pop() 183 | if token == '.': 184 | token2 = s.pop() 185 | op2 = Operand(token2, [''], [0,0], '') 186 | op1 = self._evaluateStack(s) 187 | result = self.opn[token](op1, op2) 188 | if debug_flag: 189 | print(result) 190 | return result 191 | elif token in ['+','-', '*', '/', '@', '\\', '=']: 192 | op2 = self._evaluateStack(s) 193 | op1 = self._evaluateStack(s) 194 | result = self.opn[token](op1, op2) 195 | if debug_flag: 196 | print(result) 197 | return result 198 | else: 199 | return Operand(token, typed_record[token], var_dim_record[token], token) 200 | 201 | 202 | # the parse function that invokes all of the above. 203 | def parse(self, expr): 204 | 205 | typed_record = self.records[0] 206 | var_dim_record = self.records[1] 207 | 208 | if expr != "": 209 | # try parsing the input string 210 | try: 211 | L = equation.parseString(expr) 212 | targetvar = Operand(L[0], typed_record[L[0]], var_dim_record[L[0]], L[0]) 213 | except ParseException as err: 214 | print("Parse Failure", file=sys.stderr) 215 | print(err.line, file=sys.stderr) 216 | print(" " * (err.column - 1) + "^", file=sys.stderr) 217 | print(err, file=sys.stderr) 218 | raise 219 | 220 | # show result of parsing the input string 221 | if debug_flag: 222 | print(expr, "->", L) 223 | print("exprStack=", exprStack) 224 | 225 | # evaluate the stack of parsed operands, emitting C code. 226 | try: 227 | result = self._evaluateStack(exprStack) 228 | except TypeError: 229 | print( 230 | "Unsupported operation on right side of '%s'.\nCheck for missing or incorrect tags on non-scalar operands." 231 | % expr, 232 | file=sys.stderr, 233 | ) 234 | raise 235 | 236 | # create final assignment and print it. 237 | if debug_flag: 238 | print("var=", targetvar) 239 | if targetvar != None: 240 | try: 241 | result = self.opn['='](targetvar, result) 242 | except TypeError: 243 | print( 244 | "Left side tag does not match right side of '%s'" % expr, 245 | file=sys.stderr, 246 | ) 247 | raise 248 | else: 249 | print("Empty left side in '%s'" % expr, file=sys.stderr) 250 | raise TypeError 251 | 252 | 253 | ccode = result.expr 254 | return "\n%s;\n" % (ccode) 255 | -------------------------------------------------------------------------------- /prometeo/cpmt/pmat_blasfeo_wrapper.c: -------------------------------------------------------------------------------- 1 | #include "pmat_blasfeo_wrapper.h" 2 | #include "pvec_blasfeo_wrapper.h" 3 | #include "pmt_heap.h" 4 | #include "pmt_aux.h" 5 | #include 6 | #include 7 | 8 | void make_int_multiple_of(int num, int *size) { *size = (*size + num - 1) / num * num; } 9 | 10 | struct pmat * c_pmt_create_pmat(int m, int n) { 11 | // assign current address of global heap to pmat pointer 12 | struct pmat *pmat = (struct pmat *) ___c_pmt_8_heap; 13 | void *pmat_address = ___c_pmt_8_heap; 14 | 15 | // advance global heap address 16 | ___c_pmt_8_heap += sizeof(struct pmat); 17 | 18 | 19 | // create (zeroed) blasfeo_dmat and advance global heap 20 | c_pmt_assign_and_advance_blasfeo_dmat(m, n, &(pmat->bmat)); 21 | 22 | return (struct pmat *)(pmat_address); 23 | } 24 | 25 | void c_pmt_assign_and_advance_blasfeo_dmat(int m, int n, struct blasfeo_dmat **bmat) { 26 | // assign current address of global heap to blasfeo dmat pointer 27 | assert((size_t) ___c_pmt_8_heap % 8 == 0 && "pointer not 8-byte aligned!"); 28 | *bmat = (struct blasfeo_dmat *) ___c_pmt_8_heap; 29 | // 30 | // advance global heap address 31 | ___c_pmt_8_heap += sizeof(struct blasfeo_dmat); 32 | 33 | // assign current address of global heap to memory in blasfeo dmat 34 | char *pmem_ptr = (char *)___c_pmt_64_heap; 35 | // align_char_to(64, &pmem_ptr); 36 | ___c_pmt_64_heap = pmem_ptr; 37 | assert((size_t) ___c_pmt_64_heap % 64 == 0 && "dmat not 64-byte aligned!"); 38 | blasfeo_create_dmat(m, n, *bmat, ___c_pmt_64_heap); 39 | 40 | // advance global heap address 41 | int memsize = (*bmat)->memsize; 42 | make_int_multiple_of(64, &memsize); 43 | ___c_pmt_64_heap += memsize; 44 | 45 | // zero allocated memory 46 | int i; 47 | double *dA = (*bmat)->dA; 48 | int size = (*bmat)->memsize; 49 | for(i=0; ibmat->m; 60 | int nA = A->bmat->n; 61 | int nB = B->bmat->n; 62 | struct blasfeo_dmat *bA = A->bmat; 63 | struct blasfeo_dmat *bB = B->bmat; 64 | struct blasfeo_dmat *bC = C->bmat; 65 | struct blasfeo_dmat *bD = D->bmat; 66 | 67 | // printf("In dgemm\n"); 68 | // blasfeo_print_dmat(mA, nA, A->bmat, 0, 0); 69 | // blasfeo_print_dmat(mA, nA, B->bmat, 0, 0); 70 | // blasfeo_print_dmat(mA, nA, C->bmat, 0, 0); 71 | // blasfeo_print_dmat(mA, nA, D->bmat, 0, 0); 72 | 73 | blasfeo_dgemm_nn(mA, nB, nA, 1.0, bA, 0, 0, bB, 0, 0, 1, bC, 0, 0, bD, 0, 0); 74 | } 75 | 76 | struct pmat * _c_pmt_gemm_nn(struct pmat *A, struct pmat *B) { 77 | int mA = A->bmat->m; 78 | int nA = A->bmat->n; 79 | int nB = B->bmat->n; 80 | struct blasfeo_dmat *bA = A->bmat; 81 | struct blasfeo_dmat *bB = B->bmat; 82 | 83 | 84 | struct pmat *C = c_pmt_create_pmat(nA, nB); 85 | struct blasfeo_dmat *bC = C->bmat; 86 | 87 | // printf("In dgemm\n"); 88 | // blasfeo_print_dmat(mA, nA, A->bmat, 0, 0); 89 | // blasfeo_print_dmat(mA, nA, B->bmat, 0, 0); 90 | // blasfeo_print_dmat(mA, nA, C->bmat, 0, 0); 91 | // blasfeo_print_dmat(mA, nA, D->bmat, 0, 0); 92 | 93 | blasfeo_dgemm_nn(mA, nB, nA, 1.0, bA, 0, 0, bB, 0, 0, 1, bC, 0, 0, bC, 0, 0); 94 | return C; 95 | } 96 | 97 | void c_pmt_gemm_tn(struct pmat *A, struct pmat *B, struct pmat *C, struct pmat *D) { 98 | int mA = A->bmat->m; 99 | int nA = A->bmat->n; 100 | int nB = B->bmat->n; 101 | struct blasfeo_dmat *bA = A->bmat; 102 | struct blasfeo_dmat *bB = B->bmat; 103 | struct blasfeo_dmat *bC = C->bmat; 104 | struct blasfeo_dmat *bD = D->bmat; 105 | 106 | // printf("In dgemm\n"); 107 | // blasfeo_print_dmat(mA, nA, A->bmat, 0, 0); 108 | // blasfeo_print_dmat(mA, nA, B->bmat, 0, 0); 109 | // blasfeo_print_dmat(mA, nA, C->bmat, 0, 0); 110 | // blasfeo_print_dmat(mA, nA, D->bmat, 0, 0); 111 | 112 | blasfeo_dgemm_tn(nA, nB, mA, 1.0, bA, 0, 0, bB, 0, 0, 1, bC, 0, 0, bD, 0, 0); 113 | } 114 | 115 | void c_pmt_gemm_nt(struct pmat *A, struct pmat *B, struct pmat *C, struct pmat *D) { 116 | int mA = A->bmat->m; 117 | int nA = A->bmat->n; 118 | int mB = B->bmat->m; 119 | struct blasfeo_dmat *bA = A->bmat; 120 | struct blasfeo_dmat *bB = B->bmat; 121 | struct blasfeo_dmat *bC = C->bmat; 122 | struct blasfeo_dmat *bD = D->bmat; 123 | 124 | // printf("In dgemm\n"); 125 | // blasfeo_print_dmat(mA, nA, A->bmat, 0, 0); 126 | // blasfeo_print_dmat(mA, nA, B->bmat, 0, 0); 127 | // blasfeo_print_dmat(mA, nA, C->bmat, 0, 0); 128 | // blasfeo_print_dmat(mA, nA, D->bmat, 0, 0); 129 | 130 | blasfeo_dgemm_nt(mA, mB, nA, 1.0, bA, 0, 0, bB, 0, 0, 1, bC, 0, 0, bD, 0, 0); 131 | } 132 | 133 | 134 | void c_pmt_trmm_rlnn(struct pmat *A, struct pmat *B, struct pmat *D) { 135 | int nB = B->bmat->n; 136 | int mB = B->bmat->m; 137 | struct blasfeo_dmat *bA = A->bmat; 138 | struct blasfeo_dmat *bB = B->bmat; 139 | struct blasfeo_dmat *bD = D->bmat; 140 | 141 | // printf("In dgemm\n"); 142 | // blasfeo_print_dmat(mA, nA, A->bmat, 0, 0); 143 | // blasfeo_print_dmat(mA, nA, B->bmat, 0, 0); 144 | // blasfeo_print_dmat(mA, nA, C->bmat, 0, 0); 145 | // blasfeo_print_dmat(mA, nA, D->bmat, 0, 0); 146 | 147 | blasfeo_dtrmm_rlnn(mB, nB, 1.0, bA, 0, 0, bB, 0, 0, bD, 0, 0); 148 | } 149 | 150 | void c_pmt_syrk_ln(struct pmat *A, struct pmat *B, struct pmat *C, struct pmat *D) { 151 | int nA = A->bmat->n; 152 | int mA = A->bmat->m; 153 | struct blasfeo_dmat *bA = A->bmat; 154 | struct blasfeo_dmat *bB = B->bmat; 155 | struct blasfeo_dmat *bD = D->bmat; 156 | 157 | // printf("In dgemm\n"); 158 | // blasfeo_print_dmat(mA, nA, A->bmat, 0, 0); 159 | // blasfeo_print_dmat(mA, nA, B->bmat, 0, 0); 160 | // blasfeo_print_dmat(mA, nA, C->bmat, 0, 0); 161 | // blasfeo_print_dmat(mA, nA, D->bmat, 0, 0); 162 | 163 | blasfeo_dsyrk_ln(mA, nA, 1.0, A->bmat, 0, 0, B->bmat, 0, 0, 1.0, C->bmat, 0, 0, D->bmat, 0, 0); 164 | } 165 | 166 | void c_pmt_getrf(struct pmat *A, struct pmat *fact, int *ipiv) { 167 | int mA = A->bmat->m; 168 | struct blasfeo_dmat *bA = A->bmat; 169 | struct blasfeo_dmat *bfact = fact->bmat; 170 | 171 | // factorization 172 | blasfeo_dgetrf_rp(mA, mA, bA, 0, 0, bfact, 0, 0, ipiv); 173 | } 174 | 175 | void c_pmt_potrf(struct pmat *A, struct pmat *fact) { 176 | int mA = A->bmat->m; 177 | struct blasfeo_dmat *bA = A->bmat; 178 | struct blasfeo_dmat *bfact = fact->bmat; 179 | 180 | // factorization 181 | blasfeo_dpotrf_l(mA, bA, 0, 0, bfact, 0, 0); 182 | } 183 | 184 | void c_pmt_getrsv(struct pmat *fact, int *ipiv, struct pvec *rhs) { 185 | int mfact = fact->bmat->m; 186 | struct blasfeo_dmat *bfact = fact->bmat; 187 | struct blasfeo_dvec *brhs = rhs->bvec; 188 | 189 | // permute the r.h.s 190 | blasfeo_dvecpe(mfact, ipiv, brhs, 0); 191 | // triangular solves 192 | blasfeo_dtrsv_lnu(mfact, bfact, 0, 0, brhs, 0, brhs, 0); 193 | blasfeo_dtrsv_unn(mfact, bfact, 0, 0, brhs, 0, brhs, 0); 194 | } 195 | 196 | void c_pmt_getrsm(struct pmat *fact, int *ipiv, struct pmat *rhs) { 197 | int mfact = fact->bmat->m; 198 | struct blasfeo_dmat *bfact = fact->bmat; 199 | struct blasfeo_dmat *brhs = rhs->bmat; 200 | 201 | // permute the r.h.s 202 | blasfeo_drowpe(mfact, ipiv, brhs); 203 | // triangular solves 204 | blasfeo_dtrsm_llnu(mfact, mfact, 1.0, bfact, 0, 0, brhs, 0, 0, brhs, 0, 0); 205 | blasfeo_dtrsm_lunn(mfact, mfact, 1.0, bfact, 0, 0, brhs, 0, 0, brhs, 0, 0); 206 | } 207 | 208 | struct pmat * _c_pmt_getrsm(struct pmat *A, struct pmat *rhs) { 209 | int mA = A->bmat->m; 210 | struct blasfeo_dmat *brhs = rhs->bmat; 211 | struct blasfeo_dmat *bA = A->bmat; 212 | 213 | struct pmat *fact = c_pmt_create_pmat(mA, mA); 214 | struct blasfeo_dmat *bfact = fact->bmat; 215 | 216 | // permutation indeces 217 | int *ipiv; int_zeros(&ipiv, mA, 1); 218 | // factorization 219 | blasfeo_dgetrf_rp(mA, mA, bA, 0, 0, bfact, 0, 0, ipiv); 220 | // permute the r.h.s 221 | blasfeo_drowpe(mA, ipiv, brhs); 222 | // triangular solves 223 | blasfeo_dtrsm_llnu(mA, mA, 1.0, bfact, 0, 0, brhs, 0, 0, brhs, 0, 0); 224 | blasfeo_dtrsm_lunn(mA, mA, 1.0, bfact, 0, 0, brhs, 0, 0, brhs, 0, 0); 225 | return rhs; 226 | } 227 | 228 | void c_pmt_potrsm(struct pmat *fact, struct pmat *rhs) { 229 | int mrhs = rhs->bmat->m; 230 | int nrhs = rhs->bmat->n; 231 | struct blasfeo_dmat *bfact = fact->bmat; 232 | struct blasfeo_dmat *brhs = rhs->bmat; 233 | // struct blasfeo_dmat *bout = out->bmat; 234 | 235 | // triangular solves 236 | blasfeo_dtrsm_llnn(mrhs, nrhs, 1.0, bfact, 0, 0, brhs, 0, 0, brhs, 0, 0); 237 | 238 | // struct pmat * fact_tran = c_pmt_create_pmat(nrhs, nrhs); 239 | // struct blasfeo_dmat *bfact_tran = fact_tran->bmat; 240 | blasfeo_dtrtr_l(nrhs, bfact, 0, 0, bfact, 0, 0); 241 | // blasfeo_dgese(1, 1, 0.0, bfact, 1, 0); 242 | blasfeo_dtrsm_lunn(mrhs, nrhs, 1.0, bfact, 0, 0, brhs, 0, 0, brhs, 0, 0); 243 | // blasfeo_dtrsm_lltn(mrhs, nrhs, 1.0, bfact_tran, 0, 0, brhs, 0, 0, brhs, 0, 0); 244 | 245 | } 246 | 247 | void c_pmt_potrsv(struct pmat *fact, struct pvec *rhs) { 248 | int mfact = fact->bmat->m; 249 | struct blasfeo_dmat *bfact = fact->bmat; 250 | struct blasfeo_dvec *brhs = rhs->bvec; 251 | // struct blasfeo_dvec *bout = out->bvec; 252 | 253 | // triangular solves 254 | blasfeo_dtrsv_lnu(mfact, bfact, 0, 0, brhs, 0, brhs, 0); 255 | blasfeo_dtrsv_unn(mfact, bfact, 0, 0, brhs, 0, brhs, 0); 256 | } 257 | 258 | void c_pmt_gead(double alpha, struct pmat *A, struct pmat *B) { 259 | int mA = A->bmat->m; 260 | int nA = A->bmat->n; 261 | struct blasfeo_dmat *bA = A->bmat; 262 | struct blasfeo_dmat *bB = B->bmat; 263 | 264 | blasfeo_dgead(mA, nA, alpha, bA, 0, 0, bB, 0, 0); 265 | } 266 | 267 | 268 | struct pmat * _c_pmt_gead(double alpha, struct pmat *A, struct pmat *B) { 269 | int mA = A->bmat->m; 270 | int nA = A->bmat->n; 271 | struct blasfeo_dmat *bA = A->bmat; 272 | struct blasfeo_dmat *bB = B->bmat; 273 | 274 | struct pmat *C = c_pmt_create_pmat(mA, nA); 275 | 276 | c_pmt_pmat_copy(B, C); 277 | struct blasfeo_dmat *bC = C->bmat; 278 | 279 | blasfeo_dgead(mA, nA, alpha, bA, 0, 0, bC, 0, 0); 280 | return C; 281 | } 282 | 283 | void c_pmt_gemv_n(struct pmat *A, struct pvec *b, struct pvec *c, struct pvec *d) { 284 | int mA = A->bmat->m; 285 | int nA = A->bmat->n; 286 | int mb = b->bvec->m; 287 | struct blasfeo_dmat *bA = A->bmat; 288 | struct blasfeo_dvec *bb = b->bvec; 289 | struct blasfeo_dvec *bc = c->bvec; 290 | struct blasfeo_dvec *bd = d->bvec; 291 | 292 | // printf("In dgemm\n"); 293 | // blasfeo_print_dmat(mA, nA, A->bmat, 0, 0); 294 | // blasfeo_print_dmat(mA, nA, B->bmat, 0, 0); 295 | // blasfeo_print_dmat(mA, nA, C->bmat, 0, 0); 296 | // blasfeo_print_dmat(mA, nA, D->bmat, 0, 0); 297 | 298 | blasfeo_dgemv_n(mA, nA, 1.0, bA, 0, 0, bb, 0, 1.0, bc, 299 | 0, bd, 0); 300 | } 301 | // auxiliary 302 | void c_pmt_pmat_fill(struct pmat *A, double fill_value) { 303 | int m = A->bmat->m; 304 | int n = A->bmat->n; 305 | 306 | blasfeo_dgese(m, n, fill_value, A->bmat, 0, 0); 307 | } 308 | 309 | void c_pmt_pmat_set_el(struct pmat *A, int i, int j, double fill_value) { 310 | 311 | blasfeo_dgein1(fill_value, A->bmat, i, j); 312 | } 313 | 314 | void c_pmt_gecp(int m, int n, struct pmat *A, int ai, int aj, struct pmat *B, int bi, int bj) { 315 | struct blasfeo_dmat *bA = A->bmat; 316 | struct blasfeo_dmat *bB = B->bmat; 317 | blasfeo_dgecp(m, n, bA, ai, aj, bB, bi, bj); 318 | } 319 | 320 | double c_pmt_pmat_get_el(struct pmat *A, int i, int j) { 321 | 322 | blasfeo_dgeex1(A->bmat, i, j); 323 | } 324 | 325 | void c_pmt_pmat_copy(struct pmat *A, struct pmat *B) { 326 | int m = A->bmat->m; 327 | int n = A->bmat->n; 328 | double value; 329 | 330 | 331 | blasfeo_dgecp(m, n, A->bmat, 0, 0, B->bmat, 0, 0); 332 | // for(int i = 0; i < m; i++) 333 | // for(int j = 0; j < n; j++) { 334 | // value = blasfeo_dgeex1(A->bmat, i, j); 335 | // blasfeo_dgein1(value, B->bmat, i, j); 336 | // } 337 | } 338 | 339 | struct pmat * _c_pmt_pmat_copy(struct pmat *A, struct pmat *B) { 340 | int m = A->bmat->m; 341 | int n = A->bmat->n; 342 | double value; 343 | 344 | blasfeo_dgecp(m, n, A->bmat, 0, 0, B->bmat, 0, 0); 345 | // for(int i = 0; i < m; i++) 346 | // for(int j = 0; j < n; j++) { 347 | // value = blasfeo_dgeex1(A->bmat, i, j); 348 | // blasfeo_dgein1(value, B->bmat, i, j); 349 | // } 350 | 351 | return B; 352 | } 353 | 354 | void c_pmt_pmat_tran(struct pmat *A, struct pmat *B) { 355 | int m = A->bmat->m; 356 | int n = A->bmat->n; 357 | double value; 358 | 359 | blasfeo_dgetr(m, n, A->bmat, 0, 0, B->bmat, 0, 0); 360 | // for(int i = 0; i < m; i++) 361 | // for(int j = 0; j < n; j++) { 362 | // value = blasfeo_dgeex1(A->bmat, i, j); 363 | // blasfeo_dgein1(value, B->bmat, j, i); 364 | // } 365 | } 366 | 367 | struct pmat * _c_pmt_pmat_tran(struct pmat *A) { 368 | int m = A->bmat->m; 369 | int n = A->bmat->n; 370 | struct pmat *B = c_pmt_create_pmat(n, m); 371 | struct blasfeo_dmat *bB = B->bmat; 372 | double value; 373 | 374 | blasfeo_dgetr(m, n, A->bmat, 0, 0, B->bmat, 0, 0); 375 | // for(int i = 0; i < m; i++) 376 | // for(int j = 0; j < n; j++) { 377 | // value = blasfeo_dgeex1(A->bmat, i, j); 378 | // blasfeo_dgein1(value, B->bmat, j, i); 379 | // } 380 | return B; 381 | } 382 | 383 | void c_pmt_pmat_vcat(struct pmat *A, struct pmat *B, struct pmat *res) { 384 | int mA = A->bmat->m; 385 | int nA = A->bmat->n; 386 | int mB = B->bmat->m; 387 | int nB = B->bmat->n; 388 | double value; 389 | 390 | for(int i = 0; i < mA; i++) 391 | for(int j = 0; j < nA; j++) { 392 | value = blasfeo_dgeex1(A->bmat, i, j); 393 | blasfeo_dgein1(value, res->bmat, i, j); 394 | } 395 | for(int i = 0; i < mB; i++) 396 | for(int j = 0; j < nB; j++) { 397 | value = blasfeo_dgeex1(B->bmat, i, j); 398 | blasfeo_dgein1(value, res->bmat, mA + i, j); 399 | } 400 | } 401 | 402 | void c_pmt_pmat_hcat(struct pmat *A, struct pmat *B, struct pmat *res) { 403 | int mA = A->bmat->m; 404 | int nA = A->bmat->n; 405 | int mB = B->bmat->m; 406 | int nB = B->bmat->n; 407 | double value; 408 | 409 | for(int i = 0; i < mA; i++) 410 | for(int j = 0; j < nA; j++) { 411 | value = blasfeo_dgeex1(A->bmat, i, j); 412 | blasfeo_dgein1(value, res->bmat, i, j); 413 | } 414 | for(int i = 0; i < mB; i++) 415 | for(int j = 0; j < nB; j++) { 416 | value = blasfeo_dgeex1(B->bmat, i, j); 417 | blasfeo_dgein1(value, res->bmat, i, nA + j); 418 | } 419 | } 420 | void c_pmt_pmat_print(struct pmat *A) { 421 | int m = A->bmat->m; 422 | int n = A->bmat->n; 423 | 424 | blasfeo_print_dmat(m, n, A->bmat, 0, 0); 425 | } 426 | 427 | -------------------------------------------------------------------------------- /prometeo/linalg/pmat.py: -------------------------------------------------------------------------------- 1 | from ctypes import * 2 | from .pmat_blasfeo_wrapper import * 3 | from .pvec import * 4 | from .blasfeo_wrapper import * 5 | from multipledispatch import dispatch 6 | from abc import ABC 7 | 8 | class pmat_(ABC): 9 | pass 10 | 11 | class pmat(pmat_): 12 | blasfeo_dmat = None 13 | 14 | def __init__(self, m: int, n: int): 15 | self.blasfeo_dmat = c_pmt_create_blasfeo_dmat(m, n) 16 | self._m = m 17 | self._n = n 18 | 19 | @property 20 | def m(self): 21 | return self._m 22 | 23 | @property 24 | def n(self): 25 | return self._n 26 | 27 | def __getitem__(self, index): 28 | if isinstance(index, tuple): 29 | if len(index) != 2: 30 | raise Exception ('pmat subscript should be a 2-dimensional tuples, \ 31 | you have: {}\n. Exiting'.format(index)) 32 | if isinstance(index[0], int) and isinstance(index[1], int): 33 | if index[0] < 0 or index[0] > self.m or \ 34 | index[1] < 0 or index[1] > self.n: 35 | raise Exception('Invalid subscripting values. Exiting. \n') 36 | el = pmat_get(self, index[0], index[1]) 37 | return el 38 | elif isinstance(index[0], slice) and isinstance(index[1], slice): 39 | if index[0].start < 0 or index[0].stop > self.m or \ 40 | index[1].start < 0 or index[1].stop > self.n: 41 | raise Exception('Invalid subscripting values. Exiting. \n') 42 | m_value = index[0].stop - index[0].start 43 | n_value = index[1].stop - index[1].start 44 | submatrix = pmat(m_value, n_value) 45 | # TODO(andrea): there might be better performing implementations of this. 46 | # print(index[0].start) 47 | # import pdb; pdb.set_trace() 48 | for i in range(m_value): 49 | for j in range(n_value): 50 | # print(i,j) 51 | submatrix[i,j] = self[index[0].start+i, index[1].start+j] 52 | # print('\n\n') 53 | # pmat_print(submatrix) 54 | # print('\n\n') 55 | return submatrix 56 | else: 57 | raise Exception ('pmat subscript should be a 2-dimensional tuples, \ 58 | you have: {}\n. Exiting'.format(index)) 59 | 60 | def __setitem__(self, index, value): 61 | if isinstance(index, tuple): 62 | if len(index) != 2: 63 | raise Exception ('pmat subscript should be a 2-dimensional tuples, \ 64 | you have: {}\n. Exiting'.format(index)) 65 | if isinstance(index[0], int) and isinstance(index[1], int): 66 | if index[0] < 0 or index[0] > self.m or \ 67 | index[1] < 0 or index[1] > self.n: 68 | raise Exception('Invalid subscripting values. Exiting. \n') 69 | pmat_set(self, value, index[0], index[1]) 70 | elif isinstance(index[0], slice) and isinstance(index[1], slice): 71 | m_target = index[0].stop - index[0].start 72 | n_target = index[1].stop - index[1].start 73 | if m_target != value.m or n_target != value.n: 74 | raise Exception('Dimension mismatch: ({},{}) <- ({},{}). Exiting.'.format(m_target, n_target, value.m, value.n)) 75 | if index[0].start < 0 or index[0].stop > self.m or \ 76 | index[1].start < 0 or index[1].stop > self.n: 77 | raise Exception('Invalid subscripting values. Exiting. \n') 78 | # TODO(andrea): there might be better performing implementations of this. 79 | for i in range(m_target): 80 | for j in range(n_target): 81 | self[index[0].start+i,index[1].start+j] = value[i,j] 82 | else: 83 | raise Exception ('pmat subscripts must be 2-dimensional tuples, \ 84 | you have: {}\n. Exiting'.format(index)) 85 | else: 86 | raise Exception ('pmat subscripts must be 2-dimensional tuples, \ 87 | you have: {}\n. Exiting'.format(index)) 88 | 89 | # class pmat(pmat_): 90 | 91 | # blasfeo_dmat = None 92 | # _i = None 93 | # _j = None 94 | 95 | # def __init__(self, m: int, n: int): 96 | # self.blasfeo_dmat = c_pmt_create_blasfeo_dmat(m, n) 97 | 98 | # def __getitem__(self, index): 99 | # if self._i is not None: 100 | # self._j = index 101 | # el = self.my_get_item() 102 | # return el 103 | 104 | # self._i = index 105 | # return self 106 | 107 | # def __setitem__(self, index, value): 108 | # self._j = index 109 | # self.my_set_item(value) 110 | # return 111 | 112 | 113 | # def my_set_item(self, value): 114 | # pmat_set(self, value, self._i, self._j) 115 | # self._i = None 116 | # self._j = None 117 | # return 118 | 119 | # def my_get_item(self): 120 | # el = pmat_get(self, self._i, self._j) 121 | # self._i = None 122 | # self._j = None 123 | # return el 124 | 125 | # TODO(andrea): ideally one would have three levels: 126 | # 1) high-level 127 | # 2) intermediate-level 128 | # 3) low-level (BLASFEO wrapper) 129 | 130 | # high-level linear algebra 131 | @dispatch(pmat_) 132 | def __mul__(self, other): 133 | if self.n != other.m: 134 | raise Exception('__mul__: mismatching dimensions:' 135 | ' ({}, {}) x ({}, {})'.format(self.m, self.n, other.m, other.n)) 136 | 137 | res = pmat(self.m, other.n) 138 | pmat_fill(res, 0.0) 139 | zero_mat = pmat(self.m, other.n) 140 | pmat_fill(zero_mat, 0.0) 141 | pmt_gemm_nn(self, other, zero_mat, res) 142 | return res 143 | 144 | @dispatch(pvec_) 145 | def __mul__(self, other): 146 | if self.n != other.blasfeo_dvec.m: 147 | raise Exception('__mul__: mismatching dimensions:' 148 | ' ({}, {}) x ({},)'.format(self.m, self.n, other.blasfeo_dvec.m)) 149 | 150 | res = pvec(self.m) 151 | res.fill(0.0) 152 | zero_vec = pvec(self.m) 153 | zero_vec.fill(0.0) 154 | pmt_gemv_n(self, other, zero_vec, res) 155 | return res 156 | 157 | @dispatch(pmat_) 158 | def __add__(self, other): 159 | if self.m != other.m or self.n != other.n: 160 | raise Exception('__add__: mismatching dimensions:' 161 | ' ({}, {}) + ({}, {})'.format(self.m, self.n, other.m, other.n)) 162 | res = pmat(self.m, self.n) 163 | pmat_copy(other, res) 164 | pmt_gead(1.0, self, res) 165 | return res 166 | 167 | def __sub__(self, other): 168 | if self.m != other.m or self.n != other.n: 169 | raise Exception('__sub__: mismatching dimensions:' 170 | ' ({}, {}) + ({}, {})'.format(self.m, self.n, other.m, other.n)) 171 | res = pmat(self.m, self.n) 172 | pmat_copy(self, res) 173 | pmt_gead(-1.0, other, res) 174 | return res 175 | 176 | def pmat_fill(A: pmat, value: float): 177 | for i in range(A.m): 178 | for j in range(A.n): 179 | A[i,j] = value 180 | return 181 | 182 | def pmat_copy(A: pmat, B: pmat): 183 | if A.m != B.m or A.n != B.n: 184 | raise Exception('__copy__: mismatching dimensions:' 185 | ' ({}, {}) -> ({}, {})'.format(A.m, A.n, B.m, B.n)) 186 | for i in range(A.m): 187 | for j in range(A.n): 188 | B[i,j] = A[i,j] 189 | return 190 | 191 | def pmat_tran(A: pmat, B: pmat): 192 | if A.m != B.n or A.n != B.m: 193 | raise Exception('__tran__: mismatching dimensions:' 194 | ' ({}, {}) -> ({}, {})'.format(A.m, A.n, B.m, B.n)) 195 | for i in range(A.m): 196 | for j in range(A.n): 197 | B[j,i] = A[i,j] 198 | 199 | def pmat_vcat(A: pmat, B: pmat, res: pmat): 200 | if A.n != B.n or A.n != res.n or A.m + B.m != res.m: 201 | raise Exception('__vcat__: mismatching dimensions:' 202 | ' ({}, {}) ; ({}, {})'.format(A.m, A.n, B.m, B.n)) 203 | for i in range(A.m): 204 | for j in range(A.n): 205 | res[i,j] = A[i,j] 206 | for i in range(B.m): 207 | for j in range(B.n): 208 | res[A.m + i,j] = B[i,j] 209 | 210 | def pmat_hcat(A: pmat, B: pmat, res: pmat): 211 | if A.m != B.m or A.m != res.m or A.n + B.n != res.n: 212 | raise Exception('__hcat__: mismatching dimensions:' 213 | ' ({}, {}) , ({}, {})'.format(A.m, A.n, B.m, B.n)) 214 | for i in range(A.m): 215 | for j in range(A.n): 216 | res[i,j] = A[i,j] 217 | for i in range(B.m): 218 | for j in range(B.n): 219 | res[i,A.n + j] = B[i,j] 220 | 221 | # def pmt_getrsm(fact: pmat, ipiv: list, rhs: pmat): 222 | # # create permutation vector 223 | # c_ipiv = cast(create_string_buffer(sizeof(c_int)*A.m), POINTER(c_int)) 224 | # for i in range(A.n): 225 | # c_ipiv[i] = ipiv[i] 226 | # res = pmat(A.m, B.n) 227 | # # create permuted rhs 228 | # # pB = pmat(B.m, B.n) 229 | # pmat_copy(B, res) 230 | # pmt_rowpe(B.m, c_ipiv, res) 231 | # # solve 232 | # pmt_trsm_llnu(A, res) 233 | # pmt_trsm_lunu(A, res) 234 | # return res 235 | 236 | # def pmt_getrsv(fact: pmat, ipiv: list, rhs: pvec): 237 | # # create permutation vector 238 | # c_ipiv = cast(create_string_buffer(sizeof(c_int)*fact.m), POINTER(c_int)) 239 | # for i in range(fact.n): 240 | # c_ipiv[i] = ipiv[i] 241 | # # permuted rhs 242 | # pvec_copy(b, res) 243 | # pmt_vecpe(b.blasfeo_dvec.m, c_ipiv, res) 244 | # # solve 245 | # pmt_trsv_llnu(fact, rhs) 246 | # pmt_trsv_lunn(fact, rhs) 247 | # return 248 | 249 | def pmt_potrsm(fact: pmat, rhs: pmat): 250 | # solve 251 | pmt_trsm_llnn(fact, rhs) 252 | fact_tran = pmat(fact.m, fact.n) 253 | pmat_tran(fact, fact_tran) 254 | pmt_trsm_lunn(fact_tran, rhs) 255 | return 256 | 257 | def pmt_potrsv(fact: pmat, rhs: pvec): 258 | # solve 259 | pmt_trsv_llnu(fact, rhs) 260 | pmt_trsv_lunn(fact, rhs) 261 | return 262 | 263 | # intermediate-level linear algebra 264 | # def pmt_gemm(A: pmat, B: pmat, C: pmat, D: pmat): 265 | def pmt_gemm(*argv): 266 | if len(argv) < 3: 267 | raise Exception('Invalid number of arguments') 268 | A = argv[0] 269 | B = argv[1] 270 | if len(argv) == 4: 271 | C = argv[2] 272 | D = argv[3] 273 | else: 274 | C = argv[2] 275 | D = argv[2] 276 | 277 | if A.n != B.m or A.m != C.m or B.n != C.n or C.m != D.m or C.n != D.n: 278 | raise Exception('pmt_gemm: mismatching dimensions:' 279 | ' ({}, {}) <- ({},{}) + ({}, {}) x ({}, {})'.format(\ 280 | D.m, D.n, C.m, C.n, A.m, A.n, B.m, B.n)) 281 | 282 | c_pmt_dgemm_nn(A, B, C, D) 283 | return 284 | 285 | def pmt_gemm_nn(*argv): 286 | if len(argv) < 3: 287 | raise Exception('Invalid number of arguments') 288 | A = argv[0] 289 | B = argv[1] 290 | if len(argv) == 4: 291 | C = argv[2] 292 | D = argv[3] 293 | else: 294 | C = argv[2] 295 | D = argv[2] 296 | if A.n != B.m or A.m != C.m or B.n != C.n or C.m != D.m or C.n != D.n: 297 | raise Exception('pmt_gemm_nn: mismatching dimensions:' 298 | ' ({}, {}) <- ({},{}) + ({}, {}) x ({}, {})'.format(\ 299 | D.m, D.n, C.m, C.n, A.m, A.n, B.m, B.n)) 300 | 301 | c_pmt_dgemm_nn(A, B, C, D) 302 | return 303 | 304 | def pmt_gemm_nt(*argv): 305 | if len(argv) < 3: 306 | raise Exception('Invalid number of arguments') 307 | A = argv[0] 308 | B = argv[1] 309 | if len(argv) == 4: 310 | C = argv[2] 311 | D = argv[3] 312 | else: 313 | C = argv[2] 314 | D = argv[2] 315 | if A.n != B.n or A.m != C.m or B.m != C.n or C.m != D.m or C.n != D.n: 316 | raise Exception('pmt_gemm_nt: mismatching dimensions:' 317 | ' ({}, {}) <- ({},{}) + ({}, {}) x ({}, {})^T'.format(\ 318 | D.m, D.n, C.m, C.n, A.m, A.n, B.m, B.n)) 319 | c_pmt_dgemm_nt(A, B, C, D) 320 | return 321 | 322 | def pmt_gemm_tn(*argv): 323 | if len(argv) < 3: 324 | raise Exception('Invalid number of arguments') 325 | A = argv[0] 326 | B = argv[1] 327 | if len(argv) == 4: 328 | C = argv[2] 329 | D = argv[3] 330 | else: 331 | C = argv[2] 332 | D = argv[2] 333 | if A.m != B.m or A.n != C.m or B.n != C.n or C.m != D.m or C.n != D.n: 334 | raise Exception('pmt_gemm_tn: mismatching dimensions:' 335 | ' ({}, {}) <- ({},{}) + ({}, {})^T x ({}, {})'.format(\ 336 | D.m, D.n, C.m, C.n, A.m, A.n, B.m, B.n)) 337 | 338 | c_pmt_dgemm_tn(A, B, C, D) 339 | return 340 | 341 | def pmt_gemm_tt(*argv): 342 | if len(argv) < 3: 343 | raise Exception('Invalid number of arguments') 344 | A = argv[0] 345 | B = argv[1] 346 | if len(argv) == 4: 347 | C = argv[2] 348 | D = argv[3] 349 | else: 350 | C = argv[2] 351 | D = argv[2] 352 | if A.m != B.n or A.n != C.m or B.m != C.n or C.m != D.m or C.n != D.n: 353 | raise Exception('pmt_gemm_tt: mismatching dimensions:' 354 | ' ({}, {}) <- ({},{}) + ({}, {})^T x ({}, {})^T'.format(\ 355 | D.m, D.n, C.m, C.n, A.m, A.n, B.m, B.n)) 356 | c_pmt_dgemm_tt(A, B, C, D) 357 | return 358 | 359 | # B <= B + alpha*A 360 | def pmt_gead(alpha: float, A: pmat, B: pmat): 361 | if A.m != B.m or A.n != B.n: 362 | raise Exception('pmt_dgead: mismatching dimensions:' 363 | '({},{}) + ({}, {})'.format(A.m, A.n, B.m, B.n)) 364 | c_pmt_dgead(alpha, A, B) 365 | return 366 | 367 | def pmt_rowpe(m: int, ipiv: POINTER(c_int), A: pmat): 368 | c_pmt_drowpe(m, ipiv, A) 369 | return 370 | 371 | def pmt_trsm_llnu(A: pmat, B: pmat): 372 | c_pmt_trsm_llnu(A, B) 373 | return 374 | 375 | def pmt_trsm_lunn(A: pmat, B: pmat): 376 | c_pmt_trsm_lunn(A, B) 377 | return 378 | 379 | def pmt_trsm_llnn(A: pmat, B: pmat): 380 | c_pmt_trsm_llnn(A, B) 381 | return 382 | 383 | def pmt_trsv_llnu(A: pmat, b: pvec): 384 | c_pmt_trsv_llnu(A, b) 385 | return 386 | 387 | def pmt_trsv_lunn(A: pmat, b: pvec): 388 | c_pmt_trsv_lunn(A, b) 389 | return 390 | 391 | def pmt_getrf(A: pmat, fact: pmat, ipiv: list): 392 | # create permutation vector 393 | c_ipiv = cast(create_string_buffer(sizeof(c_int)*A.m), POINTER(c_int)) 394 | # factorize 395 | c_pmt_getrf(A, fact, c_ipiv) 396 | for i in range(A.n): 397 | ipiv[i] = c_ipiv[i] 398 | return 399 | 400 | def pmt_potrf(A: pmat, fact: pmat): 401 | # factorize 402 | c_pmt_potrf(A, fact) 403 | return 404 | 405 | def pmt_gemv_n(A: pmat, b: pvec, c: pvec, d: pvec): 406 | c_pmt_dgemv_n(A, b, c, d) 407 | return 408 | 409 | # auxiliary functions 410 | def pmt_set_data(M: pmat, data: POINTER(c_double)): 411 | c_pmt_set_blasfeo_dmat(M.blasfeo_dmat, data) 412 | return 413 | 414 | def pmat_set(M: pmat, value, i, j): 415 | c_pmt_set_blasfeo_dmat_el(value, M.blasfeo_dmat, i, j) 416 | return 417 | 418 | def pmat_get(M: pmat, i, j): 419 | el = c_pmt_get_blasfeo_dmat_el(M.blasfeo_dmat, i, j) 420 | return el 421 | 422 | def pmat_print(M: pmat): 423 | c_pmt_print_blasfeo_dmat(M) 424 | return 425 | 426 | 427 | --------------------------------------------------------------------------------