├── src └── vegasflow │ ├── tests │ ├── __init__.py │ ├── test_config.py │ ├── test_gradients.py │ ├── test_utils.py │ ├── test_misc.py │ └── test_algs.py │ ├── __init__.py │ ├── plain.py │ ├── configflow.py │ ├── utils.py │ ├── vflowplus.py │ └── vflow.py ├── .github ├── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md └── workflows │ ├── pythonpublish.yml │ └── pytest.yml ├── examples ├── cuda │ ├── integrand.h │ ├── cuda_example.py │ ├── makefile │ ├── integrand.cu.cpp │ └── integrand.cpp ├── cluster_dask.py ├── simgauss_tf.py ├── basic_example.py ├── example_eager.py ├── multidimensional_integral.py ├── retracing.py ├── asian_options_tf.py ├── simgauss_cffi.py ├── histogram_ex.py ├── multiple_integrals.py ├── example_pineappl.py ├── drellyan_lo_tf.py └── singletop_lo_tf.py ├── .readthedocs.yml ├── doc ├── Makefile └── source │ ├── apisrc │ └── vegasflow.rst │ ├── conf.py │ ├── intalg.rst │ ├── index.rst │ ├── examples.rst │ └── how_to.rst ├── PKGBUILD ├── pyproject.toml ├── .gitignore ├── README.md ├── LICENSE └── .pylintrc /src/vegasflow/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | ### The problem 11 | 12 | Brief description of the problem you want to solve. 13 | 14 | ### Proposed solution 15 | 16 | Please share any possible solutions for the problem you are thinking of. 17 | 18 | ### Are you available/want to contribute? 19 | 20 | Yes/No 21 | -------------------------------------------------------------------------------- /src/vegasflow/__init__.py: -------------------------------------------------------------------------------- 1 | """Monte Carlo integration with Tensorflow""" 2 | 3 | from vegasflow.configflow import DTYPE, DTYPEINT, float_me, int_me, run_eager 4 | from vegasflow.plain import PlainFlow, plain_sampler, plain_wrapper 5 | 6 | # Expose the main interfaces 7 | from vegasflow.vflow import VegasFlow, vegas_sampler, vegas_wrapper 8 | from vegasflow.vflowplus import VegasFlowPlus, vegasflowplus_sampler, vegasflowplus_wrapper 9 | 10 | __version__ = "1.4.0" 11 | -------------------------------------------------------------------------------- /examples/cuda/integrand.h: -------------------------------------------------------------------------------- 1 | #ifndef KERNEL_INTEGRAND_ 2 | #define KERNEL_INTEGRAND_ 3 | 4 | namespace tensorflow { 5 | using Eigen::GpuDevice; 6 | 7 | template 8 | struct IntegrandOpFunctor { 9 | void operator()(const Device &d, const T *input, T *output, const int nevents, const int dims); 10 | }; 11 | 12 | #if KERNEL_CUDA 13 | template 14 | struct IntegrandOpFunctor { 15 | void operator()(const Eigen::GpuDevice &d, const T *input, T *output, const int nevents, const int dims); 16 | }; 17 | #endif 18 | 19 | } 20 | 21 | #endif 22 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Set the version of Python and other tools you might need 8 | build: 9 | os: ubuntu-20.04 # is required: see https://github.com/readthedocs/readthedocs.org/issues/8912 10 | tools: 11 | python: "3.10" 12 | 13 | # Build documentation in the docs/ directory with Sphinx 14 | sphinx: 15 | configuration: doc/source/conf.py 16 | 17 | # Optionally set requirements required to build your docs 18 | python: 19 | install: 20 | - method: pip 21 | path: . 22 | extra_requirements: 23 | - docs 24 | system_packages: true 25 | -------------------------------------------------------------------------------- /.github/workflows/pythonpublish.yml: -------------------------------------------------------------------------------- 1 | name: Python publication 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v1 12 | - name: Set up Python 13 | uses: actions/setup-python@v1 14 | with: 15 | python-version: '3.x' 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install setuptools wheel twine 20 | - name: Build and publish 21 | env: 22 | TWINE_USERNAME: ${{ secrets.TWINE_USER }} 23 | TWINE_PASSWORD: ${{ secrets.TWINE_PASS }} 24 | run: | 25 | python setup.py sdist bdist_wheel 26 | twine upload dist/* 27 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | 22 | view: html 23 | $(BROWSER) build/html/index.html 24 | -------------------------------------------------------------------------------- /examples/cuda/cuda_example.py: -------------------------------------------------------------------------------- 1 | from vegasflow.configflow import DTYPE, DTYPEINT 2 | import time 3 | import numpy as np 4 | import tensorflow as tf 5 | from vegasflow.plain import plain_wrapper 6 | 7 | # MC integration setup 8 | dim = 4 9 | ncalls = np.int32(1e6) 10 | n_iter = 5 11 | 12 | integrand_module = tf.load_op_library('./integrand.so') 13 | 14 | @tf.function 15 | def wrapper_integrand(xarr, **kwargs): 16 | return integrand_module.integrand_op(xarr) 17 | 18 | @tf.function 19 | def fully_python_integrand(xarr, **kwargs): 20 | return tf.reduce_sum(xarr, axis=1) 21 | 22 | if __name__ == "__main__": 23 | print(f"VEGAS MC, ncalls={ncalls}:") 24 | start = time.time() 25 | ncalls = 10*ncalls 26 | r = plain_wrapper(wrapper_integrand, dim, n_iter, ncalls) 27 | end = time.time() 28 | print(f"Vegas took: time (s): {end-start}") 29 | -------------------------------------------------------------------------------- /.github/workflows/pytest.yml: -------------------------------------------------------------------------------- 1 | name: pytest 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | max-parallel: 3 11 | matrix: 12 | python-version: [3.9, '3.12'] 13 | 14 | steps: 15 | - uses: actions/checkout@v1 16 | - name: Set up Python ${{ matrix.python-version }} 17 | uses: actions/setup-python@v1 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | - name: Install dependencies and package 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install . 24 | - name: Lint with pylint 25 | run: | 26 | pip install pylint 27 | # Error out only in actual errors 28 | pylint src/*/*.py -E -d E1123,E1120 29 | pylint src/*/*.py --exit-zero 30 | - name: Test with pytest 31 | run: | 32 | pip install pytest 33 | pytest 34 | -------------------------------------------------------------------------------- /examples/cluster_dask.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example: cluster usage 3 | 4 | Basic example of running a VegasFlow job on a distributed system 5 | using dask and the SLURMCluster backend 6 | """ 7 | 8 | from dask_jobqueue import SLURMCluster # pylint: disable=import-error 9 | from vegasflow.vflow import VegasFlow 10 | import tensorflow as tf 11 | 12 | 13 | def integrand(xarr, **kwargs): 14 | return tf.reduce_sum(xarr, axis=1) 15 | 16 | 17 | if __name__ == "__main__": 18 | cluster = SLURMCluster( 19 | memory="2g", 20 | processes=1, 21 | cores=4, 22 | queue="", 23 | project="", 24 | job_extra=["--get-user-env"], 25 | ) 26 | 27 | mc_instance = VegasFlow(4, int(1e6), events_limit=int(1e5)) 28 | cluster.scale(jobs=10) 29 | mc_instance.set_distribute(cluster) 30 | mc_instance.compile(integrand) 31 | mc_instance.run_integration(5) 32 | -------------------------------------------------------------------------------- /doc/source/apisrc/vegasflow.rst: -------------------------------------------------------------------------------- 1 | vegasflow package 2 | ================= 3 | 4 | Module contents 5 | --------------- 6 | 7 | .. automodule:: vegasflow 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | Submodules 13 | ---------- 14 | 15 | vegasflow.vflow module 16 | ---------------------- 17 | 18 | .. automodule:: vegasflow.vflow 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | vegasflow.plain module 24 | ---------------------- 25 | 26 | .. automodule:: vegasflow.plain 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | vegasflow.monte_carlo module 32 | ---------------------------- 33 | 34 | .. automodule:: vegasflow.monte_carlo 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | vegasflow.utils module 40 | ---------------------------- 41 | 42 | .. automodule:: vegasflow.utils 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /PKGBUILD: -------------------------------------------------------------------------------- 1 | # Maintainer: Juacrumar 2 | 3 | pkgname=python-vegasflow 4 | _name=vegasflow 5 | pkgver=1.0.2 6 | pkgrel=1 7 | pkgdesc='Monte Carlo integration library written in Python and based on the TensorFlow framework' 8 | arch=('any') 9 | url="https://vegasflow.readthedocs.io/" 10 | license=('GPL3') 11 | depends=("python>=3.6" 12 | "python-tensorflow" 13 | "python-joblib" 14 | "python-numpy") 15 | optdepends=("python-cffi: interfacing vegasflow with C code" 16 | "python-tensorflow-cuda: GPU support") 17 | # checkdepends=("python-pytest") 18 | provides=("vegasflow") 19 | changelog= 20 | source=("https://github.com/N3PDF/vegasflow/archive/v.${pkgver}.tar.gz") 21 | md5sums=("118fa9906f588ab7ecd320728c478ade") 22 | 23 | prepare() { 24 | cd "$_name-v.$pkgver" 25 | } 26 | 27 | # check() { 28 | # cd "$_name-v.$pkgver" 29 | # pytest 30 | # } 31 | 32 | build() { 33 | cd "$_name-v.$pkgver" 34 | python setup.py build 35 | } 36 | 37 | package() { 38 | cd "$_name-v.$pkgver" 39 | python setup.py install --root="$pkgdir" --optimize=2 --skip-build 40 | } 41 | -------------------------------------------------------------------------------- /examples/cuda/makefile: -------------------------------------------------------------------------------- 1 | target_lib=integrand.so 2 | 3 | TF_CFLAGS=`python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))' 2> /dev/null` 4 | TF_LFLAGS=`python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))' 2>/dev/null` 5 | 6 | CXX=g++ 7 | CXFLAGS=-std=c++17 -shared -fPIC -O2 8 | KERNEL_DEF=-D KERNEL_CUDA=1 9 | NCCFLAGS=-std=c++17 $(KERNEL_DEF) -x cu -Xcompiler -fPIC --disable-warnings 10 | 11 | # Check whether there's nvcc 12 | ifeq (,$(shell which nvcc 2>/dev/null)) 13 | else 14 | NCC:=nvcc 15 | NCCLIB:=$(subst bin/nvcc,lib64, $(shell which nvcc)) 16 | CXFLAGS+=$(KERNEL_DEF) -L$(NCCLIB) -lcudart 17 | kernel_comp=integrand.cu.o 18 | endif 19 | 20 | .PHONY: run clean 21 | 22 | run: $(target_lib) 23 | @python cuda_example.py 24 | 25 | %.cu.o: %.cu.cpp 26 | @echo "[$(NCC)] Integrating cuda kernel..." 27 | @$(NCC) $(NCCFLAGS) -c -o $@ $< $(TF_CFLAGS) 28 | 29 | %.so: %.cpp $(kernel_comp) 30 | @echo "[$(CXX)] Integrating operator..." 31 | @$(CXX) $(CXFLAGS) $(KERNEL) -o $@ $^ $(TF_CFLAGS) $(TF_LFLAGS) 32 | 33 | clean: 34 | rm -f $(target_lib) $(kernel_comp) 35 | -------------------------------------------------------------------------------- /examples/simgauss_tf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example: basic integration 3 | 4 | Basic example using the vegas_wrapper helper 5 | """ 6 | 7 | from vegasflow.configflow import DTYPE 8 | import time 9 | import numpy as np 10 | import tensorflow as tf 11 | from vegasflow.vflow import vegas_wrapper 12 | from vegasflow import PlainFlow 13 | from vegasflow.plain import plain_wrapper 14 | 15 | 16 | # MC integration setup 17 | dim = 4 18 | ncalls = np.int32(1e5) 19 | n_iter = 5 20 | 21 | 22 | @tf.function 23 | def symgauss(xarr, **kwargs): 24 | """symgauss test function""" 25 | n_dim = xarr.shape[-1] 26 | a = tf.constant(0.1, dtype=DTYPE) 27 | n100 = tf.cast(100 * n_dim, dtype=DTYPE) 28 | pref = tf.pow(1.0 / a / np.sqrt(np.pi), n_dim) 29 | coef = tf.reduce_sum(tf.range(n100 + 1)) 30 | coef += tf.reduce_sum(tf.square((xarr - 1.0 / 2.0) / a), axis=1) 31 | coef -= (n100 + 1) * n100 / 2.0 32 | return pref * tf.exp(-coef) 33 | 34 | 35 | if __name__ == "__main__": 36 | """Testing several different integrations""" 37 | print(f"VEGAS MC, ncalls={ncalls}:") 38 | start = time.time() 39 | r = vegas_wrapper(symgauss, dim, n_iter, ncalls) 40 | end = time.time() 41 | print(f"Vegas took: time (s): {end-start}") 42 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | ### Description 11 | 12 | Please, describe briefly what the issue is 13 | 14 | ### Code example 15 | 16 | If possible, write a minimum working example that reproduces the bug, 17 | e.g: 18 | 19 | ```python 20 | import vegasflow 21 | vegasflow.broken_function() 22 | ``` 23 | 24 | ### Additional information 25 | 26 | Does the problem occur in CPU or GPU? 27 | If GPU, how many? Which version of Cuda do you have? 28 | 29 | ```bash 30 | nvcc --version 31 | ``` 32 | 33 | Please include the version of python, vegasflow and tensorflow that you are running. 34 | Running the following python script will produce useful information: 35 | 36 | ```python 37 | import tensorflow as tf 38 | import sys 39 | from tensorflow.python.framework import test_util 40 | import vegasflow 41 | 42 | print(f"Python version: {sys.version}") 43 | print(f"Vegasflow: {vegasflow.__version__}") 44 | print(f"Tensorflow: {tf.__version__}") 45 | print(f"tf-mkl: {test_util.IsMklEnabled()}") 46 | print(f"tf-cuda: {tf.test.is_built_with_cuda()}") 47 | print(f"tf-cuda: {tf.test.is_built_with_rocm()}") 48 | print(f"GPU available: {tf.test.is_gpu_available()}") 49 | ``` 50 | -------------------------------------------------------------------------------- /examples/basic_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example: basic integration 3 | 4 | Very basic example with a simple integrand 5 | """ 6 | 7 | from vegasflow import VegasFlow, float_me, run_eager 8 | import time 9 | import numpy as np 10 | import tensorflow as tf 11 | 12 | 13 | # MC integration setup 14 | dim = 4 15 | ncalls = int(1e5) 16 | n_iter = 5 17 | 18 | 19 | def symgauss(xarr): 20 | """symgauss test function""" 21 | n_dim = xarr.shape[-1] 22 | a = float_me(0.1) 23 | n100 = float_me(100 * n_dim) 24 | pref = tf.pow(1.0 / a / np.sqrt(np.pi), n_dim) 25 | coef = tf.reduce_sum(tf.range(n100 + 1)) 26 | coef += tf.reduce_sum(tf.square((xarr - 1.0 / 2.0) / a), axis=1) 27 | coef -= (n100 + 1) * n100 / 2.0 28 | return pref * tf.exp(-coef) 29 | 30 | 31 | if __name__ == "__main__": 32 | """Testing several different integrations""" 33 | print(f"VEGAS MC, ncalls={ncalls}:") 34 | start = time.time() 35 | vegas_instance = VegasFlow(dim, ncalls, events_limit=int(7e4)) 36 | vegas_instance.compile(symgauss) 37 | result = vegas_instance.run_integration(n_iter) 38 | end = time.time() 39 | print(f"Vegas took: time (s): {end-start}") 40 | print("Change the number of events and freeze the grid...") 41 | vegas_instance.freeze_grid() 42 | vegas_instance.run_integration(n_iter) 43 | -------------------------------------------------------------------------------- /examples/cuda/integrand.cu.cpp: -------------------------------------------------------------------------------- 1 | #if KERNEL_CUDA 2 | #define EIGEN_USE_GPU 3 | 4 | #include "tensorflow/core/framework/op_kernel.h" 5 | #include "integrand.h" 6 | 7 | using namespace tensorflow; 8 | using GPUDevice = Eigen::GpuDevice; 9 | 10 | // This is the kernel that does the actual computation on device 11 | template 12 | __global__ void IntegrandOpKernel(const T *input, T *output, const int nevents, const int ndim) { 13 | const auto gid = blockIdx.x*blockDim.x + threadIdx.x; 14 | // note: this an example of usage, not an example of a very optimal anything... 15 | for (int i = gid; i < nevents; i += blockDim.x*gridDim.x) { 16 | output[i] = 0.0; 17 | for (int j = 0; j < ndim; j++) { 18 | output[i] += input[i,j]; 19 | } 20 | } 21 | } 22 | 23 | // But it still needs to be launched from within C++ 24 | // this bit is to be compared with the functor at the top of integrand.cpp 25 | template 26 | void IntegrandOpFunctor::operator()(const GPUDevice &d, const T *input, T *output, const int nevents, const int dims) { 27 | const int block_count = 1024; 28 | const int thread_per_block = 20; 29 | IntegrandOpKernel<<>>(input, output, nevents, dims); 30 | } 31 | 32 | template struct IntegrandOpFunctor; 33 | 34 | 35 | #endif 36 | -------------------------------------------------------------------------------- /examples/example_eager.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example: eager mode integrand 3 | 4 | Demonstrates how to run a non-tensorflow integrand using VegasFlow 5 | """ 6 | 7 | from vegasflow import run_eager, vegas_wrapper 8 | import time 9 | import numpy as np 10 | from scipy.special import expit 11 | import tensorflow as tf 12 | 13 | # Enable eager mode 14 | run_eager(True) 15 | 16 | # MC integration setup 17 | dim = 4 18 | ncalls = np.int32(1e5) 19 | n_iter = 5 20 | 21 | 22 | @tf.function 23 | def symgauss_sigmoid(xarr, **kwargs): 24 | """symgauss test function""" 25 | n_dim = xarr.shape[-1] 26 | a = 0.1 27 | pref = pow(1.0 / a / np.sqrt(np.pi), n_dim) 28 | coef = np.sum(np.arange(1, 101)) 29 | # Tensorflow variable will be casted down by numpy 30 | # you can directly access their numpy representation with .numpy() 31 | xarr_sq = np.square((xarr - 1.0 / 2.0) / a) 32 | coef += np.sum(xarr_sq, axis=1) 33 | coef -= 100.0 * 101.0 / 2.0 34 | return expit(xarr[:, 0].numpy()) * (pref * np.exp(-coef)) 35 | 36 | 37 | if __name__ == "__main__": 38 | """Testing several different integrations""" 39 | print(f"VEGAS MC, ncalls={ncalls}:") 40 | start = time.time() 41 | ncalls = 10 * ncalls 42 | r = vegas_wrapper(symgauss_sigmoid, dim, n_iter, ncalls, compilable=True) 43 | end = time.time() 44 | print(f"Vegas took: time (s): {end-start}") 45 | -------------------------------------------------------------------------------- /examples/multidimensional_integral.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example: basic multidimensional integral 3 | 4 | The output of the function is not an scalar but a vector v. 5 | In this case it is necessary to tell Vegas which is the main dimension 6 | (i.e., the output dimension the grid should adapt to) 7 | 8 | Note that the integrand should have an output of the same shape as the tensor of random numbers 9 | the shape of the tensor of random numbers and of the output is (nevents, ndim) 10 | """ 11 | from vegasflow import VegasFlow, run_eager 12 | 13 | run_eager() 14 | import tensorflow as tf 15 | 16 | # MC integration setup 17 | dim = 3 18 | ncalls = int(1e4) 19 | n_iter = 5 20 | 21 | 22 | @tf.function 23 | def test_function(xarr): 24 | res = tf.square((xarr - 1.0) ** 2) 25 | return tf.exp(-res) 26 | 27 | 28 | if __name__ == "__main__": 29 | print("Testing a multidimensional integration") 30 | vegas = VegasFlow(dim, ncalls, main_dimension=1) 31 | vegas.compile(test_function) 32 | all_results, all_err = vegas.run_integration(2) 33 | try: 34 | for result, error in zip(all_results, all_err): 35 | print(f"{result = :.5} +- {error:.5}") 36 | except TypeError: 37 | # So that the example works also if the integrand is made scalar 38 | result = all_results 39 | error = all_err 40 | print(f"{result = :.5} +- {error:.5}") 41 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "vegasflow" 7 | version = "1.4.0" 8 | description="Hardware agnostic Monte Carlo integration" 9 | authors = [ 10 | {name = "S. Carrazza", email = "stefano.carraza@cern.ch"}, 11 | {name = "J.Cruz-Martinez", email = "juacrumar@lairen.eu"} 12 | ] 13 | classifiers=[ 14 | 'Operating System :: Unix', 15 | 'Programming Language :: Python', 16 | 'Programming Language :: Python :: 3', 17 | 'Topic :: Scientific/Engineering', 18 | 'Topic :: Scientific/Engineering :: Physics', 19 | ] 20 | dependencies = [ 21 | "joblib", 22 | "numpy", 23 | "tensorflow>2.2" 24 | ] 25 | 26 | [project.optional-dependencies] 27 | docs = [ 28 | 'sphinx', 29 | 'sphinx_rtd_theme', 30 | 'sphinxcontrib-bibtex', 31 | ] 32 | examples = [ 33 | 'cffi', 34 | 'pineappl', 35 | 'pdfflow', 36 | 'scipy' 37 | ] 38 | benchmark = [ 39 | 'vegas', # Lepage's Vegas for benchmarking 40 | ] 41 | distribute = [ 42 | 'dask', 43 | 'distributed', 44 | 'dask-jobqueue', 45 | ] 46 | 47 | [tool.black] 48 | line-length = 100 49 | skip_magic_trailing_comma = true 50 | 51 | [tool.isort] 52 | atomic = true 53 | line_length = 120 54 | profile = "black" # https://black.readthedocs.io/en/stable/guides/using_black_with_other_tools.html#custom-configuration 55 | skip_gitignore = true 56 | force_sort_within_sections = true 57 | -------------------------------------------------------------------------------- /examples/retracing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Retracing example in VegasFlowPlus 3 | """ 4 | 5 | from vegasflow import VegasFlowPlus, VegasFlow, PlainFlow 6 | from vegasflow.configflow import DTYPE, DTYPEINT, run_eager, float_me 7 | import time 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | # MC integration setup 12 | dim = 2 13 | ncalls = np.int32(1e3) 14 | n_iter = 5 15 | 16 | @tf.function(input_signature=[ 17 | tf.TensorSpec(shape=[None,dim], dtype=DTYPE), 18 | tf.TensorSpec(shape=[], dtype=DTYPEINT), 19 | tf.TensorSpec(shape=[None], dtype=DTYPE) 20 | ] 21 | ) 22 | def symgauss(xarr, n_dim=None, weight=None, **kwargs): 23 | """symgauss test function""" 24 | if n_dim is None: 25 | n_dim = xarr.shape[-1] 26 | a = tf.constant(0.1, dtype=DTYPE) 27 | n100 = tf.cast(100 * n_dim, dtype=DTYPE) 28 | pref = tf.pow(1.0 / a / np.sqrt(np.pi), float_me(n_dim)) 29 | coef = tf.reduce_sum(tf.range(n100 + 1)) 30 | coef += tf.reduce_sum(tf.square((xarr - 1.0 / 2.0) / a), axis=1) 31 | coef -= (n100 + 1) * n100 / 2.0 32 | return pref * tf.exp(-coef) 33 | 34 | 35 | 36 | if __name__ == "__main__": 37 | """Testing several different integrations""" 38 | 39 | # run_eager() 40 | vegas_instance = VegasFlowPlus(dim, ncalls,adaptive=True) 41 | vegas_instance.compile(symgauss) 42 | vegas_instance.run_integration(n_iter) 43 | -------------------------------------------------------------------------------- /examples/asian_options_tf.py: -------------------------------------------------------------------------------- 1 | # Asian option like integral 2 | # From https://doi.org/10.1016/S0885-064X(03)00003-7 3 | # Equation (14) 4 | 5 | from vegasflow.configflow import DTYPE, DTYPEINT 6 | from vegasflow import vegas_wrapper 7 | 8 | from tensorflow.python.ops.distributions.special_math import ndtri 9 | import time 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | # MC integration setup 14 | d = 16 15 | ncalls = np.int32(2e6) 16 | n_iter = 5 17 | 18 | # Asian Option setup 19 | T = tf.constant(1.0, dtype=DTYPE) 20 | r = tf.constant(1.0, dtype=DTYPE) 21 | sigma = tf.constant(1.0, dtype=DTYPE) 22 | sigma2 = tf.square(sigma) 23 | S0 = tf.constant(1.0, dtype=DTYPE) 24 | t = tf.constant(tf.ones(shape=(ncalls,), dtype=DTYPE)) 25 | sqrtdt = tf.constant(1.0, dtype=DTYPE) 26 | K = tf.constant(0.0, dtype=DTYPE) 27 | e = tf.exp(tf.constant(-1*r*T, dtype=DTYPE)) 28 | zero = tf.constant(0.0, dtype=DTYPE) 29 | 30 | 31 | @tf.function 32 | def example_integrand(xarr, **kwargs): 33 | """Asian options test function""" 34 | sum1 = tf.reduce_sum(ndtri(xarr), axis=1) 35 | a = S0 * tf.exp((r-sigma2/2) + sigma*sqrtdt*sum1) 36 | arg = 1 / d * tf.reduce_sum(a) 37 | return e*tf.maximum(zero, arg-K) 38 | 39 | 40 | if __name__ == "__main__": 41 | """Testing a basic integration""" 42 | print(f"VEGAS MC, ncalls={ncalls}:") 43 | start = time.time() 44 | r = vegas_wrapper(example_integrand, d, n_iter, ncalls) 45 | end = time.time() 46 | print(f"time (s): {end-start}") 47 | -------------------------------------------------------------------------------- /src/vegasflow/tests/test_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test that the configuration is consistent 3 | """ 4 | 5 | import importlib 6 | import os 7 | 8 | import numpy as np 9 | 10 | import vegasflow.configflow 11 | from vegasflow.configflow import DTYPE, DTYPEINT, float_me, int_me 12 | 13 | 14 | def test_int_me(): 15 | res = int_me(4) 16 | assert res.dtype == DTYPEINT 17 | 18 | 19 | def test_float_me(): 20 | res = float_me(4.0) 21 | assert res.dtype == DTYPE 22 | 23 | 24 | def test_float_env(): 25 | os.environ["VEGASFLOW_FLOAT"] = "32" 26 | importlib.reload(vegasflow.configflow) 27 | from vegasflow.configflow import DTYPE 28 | 29 | assert DTYPE.as_numpy_dtype == np.float32 30 | os.environ["VEGASFLOW_FLOAT"] = "64" 31 | importlib.reload(vegasflow.configflow) 32 | from vegasflow.configflow import DTYPE 33 | 34 | assert DTYPE.as_numpy_dtype == np.float64 35 | # Reset to default 36 | os.environ["VEGASFLOW_FLOAT"] = "64" 37 | importlib.reload(vegasflow.configflow) 38 | 39 | 40 | def test_int_env(): 41 | os.environ["VEGASFLOW_INT"] = "32" 42 | importlib.reload(vegasflow.configflow) 43 | from vegasflow.configflow import DTYPEINT 44 | 45 | assert DTYPEINT.as_numpy_dtype == np.int32 46 | os.environ["VEGASFLOW_INT"] = "64" 47 | importlib.reload(vegasflow.configflow) 48 | from vegasflow.configflow import DTYPEINT 49 | 50 | assert DTYPEINT.as_numpy_dtype == np.int64 51 | # Reset to default 52 | os.environ["VEGASFLOW_INT"] = "32" 53 | importlib.reload(vegasflow.configflow) 54 | -------------------------------------------------------------------------------- /src/vegasflow/plain.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plain implementation of the plainest possible MonteCarlo 3 | """ 4 | 5 | import tensorflow as tf 6 | 7 | from vegasflow.configflow import fone, fzero 8 | from vegasflow.monte_carlo import MonteCarloFlow, sampler, wrapper 9 | 10 | 11 | class PlainFlow(MonteCarloFlow): 12 | """ 13 | Simple Monte Carlo integrator. 14 | """ 15 | 16 | _CAN_RUN_VECTORIAL = True 17 | 18 | def _run_event(self, integrand, ncalls=None): 19 | if ncalls is None: 20 | n_events = self.n_events 21 | else: 22 | n_events = ncalls 23 | 24 | # Generate all random number for this iteration 25 | rnds, xjac = self._generate_random_array(n_events) 26 | 27 | # Compute the integrand 28 | tmp = integrand(rnds, weight=xjac) * xjac 29 | tmp2 = tf.square(tmp) 30 | 31 | # Accommodate multidimensional output by ensuring that only the event axis is accumulated 32 | res = tf.reduce_sum(tmp, axis=0) 33 | res2 = tf.reduce_sum(tmp2, axis=0) 34 | 35 | return res, res2 36 | 37 | def _run_iteration(self): 38 | res, raw_res2 = self.run_event() 39 | res2 = raw_res2 * self.n_events 40 | # Compute the error 41 | err_tmp2 = (res2 - tf.square(res)) / (self.n_events - fone) 42 | sigma = tf.sqrt(tf.maximum(err_tmp2, fzero)) 43 | return res, sigma 44 | 45 | 46 | def plain_wrapper(*args, **kwargs): 47 | """Wrapper around PlainFlow""" 48 | return wrapper(PlainFlow, *args, **kwargs) 49 | 50 | 51 | def plain_sampler(*args, **kwargs): 52 | """Wrapper sampler around PlainFlow""" 53 | return sampler(PlainFlow, *args, **kwargs) 54 | -------------------------------------------------------------------------------- /examples/simgauss_cffi.py: -------------------------------------------------------------------------------- 1 | # Place your function here 2 | import time 3 | import numpy as np 4 | from vegasflow.configflow import DTYPE 5 | from vegasflow.vflow import VegasFlow 6 | import tensorflow as tf 7 | 8 | from cffi import FFI 9 | ffibuilder = FFI() 10 | 11 | 12 | # MC integration setup 13 | dim = 4 14 | ncalls = np.int32(1e5) 15 | n_iter = 5 16 | 17 | if DTYPE is tf.float64: 18 | C_type = "double" 19 | elif DTYPE is tf.float32: 20 | C_type = "float" 21 | else: 22 | raise TypeError(f"Datatype {DTYPE} not understood") 23 | 24 | 25 | ffibuilder.cdef(f""" 26 | void symgauss({C_type}*, int, int, {C_type}*); 27 | """) 28 | 29 | ffibuilder.set_source("_symgauss_cffi", f""" 30 | void symgauss({C_type} *x, int n, int evts, {C_type}* out) 31 | {{ 32 | for (int e = 0; e < evts; e++) 33 | {{ 34 | {C_type} a = 0.1; 35 | {C_type} pref = pow(1.0/a/sqrt(M_PI), n); 36 | {C_type} coef = 0.0; 37 | for (int i = 1; i <= 100*n; i++) {{ 38 | coef += ({C_type}) i; 39 | }} 40 | for (int i = 0; i < n; i++) {{ 41 | coef += pow((x[i+e*n] - 1.0/2.0)/a, 2); 42 | }} 43 | coef -= 100.0*n*(100.0*n+1.0)/2.0; 44 | out[e] = pref*exp(-coef); 45 | }} 46 | }} 47 | """) 48 | ffibuilder.compile(verbose=True) 49 | 50 | from _symgauss_cffi import ffi, lib 51 | 52 | def symgauss(xarr, **kwargs): 53 | n_dim = xarr.shape[-1] 54 | n_events = xarr.shape[0] 55 | 56 | res = np.empty(n_events, dtype = DTYPE.as_numpy_dtype) 57 | x_flat = xarr.numpy().flatten() 58 | 59 | pinput = ffi.cast(f'{C_type}*', ffi.from_buffer(x_flat)) 60 | pres = ffi.cast(f'{C_type}*', ffi.from_buffer(res)) 61 | lib.symgauss(pinput, n_dim, n_events, pres) 62 | return res 63 | 64 | if __name__ == "__main__": 65 | """Testing a basic integration""" 66 | 67 | print(f"VEGAS MC, ncalls={ncalls}:") 68 | start = time.time() 69 | vegas_instance = VegasFlow(dim, ncalls) 70 | vegas_instance.compile(symgauss, compilable = False) 71 | r = vegas_instance.run_integration(n_iter) 72 | end = time.time() 73 | print(f"time (s): {end-start}") 74 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Cffi 132 | _*.c 133 | *.o 134 | *.so 135 | 136 | # dask 137 | dask-worker-space 138 | *.out 139 | -------------------------------------------------------------------------------- /src/vegasflow/tests/test_gradients.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests the gradients of the different algorithms 3 | """ 4 | 5 | import numpy as np 6 | from pytest import mark 7 | import tensorflow as tf 8 | 9 | from vegasflow import PlainFlow, VegasFlow, VegasFlowPlus, float_me, run_eager 10 | 11 | 12 | def generate_integrand(variable): 13 | """Generate an integrand that depends on an input variable""" 14 | 15 | def example_integrand(x): 16 | y = tf.reduce_sum(x, axis=1) 17 | return y * variable 18 | 19 | return example_integrand 20 | 21 | 22 | def generate_differentiable_function(iclass, integrand, dims=3, n_calls=int(1e5), i_kwargs=None): 23 | """Generates a function that depends on the result of a Monte Carlo integral 24 | of ``integrand`` using the class iclass (in differentiable form) as integrator 25 | """ 26 | if i_kwargs is None: 27 | i_kwargs = {} 28 | integrator_instance = iclass(dims, n_calls, verbose=False, **i_kwargs) 29 | integrator_instance.compile(integrand) 30 | # Train 31 | _ = integrator_instance.run_integration(2) 32 | # Now make it differentiable/compilable 33 | runner = integrator_instance.make_differentiable() 34 | 35 | def some_complicated_function(x): 36 | integration_result, *_ = runner() 37 | return x * integration_result 38 | 39 | compiled_fun = tf.function(some_complicated_function) 40 | # Compile the function 41 | _ = compiled_fun(float_me(4.0)) 42 | return compiled_fun 43 | 44 | 45 | def wrapper_test(iclass, x_point=5.0, alpha=10, integrator_kwargs=None): 46 | """Wrapper for all integrators""" 47 | # Create a variable 48 | z = tf.Variable(float_me(1.0)) 49 | # Create an integrand that depends on this variable 50 | integrand = generate_integrand(z) 51 | # Now create a function that depends on its integration 52 | fun = generate_differentiable_function(iclass, integrand, i_kwargs=integrator_kwargs) 53 | 54 | x0 = float_me(x_point) 55 | with tf.GradientTape() as tape: 56 | tape.watch(x0) 57 | y1 = fun(x0) 58 | 59 | grad_1 = tape.gradient(y1, x0) 60 | 61 | # Change the value of the variable 62 | z.assign(z.numpy() * alpha) 63 | 64 | with tf.GradientTape() as tape: 65 | tape.watch(x0) 66 | y2 = fun(x0) 67 | 68 | grad_2 = tape.gradient(y2, x0) 69 | 70 | # Test that the gradient works as expected 71 | np.testing.assert_allclose(grad_1 * alpha, grad_2, rtol=1e-2) 72 | 73 | 74 | @mark.parametrize("algorithm", [VegasFlowPlus, VegasFlow, PlainFlow]) 75 | def test_gradient(algorithm): 76 | """ "Test one can compile and generate gradients with the different algorithms""" 77 | wrapper_test(algorithm) 78 | 79 | 80 | def test_gradient_VegasflowPlus_adaptive(): 81 | """ "Test one can compile and generate gradients with VegasFlowPlus""" 82 | wrapper_test(VegasFlowPlus, integrator_kwargs={"adaptive": True}) 83 | -------------------------------------------------------------------------------- /src/vegasflow/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | """ Test the utilities """ 2 | 3 | import numpy as np 4 | import pytest 5 | import tensorflow as tf 6 | 7 | from vegasflow.configflow import DTYPEINT, int_me 8 | from vegasflow.utils import consume_array_into_indices, generate_condition_function, py_consume_array_into_indices 9 | 10 | 11 | def test_consume_array_into_indices(): 12 | # Select the size 13 | size_in = np.random.randint(5, 100) 14 | size_out = np.random.randint(1, size_in - 3) 15 | # Generate the input array and the indices 16 | input_array = np.random.rand(size_in) 17 | indices = np.random.randint(0, size_out, size=size_in) 18 | # Make them into TF 19 | tf_input = tf.constant(input_array) 20 | tf_indx = tf.constant(indices.reshape(-1, 1), dtype=DTYPEINT) 21 | result = consume_array_into_indices(tf_input, tf_indx, int_me(size_out)) 22 | py_result = py_consume_array_into_indices(tf_input, tf_indx, size_out) 23 | np.testing.assert_array_equal(result, py_result) 24 | # Check that no results were lost 25 | np.testing.assert_almost_equal(np.sum(input_array), np.sum(result)) 26 | # Check that the arrays in numpy produce the same in numpy 27 | check_result = np.zeros(size_out) 28 | for val, i in zip(input_array, indices): 29 | check_result[i] += val 30 | np.testing.assert_allclose(check_result, result) 31 | 32 | 33 | def util_check(np_mask, tf_mask, tf_ind): 34 | np.testing.assert_array_equal(np_mask, tf_mask) 35 | # Numpy returns things the other way around 36 | np_indices = np.array(np_mask.nonzero()).T 37 | np.testing.assert_array_equal(np_indices, tf_ind) 38 | 39 | 40 | def test_generate_condition_function(): 41 | """Tests generate_condition_function and its errors""" 42 | masks = 4 # Always > 2 43 | vals = 15 44 | np_masks = np.random.randint(2, size=(masks, vals), dtype=bool) 45 | tf_masks = [tf.constant(i, dtype=tf.bool) for i in np_masks] 46 | # Generate the functions for and and or 47 | f_and = generate_condition_function(masks, "and") 48 | f_or = generate_condition_function(masks, "or") 49 | # Get the numpy and tf results 50 | np_ands = np.all(np_masks, axis=0) 51 | np_ors = np.any(np_masks, axis=0) 52 | tf_ands, idx_ands = f_and(*tf_masks) 53 | tf_ors, idx_ors = f_or(*tf_masks) 54 | # Check the values are the same 55 | util_check(np_ands, tf_ands, idx_ands) 56 | util_check(np_ors, tf_ors, idx_ors) 57 | # Check a combination 58 | f_comb = generate_condition_function(3, ["and", "or"]) 59 | np_comb = np_masks[0] & np_masks[1] | np_masks[2] 60 | tf_comb, idx_comb = f_comb(*tf_masks[:3]) 61 | util_check(np_comb, tf_comb, idx_comb) 62 | # Check failures 63 | with pytest.raises(ValueError): 64 | generate_condition_function(1, "and") 65 | with pytest.raises(ValueError): 66 | generate_condition_function(5, "bad_condition") 67 | with pytest.raises(ValueError): 68 | generate_condition_function(5, ["or", "and"]) 69 | with pytest.raises(ValueError): 70 | generate_condition_function(3, ["or", "bad_condition"]) 71 | -------------------------------------------------------------------------------- /doc/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | 16 | sys.path.insert(0, os.path.abspath('..')) 17 | import vegasflow 18 | 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = 'vegasflow' 23 | copyright = '2020, Stefano Carrazza and Juan Cruz-Martinez' 24 | author = 'Stefano Carrazza and Juan Cruz-Martinez' 25 | 26 | # The full version, including alpha/beta/rc tags 27 | release = vegasflow.__version__ 28 | autodoc_mock_imports = ['tensorflow'] 29 | 30 | 31 | # -- General configuration --------------------------------------------------- 32 | # 33 | # https://stackoverflow.com/questions/56336234/build-fail-sphinx-error-contents-rst-not-found 34 | master_doc = 'index' 35 | 36 | # Add any Sphinx extension module names here, as strings. They can be 37 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 38 | # ones. 39 | extensions = [ 40 | 'sphinx.ext.autodoc', 41 | 'sphinx.ext.doctest', 42 | 'sphinx.ext.coverage', 43 | 'sphinx.ext.napoleon', 44 | 'sphinx.ext.intersphinx', 45 | ] 46 | 47 | # Add any paths that contain templates here, relative to this directory. 48 | templates_path = ['_templates'] 49 | 50 | # Markdown configuration 51 | 52 | # The suffix(es) of source filenames. 53 | # You can specify multiple suffix as a list of string: 54 | # 55 | source_suffix = { 56 | '.rst': 'restructuredtext', 57 | } 58 | 59 | autosectionlabel_prefix_document = True 60 | # Allow to embed rst syntax in markdown files. 61 | enable_eval_rst = True 62 | 63 | # List of patterns, relative to source directory, that match files and 64 | # directories to ignore when looking for source files. 65 | # This pattern also affects html_static_path and html_extra_path. 66 | exclude_patterns = [] 67 | 68 | 69 | # -- Options for HTML output ------------------------------------------------- 70 | 71 | # The theme to use for HTML and HTML Help pages. See the documentation for 72 | # a list of builtin themes. 73 | # 74 | html_theme = 'sphinx_rtd_theme' 75 | 76 | # Add any paths that contain custom static files (such as style sheets) here, 77 | # relative to this directory. They are copied after the builtin static files, 78 | # so a file named "default.css" will overwrite the builtin "default.css". 79 | #html_static_path = ['_static'] 80 | 81 | 82 | # -- Intersphinx ------------------------------------------------------------- 83 | 84 | intersphinx_mapping = {'python': ('https://docs.python.org/3', None)} 85 | 86 | # -- Doctest ------------------------------------------------------------------ 87 | # 88 | 89 | doctest_path = [os.path.abspath('../examples')] 90 | 91 | # -- Autodoc ------------------------------------------------------------------ 92 | # 93 | autodoc_member_order = 'bysource' 94 | -------------------------------------------------------------------------------- /examples/cuda/integrand.cpp: -------------------------------------------------------------------------------- 1 | //#include "cuda_kernel.h" 2 | 3 | #include "tensorflow/core/framework/op.h" 4 | #include "tensorflow/core/framework/shape_inference.h" 5 | #include "tensorflow/core/framework/op_kernel.h" 6 | #include "integrand.h" 7 | 8 | /* 9 | * In this example we follow the TF guide for operation creation 10 | * https://www.tensorflow.org/guide/create_op 11 | * to create an integrand as a custom operators. 12 | * 13 | * To first approximation, these operators are function that take 14 | * a tensor and return a tensor. 15 | */ 16 | 17 | using namespace tensorflow; 18 | 19 | using GPUDevice = Eigen::GpuDevice; 20 | using CPUDevice = Eigen::ThreadPoolDevice; 21 | 22 | // CPU 23 | template 24 | struct IntegrandOpFunctor { 25 | void operator()(const CPUDevice &d, const T *input, T *output, const int nevents, const int dims) { 26 | for (int i = 0; i < nevents; i++) { 27 | output[i] = 0.0; 28 | for(int j = 0; j < dims; j++) { 29 | output[i] += input[i,j]; 30 | } 31 | } 32 | } 33 | }; 34 | 35 | 36 | /* The input and output type must be coherent with the types used in tensorflow 37 | * at this moment we are using float64 as default for vegasflow. 38 | * 39 | * The output shape is set to be (input_shape[0], ), i.e., number of events 40 | */ 41 | REGISTER_OP("IntegrandOp") 42 | .Input("xarr: double") 43 | .Output("ret: double") 44 | .SetShapeFn([](shape_inference::InferenceContext* c) { 45 | c -> set_output(0, c -> MakeShape( { c -> Dim(c -> input(0), 0) } ) ); 46 | return Status::OK(); 47 | }); 48 | 49 | template 50 | class IntegrandOp: public OpKernel { 51 | public: 52 | explicit IntegrandOp(OpKernelConstruction* context): OpKernel(context) {} 53 | 54 | void Compute(OpKernelContext* context) override { 55 | // Grab input tensor, which is expected to be of shape (nevents, ndim) 56 | const Tensor& input_tensor = context->input(0); 57 | auto input = input_tensor.tensor().data(); 58 | auto input_shape = input_tensor.shape(); 59 | 60 | // Create an output tensor of shape (nevents,) 61 | Tensor* output_tensor = nullptr; 62 | TensorShape output_shape; 63 | const int N = input_shape.dim_size(0); 64 | const int dims = input_shape.dim_size(1); 65 | output_shape.AddDim(N); 66 | OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_tensor)); 67 | 68 | auto output_flat = output_tensor->flat().data(); 69 | 70 | // Perform the actual computation 71 | IntegrandOpFunctor()( 72 | context->eigen_device(), input, output_flat, N, dims 73 | ); 74 | } 75 | }; 76 | 77 | // Register the CPU version of the kernel 78 | #define REGISTER_CPU(T) \ 79 | REGISTER_KERNEL_BUILDER(Name("IntegrandOp").Device(DEVICE_CPU), IntegrandOp); 80 | REGISTER_CPU(double); 81 | 82 | // Register the GPU version 83 | #ifdef KERNEL_CUDA 84 | #define REGISTER_GPU(T) \ 85 | /* Declare explicit instantiations in kernel_example.cu.cc. */ \ 86 | extern template class IntegrandOpFunctor; \ 87 | REGISTER_KERNEL_BUILDER(Name("IntegrandOp").Device(DEVICE_GPU),IntegrandOp); 88 | REGISTER_GPU(double); 89 | #endif 90 | -------------------------------------------------------------------------------- /src/vegasflow/tests/test_misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Miscellaneous tests that don't really fit anywhere else 3 | """ 4 | 5 | import numpy as np 6 | import pytest 7 | import tensorflow as tf 8 | 9 | from vegasflow import PlainFlow, VegasFlow, VegasFlowPlus 10 | 11 | from .test_algs import check_is_one, instance_and_compile 12 | 13 | 14 | def _vector_integrand(xarr, weight=None): 15 | res = tf.square((xarr - 1.0) ** 2) 16 | return tf.exp(-res) / 0.845 17 | 18 | 19 | def _wrong_integrand(xarr): 20 | """Integrand with the wrong output shape""" 21 | return tf.reduce_sum(xarr) 22 | 23 | 24 | def _simple_integrand(xarr): 25 | """Integrand f(x) = x""" 26 | return tf.reduce_prod(xarr, axis=1) 27 | 28 | 29 | def _simple_integral(xmin, xmax): 30 | """Integated version of simple_ingrand""" 31 | xm = np.array(xmin) ** 2 / 2.0 32 | xp = np.array(xmax) ** 2 / 2.0 33 | return np.prod(xp - xm) 34 | 35 | 36 | def _wrong_vector_integrand(xarr): 37 | """Vector integrand with the wrong output shape""" 38 | return tf.transpose(xarr) 39 | 40 | 41 | @pytest.mark.parametrize("mode", range(4)) 42 | @pytest.mark.parametrize("alg", [VegasFlow, PlainFlow]) 43 | def test_working_vectorial(alg, mode): 44 | """Check that the algorithms that accept integrating vectorial functions can really do so""" 45 | inst = instance_and_compile(alg, mode=mode, integrand_function=_vector_integrand) 46 | result = inst.run_integration(2) 47 | check_is_one(result, sigmas=5) 48 | 49 | 50 | @pytest.mark.parametrize("alg", [VegasFlowPlus]) 51 | def test_notworking_vectorial(alg): 52 | """Check that the algorithms that do not accept vectorial functions fail appropriately""" 53 | with pytest.raises(NotImplementedError): 54 | _ = instance_and_compile(alg, integrand_function=_vector_integrand) 55 | 56 | 57 | def test_check_wrong_main_dimension(): 58 | """Check that an error is raised by VegasFlow 59 | if the main dimension is > than the dimensionality of the integrand""" 60 | inst = VegasFlow(3, 100, main_dimension=5) 61 | with pytest.raises(ValueError): 62 | inst.compile(_vector_integrand) 63 | 64 | 65 | @pytest.mark.parametrize("wrong_fun", [_wrong_vector_integrand, _wrong_integrand]) 66 | def test_wrong_shape(wrong_fun): 67 | """Check that an error is raised by the compilation if the integrand has the wrong shape""" 68 | with pytest.raises(ValueError): 69 | _ = instance_and_compile(PlainFlow, integrand_function=wrong_fun) 70 | 71 | 72 | @pytest.mark.parametrize("alg", [PlainFlow, VegasFlow, VegasFlowPlus]) 73 | def test_integration_limits(alg, ncalls=int(1e4)): 74 | """Test an integration where the integration limits are modified""" 75 | dims = np.random.randint(1, 5) 76 | xmin = -1.0 + np.random.rand(dims) * 2.0 77 | xmax = 3.0 + np.random.rand(dims) 78 | inst = alg(dims, ncalls, xmin=xmin, xmax=xmax) 79 | inst.compile(_simple_integrand) 80 | result = inst.run_integration(5) 81 | expected_result = _simple_integral(xmin, xmax) 82 | check_is_one(result, target_result=expected_result) 83 | 84 | 85 | def test_integration_limits_checks(): 86 | """Test that the errors for wrong limits actually work""" 87 | # use hypothesis to check other corner cases 88 | with pytest.raises(ValueError): 89 | PlainFlow(1, 10, xmin=[10], xmax=[1]) 90 | with pytest.raises(ValueError): 91 | PlainFlow(1, 10, xmin=[10]) 92 | with pytest.raises(ValueError): 93 | PlainFlow(1, 10, xmax=[10]) 94 | with pytest.raises(ValueError): 95 | PlainFlow(2, 10, xmin=[0], xmax=[1]) 96 | with pytest.raises(ValueError): 97 | PlainFlow(2, 10, xmin=[0, 1], xmax=[1]) 98 | -------------------------------------------------------------------------------- /src/vegasflow/configflow.py: -------------------------------------------------------------------------------- 1 | """ 2 | Define some constants, header style 3 | """ 4 | 5 | import logging 6 | 7 | # Most of this can be moved to a yaml file without loss of generality 8 | import os 9 | 10 | import numpy as np 11 | 12 | # Some global parameters 13 | BINS_MAX = 50 14 | ALPHA = 1.5 15 | BETA = 0.75 # Vegas+ 16 | TECH_CUT = 1e-8 17 | 18 | # Set up the logistics of the integration 19 | # set the limits lower if hitting memory problems 20 | 21 | # Events Limit limits how many events are done in one single run of the event_loop 22 | MAX_EVENTS_LIMIT = int(1e6) 23 | # Maximum number of evaluation per hypercube for VegasFlowPlus 24 | MAX_NEVAL_HCUBE = int(1e4) 25 | 26 | # Select the list of devices to look for 27 | DEFAULT_ACTIVE_DEVICES = ["GPU"] # , 'CPU'] 28 | 29 | # Log levels 30 | LOG_DICT = {"0": logging.ERROR, "1": logging.WARNING, "2": logging.INFO, "3": logging.DEBUG} 31 | 32 | # Read the VEGASFLOW environment variables 33 | _log_level_idx = os.environ.get("VEGASFLOW_LOG_LEVEL") 34 | _data_path = os.environ.get("VEGASFLOW_DATA_PATH") 35 | _float_env = os.environ.get("VEGASFLOW_FLOAT", "64") 36 | _int_env = os.environ.get("VEGASFLOW_INT", "32") 37 | 38 | 39 | # Logging 40 | _bad_log_warning = None 41 | if _log_level_idx not in LOG_DICT: 42 | _bad_log_warning = _log_level_idx 43 | _log_level_idx = None 44 | 45 | if _log_level_idx is None: 46 | # If no log level is provided, set some defaults 47 | _log_level = LOG_DICT["2"] 48 | _tf_log_level = LOG_DICT["0"] 49 | else: 50 | _log_level = _tf_log_level = LOG_DICT[_log_level_idx] 51 | 52 | # Configure logging 53 | logger = logging.getLogger(__name__.split(".")[0]) 54 | logger.setLevel(_log_level) 55 | 56 | # Create and format the log handler 57 | _console_handler = logging.StreamHandler() 58 | _console_handler.setLevel(_log_level) 59 | _console_format = logging.Formatter("[%(levelname)s] (%(name)s) %(message)s") 60 | _console_handler.setFormatter(_console_format) 61 | logger.addHandler(_console_handler) 62 | 63 | os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "1") 64 | import tensorflow as tf 65 | 66 | tf.get_logger().setLevel(_tf_log_level) 67 | 68 | # set the precision type 69 | if _float_env == "64": 70 | DTYPE = tf.float64 71 | FMAX = tf.constant(np.finfo(np.float64).max, dtype=DTYPE) 72 | elif _float_env == "32": 73 | DTYPE = tf.float32 74 | FMAX = tf.constant(np.finfo(np.float32).max, dtype=DTYPE) 75 | else: 76 | DTYPE = tf.float64 77 | FMAX = tf.constant(np.finfo(np.float64).max, dtype=DTYPE) 78 | logger.warning(f"VEGASFLOW_FLOAT={_float_env} not understood, defaulting to 64 bits") 79 | 80 | if _int_env == "64": 81 | DTYPEINT = tf.int64 82 | elif _int_env == "32": 83 | DTYPEINT = tf.int32 84 | else: 85 | DTYPEINT = tf.int64 86 | logger.warning(f"VEGASFLOW_INT={_int_env} not understood, defaulting to 64 bits") 87 | 88 | 89 | def run_eager(flag=True): 90 | """Wrapper around `run_functions_eagerly` 91 | When used no function is compiled 92 | """ 93 | if tf.__version__ < "2.3.0": 94 | tf.config.experimental_run_functions_eagerly(flag) 95 | else: 96 | tf.config.run_functions_eagerly(flag) 97 | 98 | 99 | FMAX = tf.constant(np.finfo(np.float64).max, dtype=DTYPE) 100 | 101 | 102 | # The wrappers below transform tensors and array to the correct type 103 | def int_me(i): 104 | """Cast the input to the `DTYPEINT` type""" 105 | return tf.cast(i, dtype=DTYPEINT) 106 | 107 | 108 | def float_me(i): 109 | """Cast the input to the `DTYPE` type""" 110 | return tf.cast(i, dtype=DTYPE) 111 | 112 | 113 | ione = int_me(1) 114 | izero = int_me(0) 115 | fone = float_me(1) 116 | fzero = float_me(0) 117 | -------------------------------------------------------------------------------- /examples/histogram_ex.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example of integration of a simple function 3 | With the filling of a histogram 4 | """ 5 | 6 | import time 7 | import numpy as np 8 | from vegasflow.configflow import DTYPE, fzero, fone 9 | import tensorflow as tf 10 | from vegasflow.plain import PlainFlow 11 | from vegasflow.utils import consume_array_into_indices 12 | 13 | # MC integration setup 14 | dim = 3 15 | ncalls = np.int32(1e5) 16 | n_iter = 5 17 | hst_dim = 2 18 | HISTO_BINS = 2 19 | 20 | 21 | def generate_integrand(cummulator_tensor): 22 | """ 23 | This function will generate an integrand function 24 | which will already hold a reference to the tensor to accumulate 25 | """ 26 | 27 | @tf.function 28 | def histogram_collector(results, variables): 29 | """This function will receive a tensor (result) 30 | and the variables corresponding to those integrand results 31 | In the example integrand below, these corresponds to 32 | `final_result` and `histogram_values` respectively. 33 | `current_histograms` instead is the current value of the histogram 34 | which will be overwritten""" 35 | # Fill a histogram with HISTO_BINS (2) bins, (0 to 0.5, 0.5 to 1) 36 | # First generate the indices with TF 37 | indices = tf.histogram_fixed_width_bins(variables, [fzero, fone], nbins=HISTO_BINS) 38 | t_indices = tf.transpose(indices) 39 | # Then consume the results with the utility we provide 40 | partial_hist = consume_array_into_indices(results, t_indices, HISTO_BINS) 41 | # Then update the results of current_histograms 42 | new_histograms = partial_hist + current_histograms 43 | cummulator_tensor.assign(new_histograms) 44 | 45 | def integrand_example(xarr, weight=fone): 46 | """Example of function which saves histograms""" 47 | n_dim = xarr.shape[-1] 48 | a = tf.constant(0.1, dtype=DTYPE) 49 | n100 = tf.cast(100 * dim, dtype=DTYPE) 50 | pref = tf.pow(1.0 / a / np.sqrt(np.pi), n_dim) 51 | coef = tf.reduce_sum(tf.range(n100 + 1)) 52 | coef += tf.reduce_sum(tf.square((xarr - 1.0 / 2.0) / a), axis=1) 53 | coef -= (n100 + 1) * n100 / 2.0 54 | final_result = pref * tf.exp(-coef) 55 | # Collect the value for the histogram collector in a tuple 56 | histogram_values = (xarr[:, hst_dim],) 57 | histogram_collector(final_result * weight, histogram_values) 58 | return final_result 59 | 60 | return integrand_example 61 | 62 | 63 | if __name__ == "__main__": 64 | """Testing histogram generation""" 65 | if dim < hst_dim: 66 | raise ValueError( 67 | f"The number of dimensions has to be greater than {hst_dim} for this example" 68 | ) 69 | print(f"Plain MC, ncalls={ncalls}:") 70 | start = time.time() 71 | # First we create the tensor in which to accumulate the histogra 72 | # This part is completely free 73 | current_histograms = tf.Variable(tf.zeros(HISTO_BINS, dtype=DTYPE)) 74 | integrand_example = generate_integrand(current_histograms) 75 | mc_instance = PlainFlow(dim, ncalls) 76 | mc_instance.compile(integrand_example, compilable=True) 77 | # Pass the histogram variables to the integration 78 | # so it can be filled only once per iteration 79 | # This needs to be a tuple/list of tensor variables 80 | # as they will be emptied at the end of each iteration 81 | histogram_tuple = (current_histograms,) 82 | results = mc_instance.run_integration(n_iter, histograms=histogram_tuple) 83 | r = results[0] 84 | s = results[1] 85 | # At the end of the integration the variable `current_histograms` is filled 86 | # with the weighted accumulation of the histograms per iteration 87 | # while the result of the histogram each iteration can be accessed through 88 | # the history of the integrator 89 | results_per_iteration = mc_instance.history 90 | end = time.time() 91 | print(f"Plain took: time (s): {end-start}") 92 | print(f"Final result: {r} +/- {s}") 93 | print(f"Final histogram: {current_histograms}") 94 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![DOI](https://zenodo.org/badge/226363558.svg)](https://zenodo.org/badge/latestdoi/226363558) 2 | [![cpc](https://img.shields.io/badge/j.%20Computer%20Physics%20Communication-2020%2F107376-blue)](https://inspirehep.net/literature/1783000) 3 | 4 | [![Tests](https://github.com/N3PDF/vegasflow/workflows/pytest/badge.svg)](https://github.com/N3PDF/vegasflow/actions?query=workflow%3A%22pytest%22) 5 | [![Documentation Status](https://readthedocs.org/projects/vegasflow/badge/?version=latest)](https://vegasflow.readthedocs.io/en/latest/?badge=latest) 6 | 7 | 8 | # VegasFlow 9 | 10 | VegasFlow is a Monte Carlo integration library written in Python and based on the [TensorFlow](https://www.tensorflow.org/) framework. It is developed with a focus on speed and efficiency, enabling researchers to perform very expensive calculation as quick and easy as possible. 11 | 12 | Some of the key features of VegasFlow are: 13 | - Integrates efficiently high dimensional functions on single (multi-threading) and multi CPU, single and multi GPU, many GPUs or clusters. 14 | 15 | - Compatible with Python, C, C++ or Fortran. 16 | 17 | - Implementation of different Monte Carlo algorithms. 18 | 19 | ## Documentation 20 | The documentation for VegasFlow is available at [vegasflow.readthedocs.io](https://vegasflow.readthedocs.io/en/latest). 21 | 22 | ## Installation 23 | [![Anaconda-Server Badge](https://anaconda.org/conda-forge/vegasflow/badges/installer/conda.svg)](https://anaconda.org/conda-forge/vegasflow) 24 | [![AUR](https://img.shields.io/badge/aur-vegasflow-blue)](https://aur.archlinux.org/packages/python-vegasflow/) 25 | 26 | The package can be installed with pip: 27 | ```bash 28 | python3 -m pip install vegasflow 29 | ``` 30 | 31 | as well as `conda`, from the `conda-forge` channel: 32 | ```bash 33 | conda install vegasflow -c conda-forge 34 | ``` 35 | 36 | If you prefer a manual installation you can clone the repository and run: 37 | ```bash 38 | git clone https://github.com/N3PDF/vegasflow.git 39 | cd vegasflow 40 | python setup.py install 41 | ``` 42 | or if you are planning to extend or develop the code just use: 43 | ```bash 44 | python setup.py develop 45 | ``` 46 | 47 | ## Examples 48 | A number of examples (basic integration, cuda, external tools integration) can be found in the [examples folder](https://github.com/N3PDF/vegasflow/tree/master/examples). A more detailed description can be found in the [documention](https://vegasflow.readthedocs.io/en/latest/examples.html). 49 | 50 | Below you can find a minimal workflow for using the examples provided with VegasFlow: 51 | 52 | Firstly, one can install any extra dependencies required by the examples using: 53 | 54 | ```bash 55 | pip install .[examples] 56 | ``` 57 | 58 | ### Minimal Working Example 59 | ```python 60 | from vegasflow import vegas_wrapper 61 | import tensorflow as tf 62 | 63 | def integrand(x, **kwargs): 64 | """ Function: 65 | x_{1} * x_{2} ... * x_{n} 66 | x: array of dimension (events, n) 67 | """ 68 | return tf.reduce_prod(x, axis=1) 69 | 70 | dimensions = 8 71 | iterations = 5 72 | events_per_iteration = int(1e5) 73 | vegas_wrapper(integrand, dimensions, iterations, events_per_iteration, compilable=True) 74 | ``` 75 | 76 | Please feel free to [open an issue](https://github.com/N3PDF/vegasflow/issues/new) if you would like 77 | some specific example or find any problems at all with the code or the documentation. 78 | 79 | ## Citation policy 80 | 81 | If you use the package please cite the following paper and zenodo references: 82 | - [https://doi.org/10.5281/zenodo.3691926](https://doi.org/10.5281/zenodo.3691926) 83 | - [https://arxiv.org/abs/2002.12921](https://arxiv.org/abs/2002.12921) 84 | 85 | ```bibtex 86 | @article{Carrazza:2020rdn, 87 | author = "Carrazza, Stefano and Cruz-Martinez, Juan M.", 88 | title = "{VegasFlow: accelerating Monte Carlo simulation across multiple hardware platforms}", 89 | eprint = "2002.12921", 90 | archivePrefix = "arXiv", 91 | primaryClass = "physics.comp-ph", 92 | reportNumber = "TIF-UNIMI-2020-8", 93 | doi = "10.1016/j.cpc.2020.107376", 94 | journal = "Comput. Phys. Commun.", 95 | volume = "254", 96 | pages = "107376", 97 | year = "2020" 98 | } 99 | 100 | 101 | @software{vegasflow_package, 102 | author = {Juan Cruz-Martinez and 103 | Stefano Carrazza}, 104 | title = {N3PDF/vegasflow: vegasflow v1.0}, 105 | month = feb, 106 | year = 2020, 107 | publisher = {Zenodo}, 108 | version = {v1.0}, 109 | doi = {10.5281/zenodo.3691926}, 110 | url = {https://doi.org/10.5281/zenodo.3691926} 111 | } 112 | ``` 113 | -------------------------------------------------------------------------------- /examples/multiple_integrals.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Example of using VegasFlow to compute an arbitrary number of integrals 4 | Note that it should only be used for similarly-behaved integrands. 5 | 6 | The example integrands are variations of the Genz functions definged in 7 | Novak et al, 1999 (J. of Comp and Applied Maths, 112 (1999) 215-228 and implemented from http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.123.8452&rep=rep1&type=pdf 8 | 9 | Note, when possible the ``multidimensional_integral.py`` features should be utilized 10 | as then the error computation is automatically taken into account by the algorithms 11 | in VegasFlow instead of having to implement it by hand. 12 | """ 13 | import vegasflow 14 | from vegasflow.configflow import DTYPE 15 | import tensorflow as tf 16 | import numpy as np 17 | 18 | NDIM = 4 19 | NPARAMS = 10 20 | a1 = 0.1 21 | a2 = 0.5 22 | 23 | RESULT_ARRAY = [] 24 | ERRORS_ARRAY = [] 25 | 26 | w = np.random.rand(NDIM) 27 | 28 | 29 | def generate_genz(a): 30 | """Generates the genz_oscillatory function for a given parameter a""" 31 | 32 | def genz_oscillatory(xrandr, **kwargs): 33 | """ 34 | The input xrandr has shape (nevents, ndim) 35 | the output of the function is (nevents,) 36 | """ 37 | res = tf.einsum("ij, j->i", xrandr, w) + 2.0 * np.pi * a 38 | return (tf.cos(res) + 1.0) / 2.0 39 | 40 | return genz_oscillatory 41 | 42 | 43 | def generate_multiple_genz(full_a): 44 | """Compute the genz oscillatory function for multiple values of a at once""" 45 | # Make sure that full_a is a tensor 46 | # _and_ add a dummy event axis 47 | a = tf.constant(np.reshape(full_a, (1, -1)), dtype=DTYPE) 48 | 49 | def genz_oscillatory_multiple(xrandr, **kwargs): 50 | """ 51 | The input xrandr has shape (nevents, ndim) 52 | the output of the function is (nevents, nparams) 53 | """ 54 | raw_res = tf.einsum("ij, j->i", xrandr, w) 55 | # Add a dummy parameter-axis 56 | res = tf.reshape(raw_res, (-1, 1)) + 2.0 * np.pi * a 57 | return (tf.cos(res) + 1.0) / 2.0 58 | 59 | return genz_oscillatory_multiple 60 | 61 | 62 | test1 = generate_genz(a1) 63 | test2 = generate_multiple_genz([a1, a2]) 64 | 65 | # Generate 5 events 66 | rval = np.random.rand(5, NDIM) 67 | # Check the shape of the two functions 68 | r1 = test1(rval) 69 | r2 = test2(rval) 70 | print(f"Shape of single-parameter call: {r1.numpy().shape}") 71 | print(f"Shape of multiparameter-parameter call: {r2.numpy().shape}") 72 | 73 | # Now we can generate an integrand for this 74 | generate_single_integrand = generate_genz 75 | 76 | # And run it with VegasFlow for a given value of a=0.1 77 | result_single_1 = vegasflow.vegas_wrapper( 78 | generate_single_integrand(a1), NDIM, n_iter=5, total_n_events=int(1e4) 79 | ) 80 | result_single_2 = vegasflow.vegas_wrapper( 81 | generate_single_integrand(a2), NDIM, n_iter=5, total_n_events=int(1e4) 82 | ) 83 | 84 | # Now, we cannot do the same for the multiple parameter as VegasFlow is expecting an escalar result! 85 | # So we need to save the results for each integrand somehow, let's use a "digest function" 86 | 87 | 88 | def digest_function(all_results, all_wgts): 89 | # Receive the results from the computing device 90 | # and associated event weights 91 | results = all_results.numpy() 92 | wgts = all_wgts.numpy().reshape(-1) 93 | # Now get the weighted average of the result (the MC estimate) per parameter 94 | final_result = np.einsum("ij, i->j", results, wgts) 95 | # And compute the errors 96 | sq_results = np.einsum("ij, i->j", results**2, wgts**2) * len(wgts) 97 | errors = np.abs(final_result**2 - sq_results) / (len(wgts) - 1) 98 | print(f"{final_result} +- {np.sqrt(errors)}") 99 | RESULT_ARRAY.append(final_result) 100 | ERRORS_ARRAY.append(errors) 101 | return 0.0 102 | 103 | 104 | def generate_multiple_integrand(full_a): 105 | 106 | genz_function = generate_multiple_genz(full_a) 107 | 108 | # Clean the previous results (if any) 109 | while RESULT_ARRAY: 110 | RESULT_ARRAY.pop() 111 | ERRORS_ARRAY.pop() 112 | 113 | def integrand(xrandr, weight=1.0): 114 | res = genz_function(xrandr) 115 | 116 | # Store the individual result 117 | tf.py_function(digest_function, [res, weight], Tout=DTYPE) 118 | 119 | # Return the result for the first parameter 120 | return res[:, 0] 121 | 122 | return integrand 123 | 124 | 125 | vegasflow.vegas_wrapper( 126 | generate_multiple_integrand([a1, a2]), NDIM, n_iter=5, total_n_events=int(1e4) 127 | ) 128 | 129 | res_mul_1, res_mul_2 = np.average(RESULT_ARRAY, axis=0) 130 | 131 | err_mul_1, err_mul_2 = np.sqrt( 132 | 1 / np.sum(1 / np.array(ERRORS_ARRAY) ** 2, axis=0) / len(ERRORS_ARRAY) 133 | ) 134 | 135 | print( 136 | f""" 137 | The single calculation obtained: 138 | {result_single_1[0]:.4} +- {result_single_1[1]:.4} for a={a1} and 139 | {result_single_2[0]:.4} +- {result_single_2[1]:.4} for a={a2} 140 | 141 | while the multiple integration found: 142 | {res_mul_1:.4} +- {err_mul_1:.4} for a={a1} and 143 | {res_mul_2:.4} +- {err_mul_2:.4} for a={a2} 144 | """ 145 | ) 146 | -------------------------------------------------------------------------------- /src/vegasflow/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains tensorflow_compiled utilities 3 | """ 4 | 5 | import tensorflow as tf 6 | 7 | from vegasflow.configflow import DTYPE, DTYPEINT, float_me, fzero, int_me 8 | 9 | 10 | @tf.function( 11 | input_signature=[ 12 | tf.TensorSpec(shape=[None], dtype=DTYPE), 13 | tf.TensorSpec(shape=[None, None], dtype=DTYPEINT), 14 | tf.TensorSpec(shape=[], dtype=DTYPEINT), 15 | ] 16 | ) 17 | def consume_array_into_indices(input_arr, indices, result_size): 18 | """ 19 | Accumulate the input tensor `input_arr` into an output tensor of 20 | size `result_size`. The accumulation occurs according to the array 21 | of `indices`. 22 | 23 | For instance, `input_array` = [a,b,c,d] and vector column `indices` = [[0,1,0,0]].T 24 | (with `result_size` = 2) will result in a final_result: (a+c+d, b) 25 | 26 | Parameters 27 | ---------- 28 | input_arr 29 | Array of results to be consumed 30 | indices 31 | Indices of the bins in which to accumulate the input array 32 | result_size 33 | size of the output array 34 | 35 | Returns 36 | ------- 37 | `final_result` 38 | Array of size `result_size` 39 | """ 40 | all_bins = tf.range(result_size, dtype=DTYPEINT) 41 | eq = tf.transpose(tf.equal(indices, all_bins)) 42 | res_tmp = tf.where(eq, input_arr, fzero) 43 | final_result = tf.reduce_sum(res_tmp, axis=1) 44 | return final_result 45 | 46 | 47 | def py_consume_array_into_indices(input_arr, indices, result_size): 48 | """ 49 | Python interface wrapper for ``consume_array_into_indices``. 50 | It casts the possible python-object input into the correct tensorflow types. 51 | """ 52 | return consume_array_into_indices(float_me(input_arr), int_me(indices), int_me(result_size)) 53 | 54 | 55 | def generate_condition_function(n_mask, condition="and"): 56 | r"""Generates a function that takes a number of masks 57 | and returns a combination of all n_masks for the given condition. 58 | 59 | It is possible to pass a list of allowed conditions, in that case 60 | the length of the list should be n_masks - 1 and will be applied sequentially. 61 | Note that for 2 masks you can directly use ``&`` and ``|``. 62 | 63 | Parameters 64 | ---------- 65 | n_mask: int 66 | Number of masks the function should accept 67 | condition: str (default="and") 68 | Condition to apply to all masks. Accepted values are: ``and``, ``or`` 69 | 70 | Returns 71 | ------- 72 | condition_to_idx: function 73 | ``function(*masks)`` -> full_mask, true indices 74 | 75 | Example 76 | ------- 77 | >>> from vegasflow.utils import generate_condition_function 78 | >>> import tensorflow as tf 79 | >>> f_cond = generate_condition_function(2, condition='or') 80 | >>> t_1 = tf.constant([True, False, True]) 81 | >>> t_2 = tf.constant([False, False, True]) 82 | >>> full_mask, indices = f_cond(t_1, t_2) 83 | >>> print(f"{full_mask=}\n{indices=}") 84 | full_mask= 85 | indices= 86 | 87 | >>> f_cond = generate_condition_function(3, condition=['or', 'and']) 88 | >>> t_1 = tf.constant([True, False, True]) 89 | >>> t_2 = tf.constant([False, False, True]) 90 | >>> t_3 = tf.constant([True, False, False]) 91 | >>> full_mask, indices = f_cond(t_1, t_2, t_3) 92 | >>> print(f"{full_mask=}\n{indices=}") 93 | full_mask= 94 | indices= 95 | 96 | """ 97 | allowed_conditions = {"and": tf.math.logical_and, "or": tf.math.logical_or} 98 | allo = list(allowed_conditions.keys()) 99 | 100 | # Check that the user is not asking for anything weird 101 | if n_mask < 2: 102 | raise ValueError("At least two masks needed to generate a wrapper") 103 | 104 | if isinstance(condition, str): 105 | if condition not in allowed_conditions: 106 | raise ValueError(f"Wrong condition {condition}, allowed values are {allo}") 107 | is_list = False 108 | else: 109 | if len(condition) != n_mask - 1: 110 | raise ValueError(f"Wrong number of conditions for {n_mask} masks: {len(condition)}") 111 | for cond in condition: 112 | if cond not in allowed_conditions: 113 | raise ValueError(f"Wrong condition {cond}, allowed values are {allo}") 114 | is_list = True 115 | 116 | def py_condition(*masks): 117 | """Receives a list of conditions and returns a result mask 118 | and the list of indices in which the result mask is True 119 | 120 | Returns 121 | ------- 122 | res: tf.bool 123 | Mask that combines all masks 124 | indices: tf.int 125 | Indices in which ``res`` is True 126 | """ 127 | if is_list: 128 | res = masks[0] 129 | for i, cond in enumerate(condition): 130 | res = allowed_conditions[cond](res, masks[i + 1]) 131 | elif condition == "and": 132 | res = tf.math.reduce_all(masks, axis=0) 133 | elif condition == "or": 134 | res = tf.math.reduce_any(masks, axis=0) 135 | indices = int_me(tf.where(res)) 136 | return res, indices 137 | 138 | signature = n_mask * [tf.TensorSpec(shape=[None], dtype=tf.bool)] 139 | 140 | condition_to_idx = tf.function(py_condition, input_signature=signature) 141 | return condition_to_idx 142 | -------------------------------------------------------------------------------- /doc/source/intalg.rst: -------------------------------------------------------------------------------- 1 | .. _intalg-label: 2 | 3 | ====================== 4 | Integration algorithms 5 | ====================== 6 | 7 | This page lists the integration algorithms currently implemented. 8 | 9 | .. contents:: 10 | :local: 11 | :depth: 1 12 | 13 | .. _vegas-label: 14 | 15 | VegasFlow 16 | ========= 17 | 18 | Overview 19 | ^^^^^^^^ 20 | 21 | This implementation of the Vegas algorithm closely follow the description of the importance sampling in the original `Vegas `_ paper. 22 | 23 | An integration with the Vegas algorithm can be performed using the ``VegasFlow`` class. 24 | Initializing the integrator requires to provide a number of dimensions with which initialize the grid and a target number of calls per iterations. 25 | 26 | .. code-block:: python 27 | 28 | from vegasflow import VegasFlow 29 | dims = 4 30 | n_calls = int(1e6) 31 | vegas_instance = VegasFlow(dims, n_calls) 32 | 33 | Once that is generated it is possible to register an integrand by calling the ``compile`` method. 34 | 35 | .. code-block:: python 36 | 37 | def example_integrand(x, **kwargs): 38 | y = 0.0 39 | for d in range(dims): 40 | y += x[:,d] 41 | return y 42 | 43 | vegas_instance.compile(example_integrand) 44 | 45 | Once this process has been performed we can start computing the result by simply calling the ``run_integration`` method to which we need to provided a number of iterations. 46 | After each iteration the grid will be refined, producing more points (and hence reducing the error) in the regions where the integrand is larger. 47 | 48 | .. code-block:: python 49 | 50 | n_iter = 3 51 | result = vegas_instance.run_integration(n_iter) 52 | 53 | The output variable, in this example named ``result``, is a tuple variable where the first element is the result of the integration while the second element is the error of the integration. 54 | 55 | Integration Wrapper 56 | ^^^^^^^^^^^^^^^^^^^ 57 | 58 | Although manually instantiating the integrator allows for a better fine-grained control 59 | of the integration, it is also possible to use wrappers which automatically do most of the work 60 | behind the scenes. 61 | 62 | .. code-block:: python 63 | 64 | from vegasflow import vegas_wrapper 65 | n_iter = 5 66 | result = vegas_wrapper(example_integrand, dims, n_iter, n_calls) 67 | 68 | Grid freezing 69 | ^^^^^^^^^^^^^ 70 | 71 | It is often useful to freeze the grid to compute the integration several times with a frozen grid, in order to do that we provide the ``freeze_grid`` method. Note that freezing the grid forces a recompilation of the integrand which means the first iteration after freezing can potentially be slow, after which it will become much faster as before as the part of the graph dedicated to the adjusting of the grid is dropped. 72 | 73 | .. code-block:: python 74 | 75 | vegas_instance.freeze_grid() 76 | 77 | 78 | Saving and loading a grid 79 | ^^^^^^^^^^^^^^^^^^^^^^^^^ 80 | 81 | On a related note, it is possible to save and load the grid from and to json files, in order to do that we can use the ``save_grid`` and ``load_grid`` methods at any point in the calculation. 82 | Note, however, that loading a new grid will destroy the current grid. 83 | 84 | .. code-block:: python 85 | 86 | json_file = "my_grid.json" 87 | vegas_instance.save_grid(json_file) 88 | vegas_instance.load_grid(json_file) 89 | 90 | .. autoclass:: vegasflow.vflow.VegasFlow 91 | :noindex: 92 | :show-inheritance: 93 | :members: freeze_grid, unfreeze_grid, save_grid, load_grid 94 | 95 | 96 | VegasFlowPlus 97 | ============= 98 | 99 | Overview 100 | ^^^^^^^^ 101 | While ``VegasFlow`` is limited to the importance sampling algorithm, 102 | ``VegasFlowPlus`` includes the latest version of importance plus adaptive stratified sampling 103 | from Lepage's `latest paper `_. 104 | 105 | The usage and interfaces exposed by ``VegasFlowPlus`` are equivalent to those 106 | of ``VegasFlow``: 107 | 108 | 109 | .. code-block:: python 110 | 111 | from vegasflow import VegasFlowPlus 112 | dims = 4 113 | n_calls = int(1e6) 114 | vegas_instance = VegasFlowPlus(dims, n_calls) 115 | 116 | def example_integrand(x, **kwargs): 117 | y = 0.0 118 | for d in range(dims): 119 | y += x[:,d] 120 | return y 121 | 122 | vegas_instance.compile(example_integrand) 123 | 124 | n_iter = 3 125 | result = vegas_instance.run_integration(n_iter) 126 | 127 | 128 | As it can be seen, the only change has been to substitute the ``VegasFlow`` class 129 | with ``VegasFlowPlus``. 130 | 131 | .. note:: ``VegasFlowPlus`` does not support multi-device running, as it cannot break the integration in several pieces, an issue tracked at `#78 `_. 132 | 133 | Integration Wrapper 134 | ^^^^^^^^^^^^^^^^^^^ 135 | 136 | A wrapper is also provided for simplicity: 137 | 138 | .. code-block:: python 139 | 140 | from vegasflow import vegasflowplus_wrapper 141 | n_iter = 5 142 | result = vegasflowplus_wrapper(example_integrand, dims, n_iter, n_calls) 143 | 144 | 145 | PlainFlow 146 | ========= 147 | 148 | Overview 149 | ^^^^^^^^ 150 | 151 | We provide a very rudimentary Monte Carlo integrator which we name PlainFlow. 152 | This provides a easy example on how to implement a new integration algorithm. 153 | 154 | The usage pattern is similar to :ref:`vegas-label`. 155 | 156 | .. code-block:: python 157 | 158 | from vegasflow import PlainFlow 159 | plain_instance = PlainFlow(dims, n_calls) 160 | plain_instance.compile(example_integrand) 161 | plain_instance.run_integration(n_iter) 162 | 163 | Integration Wrapper 164 | ^^^^^^^^^^^^^^^^^^^ 165 | 166 | An integration wrapper is also provided as ``vegasflow.plain_wrapper``. 167 | 168 | .. code-block:: python 169 | 170 | from vegasflow import plain_wrapper 171 | result = plain_wrapper(example_integrand, dims, n_iter, n_calls) 172 | 173 | 174 | .. autoclass:: vegasflow.plain.PlainFlow 175 | :noindex: 176 | :show-inheritance: 177 | -------------------------------------------------------------------------------- /examples/example_pineappl.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import random as rn 4 | from multiprocessing.pool import ThreadPool as Pool 5 | from functools import partial 6 | from vegasflow.configflow import DTYPE, MAX_EVENTS_LIMIT, run_eager 7 | 8 | run_eager(True) 9 | from pdfflow.pflow import mkPDF 10 | import pineappl 11 | from vegasflow.utils import generate_condition_function 12 | from vegasflow.vflow import VegasFlow 13 | import time 14 | import numpy as np 15 | import tensorflow as tf 16 | 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--ncalls', default=np.int64(10000000), type=np.int64, help='Number of calls.') 20 | parser.add_argument('--pineappl', action="store_true", help='Enable pineappl fill grid.') 21 | args = parser.parse_args() 22 | 23 | 24 | # Seed everything seedable 25 | seed = 7 26 | np.random.seed(seed) 27 | rn.seed(seed + 1) 28 | tf.random.set_seed(seed + 2) 29 | 30 | 31 | # configuration 32 | dim = 3 33 | ncalls = args.ncalls 34 | n_iter = 3 35 | events_limit = MAX_EVENTS_LIMIT 36 | 37 | # Constants in GeV^2 pbarn 38 | hbarc2 = tf.constant(389379372.1, dtype=DTYPE) 39 | alpha0 = tf.constant(1.0 / 137.03599911, dtype=DTYPE) 40 | cuts = generate_condition_function(6, condition='and') 41 | 42 | 43 | @tf.function 44 | def int_photo(s, t, u): 45 | return alpha0 * alpha0 / 2.0 / s * (t / u + u / t) 46 | 47 | 48 | @tf.function 49 | def hadronic_pspgen(xarr, mmin, mmax): 50 | smin = mmin * mmin 51 | smax = mmax * mmax 52 | 53 | r1 = xarr[:, 0] 54 | r2 = xarr[:, 1] 55 | r3 = xarr[:, 2] 56 | 57 | tau0 = smin / smax 58 | tau = tf.pow(tau0, r1) 59 | y = tf.pow(tau, 1.0 - r2) 60 | x1 = y 61 | x2 = tau / y 62 | s = tau * smax 63 | 64 | jacobian = np.log(tau0) * np.log(tau0) * tau * r1 65 | 66 | # theta integration (in the CMS) 67 | cos_theta = 2.0 * r3 - 1.0 68 | jacobian *= 2.0 69 | 70 | t = -0.5 * s * (1.0 - cos_theta) 71 | u = -0.5 * s * (1.0 + cos_theta) 72 | 73 | # phi integration 74 | jacobian *= 2.0 * np.acos(-1.0) 75 | 76 | return s, t, u, x1, x2, jacobian 77 | 78 | 79 | def fill(grid, x1, x2, q2, yll, weight): 80 | zeros = np.zeros(len(weight), dtype=np.uintp) 81 | grid.fill_array(x1, x2, q2, zeros, yll, zeros, weight) 82 | 83 | 84 | def fill_grid(xarr, weight=1, **kwargs): 85 | s, t, u, x1, x2, jacobian = hadronic_pspgen(xarr, 10.0, 7000.0) 86 | 87 | ptl = tf.sqrt((t * u / s)) 88 | mll = tf.sqrt(s) 89 | yll = 0.5 * tf.math.log(x1 / x2) 90 | ylp = tf.abs(yll + tf.math.acosh(0.5 * mll / ptl)) 91 | ylm = tf.abs(yll - tf.math.acosh(0.5 * mll / ptl)) 92 | 93 | jacobian *= hbarc2 94 | 95 | # apply cuts 96 | t_1 = ptl >= 14.0 97 | t_2 = tf.abs(yll) <= 2.4 98 | t_3 = ylp <= 2.4 99 | t_4 = ylm <= 2.4 100 | t_5 = mll >= 60.0 101 | t_6 = mll <= 120.0 102 | full_mask, indices = cuts(t_1, t_2, t_3, t_4, t_5, t_6) 103 | 104 | wgt = tf.boolean_mask(jacobian * int_photo(s, u, t), full_mask, axis=0) 105 | x1 = tf.boolean_mask(x1, full_mask, axis=0) 106 | x2 = tf.boolean_mask(x2, full_mask, axis=0) 107 | yll = tf.boolean_mask(yll, full_mask, axis=0) 108 | vweight = wgt * tf.boolean_mask(weight, full_mask, axis=0) 109 | 110 | # This is a very convoluted way of doing an operation on the data during a computation 111 | # another solution is to send the pool with `py_function` like in the `multiple_integrals.py` example 112 | if kwargs.get('fill_pineappl'): 113 | q2 = 90.0 * 90.0 * tf.ones(weight.shape, dtype=tf.float64) 114 | kwargs.get('pool').apply_async(fill, [kwargs.get('grid'), x1.numpy(), x2.numpy(), 115 | q2.numpy(), tf.abs(yll).numpy(), vweight.numpy()]) 116 | 117 | return tf.scatter_nd(indices, wgt, shape=xarr.shape[0:1]) 118 | 119 | 120 | if __name__ == "__main__": 121 | start = time.time() 122 | 123 | grid = None 124 | pool = Pool(processes=1) 125 | 126 | if args.pineappl: 127 | lumi = [(22, 22, 1.0)] 128 | pine_lumi = [pineappl.lumi.LumiEntry(lumi)] 129 | pine_orders = [pineappl.grid.Order(0, 2, 0, 0)] 130 | pine_params = pineappl.subgrid.SubgridParams() 131 | bins = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 132 | 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0, 2.1, 2.2, 2.3, 2.4] 133 | # Initialize the grid 134 | # only LO $\alpha_\mathrm{s}^0 \alpha^2 \log^0(\xi_\mathrm{R}) \log^0(\xi_\mathrm{F})$ 135 | grid = pineappl.grid.Grid.create(pine_lumi, pine_orders, bins, pine_params) 136 | else: 137 | print("pineappl not active, use --pineappl") 138 | 139 | # fill the grid with phase-space points 140 | print('Generating events, please wait...') 141 | 142 | print(f"VEGAS MC, ncalls={ncalls}:") 143 | mc_instance = VegasFlow(dim, ncalls, events_limit=events_limit) 144 | mc_instance.compile(partial(fill_grid, fill_pineappl=False, grid=grid, pool=pool)) 145 | mc_instance.run_integration(n_iter) 146 | mc_instance.compile(partial(fill_grid, fill_pineappl=args.pineappl, grid=grid, pool=pool)) 147 | mc_instance.freeze_grid() 148 | mc_instance.run_integration(1) 149 | end = time.time() 150 | print(f"Vegas took: time (s): {end-start}") 151 | 152 | # wait until pineappl has filled the grids properly 153 | pool.close() 154 | pool.join() 155 | end = time.time() 156 | print(f"Pool took: time (s): {end-start}") 157 | 158 | if args.pineappl: 159 | # write the grid to disk 160 | filename = 'DY-LO-AA.pineappl' 161 | print(f'Writing PineAPPL grid to disk: {filename}') 162 | grid.write(filename) 163 | 164 | # check convolution 165 | # load pdf for testing 166 | pdf = mkPDF('NNPDF31_nlo_as_0118_luxqed/0') 167 | 168 | def xfx(id, x, q2, p): 169 | return pdf.py_xfxQ2([id], [x], [q2]) 170 | 171 | def alphas(q2, p): 172 | return pdf.py_alphasQ2([q2]) 173 | 174 | # perform convolution 175 | dxsec = grid.convolute_with_one(2212, xfx, alphas) 176 | for i in range(len(dxsec)): 177 | print(f'{bins[i]:.1f} {bins[i + 1]:.1f} {dxsec[i]:.3e}') 178 | 179 | end = time.time() 180 | print(f"Total time (s): {end-start}") 181 | -------------------------------------------------------------------------------- /doc/source/index.rst: -------------------------------------------------------------------------------- 1 | .. title:: 2 | vegasflow's documentation! 3 | 4 | 5 | ================================================================================= 6 | VegasFlow: accelerating Monte Carlo simulation across multiple hardware platforms 7 | ================================================================================= 8 | 9 | .. image:: https://img.shields.io/badge/j.%20Computer%20Physics%20Communication-2020%2F107376-blue 10 | :target: https://doi.org/10.1016/j.cpc.2020.107376 11 | 12 | .. image:: https://img.shields.io/badge/physics.comp--ph-arXiv%3A2002.12921-B31B1B 13 | :target: https://arxiv.org/abs/2002.12921 14 | 15 | .. image:: https://zenodo.org/badge/DOI/10.5281/zenodo.3691926.svg 16 | :target: https://doi.org/10.5281/zenodo.3691926 17 | 18 | .. contents:: 19 | :local: 20 | :depth: 1 21 | 22 | VegasFlow is a `Monte Carlo integration `_ library written in Python and based on the `TensorFlow `_ framework. 23 | It is developed with a focus on speed and efficiency, enabling researchers to perform very expensive calculation as quick and easy as possible. 24 | 25 | Some of the key features of VegasFlow are: 26 | 27 | * Integrates efficiently high dimensional functions on single (multi-threading) and multi CPU, single and multi GPU, many GPUs or clusters. 28 | * Compatible with Python, C, C++ or Fortran. 29 | * Implementation of different Monte Carlo algorithms. 30 | 31 | How to obtain the code 32 | ====================== 33 | 34 | Open Source 35 | ----------- 36 | The ``vegasflow`` package is open source and available at https://github.com/N3PDF/vegasflow 37 | 38 | Installation 39 | ------------ 40 | The package can be installed with pip: 41 | 42 | .. code-block:: bash 43 | 44 | python3 -m pip install vegasflow 45 | 46 | If you prefer a manual installation just use: 47 | 48 | .. code-block:: bash 49 | 50 | git clone https://github.com/N3PDF/vegasflow 51 | cd vegasflow 52 | python3 setup.py install 53 | 54 | or if you are planning to extend or develop code just use: 55 | 56 | .. code-block:: bash 57 | 58 | python3 setup.py develop 59 | 60 | It is also possible to install the package from repositories such as `conda-forge `_ or the `Arch User Repository `_ 61 | 62 | .. code-block:: bash 63 | 64 | conda install vegasflow -c conda-forge 65 | yay -S python-vegasflow 66 | 67 | Motivation 68 | ========== 69 | 70 | VegasFlow is developed within the Particle Physics group of the University of Milan. 71 | Theoretical calculations in particle physics are incredibly time consuming operations, sometimes taking months in big clusters all around the world. 72 | 73 | These expensive calculations are driven by the high dimensional phase space that need to be integrated but also by a lack of expertise in new techniques on high performance computation. 74 | Indeed, while at the theoretical level these are some of the most complicated calculations performed by mankind; at the technical level most of these calculations are performed using very dated code and methodologies that are unable to make us of the available resources. 75 | 76 | With VegasFlow we aim to fill this gap between theoretical calculations and technical performance by providing a framework which can automatically make the best of the machine in which it runs. 77 | To that end VegasFlow is based on two technologies that together will enable a new age of research. 78 | 79 | 80 | How to cite ``vegaflow``? 81 | ========================= 82 | 83 | When using ``vegasflow`` in your research, please cite the following publications: 84 | 85 | .. image:: https://img.shields.io/badge/j.%20Computer%20Physics%20Communication-2020%2F107376-blue 86 | :target: https://doi.org/10.1016/j.cpc.2020.107376 87 | 88 | 89 | .. image:: https://img.shields.io/badge/arXiv-physics.comp--ph%2F%20%20%20%202002.12921-%23B31B1B 90 | :target: https://arxiv.org/abs/2002.12921 91 | 92 | 93 | .. image:: https://zenodo.org/badge/DOI/10.5281/zenodo.3691926.svg 94 | :target: https://doi.org/10.5281/zenodo.3691926 95 | 96 | Bibtex: 97 | 98 | .. code-block:: latex 99 | 100 | @article{Carrazza:2020rdn, 101 | author = "Carrazza, Stefano and Cruz-Martinez, Juan M.", 102 | title = "{VegasFlow: accelerating Monte Carlo simulation across multiple hardware platforms}", 103 | eprint = "2002.12921", 104 | archivePrefix = "arXiv", 105 | primaryClass = "physics.comp-ph", 106 | reportNumber = "TIF-UNIMI-2020-8", 107 | doi = "10.1016/j.cpc.2020.107376", 108 | journal = "Comput. Phys. Commun.", 109 | volume = "254", 110 | pages = "107376", 111 | year = "2020" 112 | } 113 | 114 | 115 | @software{vegasflow_package, 116 | author = {Juan Cruz-Martinez and 117 | Stefano Carrazza}, 118 | title = {N3PDF/vegasflow: vegasflow v1.0}, 119 | month = feb, 120 | year = 2020, 121 | publisher = {Zenodo}, 122 | version = {v1.0}, 123 | doi = {10.5281/zenodo.3691926}, 124 | url = {https://doi.org/10.5281/zenodo.3691926} 125 | } 126 | 127 | FAQ 128 | === 129 | 130 | Why the name ``VegasFlow``? 131 | --------------------------- 132 | 133 | It is a combination of the names `Vegas` and `Tensorflow`. 134 | 135 | - **Vegas**: this integration algorithm, created originally by `G.P. Lepage `_ sits at the core of many of the most advanced calculations in High Energy Physics, it powers `Madgraph _`, `MCFM `_ or `Sherpa `_ among others. Lepage's own implementation is available in `github `_. 136 | 137 | - **TensorFlow**: the `tensorflow `_ is developed by Google and was made public in November of 2015. It is a perfect combination between performance and usability. With a focus on Deep Learning, TensorFlow provides an algebra library able to easily run operations in many different devices: CPUs, GPUs, TPUs with little input by the developer. 138 | 139 | I have a problem I can't solve 140 | ------------------------------ 141 | Please, `open an issue `_ in the github repository 142 | or `check `_ whether someone has already asked the same question. 143 | We will be happy to help. 144 | 145 | 146 | Indices and tables 147 | ================== 148 | 149 | .. toctree:: 150 | :maxdepth: 3 151 | :glob: 152 | :caption: Contents: 153 | 154 | VegasFlow 155 | how_to 156 | intalg 157 | examples 158 | apisrc/vegasflow 159 | 160 | 161 | * :ref:`genindex` 162 | * :ref:`modindex` 163 | * :ref:`search` 164 | -------------------------------------------------------------------------------- /examples/drellyan_lo_tf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example: inntegration of Drell-Yan process 3 | """ 4 | 5 | from vegasflow.configflow import DTYPE, DTYPEINT 6 | import time 7 | import numpy as np 8 | import tensorflow as tf 9 | from vegasflow.vflow import vegas_wrapper 10 | 11 | # MC integration setup 12 | dim = 4 13 | ncalls = np.int32(1e6) 14 | n_iter = 5 15 | 16 | # Physics setup 17 | # center of mass energy 18 | sqrts = tf.constant(14000, dtype=DTYPE) 19 | 20 | 21 | # auxiliary variables 22 | s = tf.square(sqrts) 23 | s2 = tf.square(s) 24 | conv = tf.constant(0.3893793e9, dtype=DTYPE) # GeV to pb conversion 25 | 26 | 27 | @tf.function 28 | def get_x1x2(xarr): 29 | """Remapping [0,1] to kappa-y""" 30 | kappa = xarr[:, 0] 31 | y = xarr[:, 1] 32 | logkappa = tf.math.log(kappa) 33 | sqrtkappa = tf.sqrt(kappa) 34 | Ycm = tf.exp(logkappa*(y - 0.5)) 35 | 36 | shat = s 37 | x1 = sqrtkappa * Ycm 38 | x2 = sqrtkappa / Ycm 39 | jac = tf.abs(logkappa) 40 | 41 | return shat, jac, x1, x2 42 | 43 | 44 | @tf.function 45 | def make_event(xarr): 46 | """Generate event kinematics""" 47 | shat, jac, x1, x2 = get_x1x2(xarr) 48 | 49 | mV = tf.sqrt(shat * x1 * x2) 50 | mV2 = mV*mV 51 | ecmo2 = mV/2 52 | zeros = tf.zeros_like(ecmo2, dtype=DTYPE) 53 | 54 | p0 = tf.stack([ecmo2, zeros, zeros, ecmo2]) 55 | p1 = tf.stack([ecmo2, zeros, zeros,-ecmo2]) 56 | 57 | pV = p0 + p1 58 | YV = 0.5 * tf.math.log(tf.abs((pV[0] + pV[3])/(pV[0] - pV[3]))) 59 | pVt2 = tf.square(pV[1]) + tf.square(pV[2]) 60 | phi = 2 * np.pi * xarr[:, 3] 61 | ptmax = 0.5 * mV2 / (tf.sqrt(mV2 + pVt2) - (pV[1]*tf.cos(phi) + pV[2]*tf.sin(phi))) 62 | pta = ptmax * xarr[:, 2] 63 | pt = tf.stack([zeros, pta*tf.cos(phi), pta*tf.sin(phi), zeros]) 64 | Delta = (mV2 + 2 * (pV[1]*pt[1] + pV[2]*pt[2]))/2.0/pta/tf.sqrt(mV2 + pVt2) 65 | y = YV - tf.acosh(Delta) 66 | kallenF = 2.0 * ptmax/tf.sqrt(mV2 + pVt2)/tf.abs(tf.sinh(YV-y)) 67 | 68 | p2 = tf.stack([pta*tf.cosh(y), pta*tf.cos(phi), pta*tf.sin(phi),pta*tf.sinh(y)]) 69 | p3 = pV - p2 70 | 71 | psw = 1 / (8*np.pi)* kallenF # psw 72 | psw *= jac # jac for tau, y 73 | flux = 1 / (2 * mV2) # flux 74 | 75 | return psw, flux, p0, p1, p2, p3, x1, x2 76 | 77 | 78 | @tf.function 79 | def dot(p1, p2): 80 | """Dot product 4-momenta""" 81 | e = p1[0]*p2[0] 82 | px = p1[1]*p2[1] 83 | py = p1[2]*p2[2] 84 | pz = p1[3]*p2[3] 85 | return e - px - py - pz 86 | 87 | 88 | @tf.function 89 | def u0(p, i): 90 | """Compute the dirac spinor u0""" 91 | 92 | zeros = tf.zeros_like(p[0], dtype=DTYPE) 93 | czeros = tf.complex(zeros, zeros) 94 | ones = tf.ones_like(p[0], dtype=DTYPE) 95 | 96 | # case 1) py == 0 97 | rz = p[3]/p[0] 98 | theta1 = tf.where(rz > 0, zeros, rz) 99 | theta1 = tf.where(rz < 0, np.pi*ones, theta1) 100 | phi1 = zeros 101 | 102 | # case 2) py != 0 103 | rrr = rz 104 | rrr = tf.where(rz < -1, -ones, rz) 105 | rrr = tf.where(rz > 1, ones, rrr) 106 | theta2 = tf.acos(rrr) 107 | rx = p[1]/p[0]/tf.sin(theta2) 108 | rrr = tf.where(rx < -1, -ones, rx) 109 | rrr = tf.where(rx > 1, ones, rrr) 110 | phi2 = tf.acos(rrr) 111 | ry = p[2]/p[0] 112 | phi2 = tf.where(ry < 0, -phi2, phi2) 113 | 114 | # combine 115 | theta = tf.where(p[1] == 0, theta1, theta2) 116 | phi = tf.where(p[1] == 0, phi1, phi2) 117 | 118 | prefact = tf.complex(np.sqrt(2), zeros)*tf.sqrt(tf.complex(p[0], zeros)) 119 | if i == 1: 120 | a = tf.complex(tf.cos(theta/2), zeros) 121 | b = tf.complex(tf.sin(theta/2), zeros) 122 | u0_0 = prefact*a 123 | u0_1 = prefact*b*tf.complex(tf.cos(phi), tf.sin(phi)) 124 | u0_2 = czeros 125 | u0_3 = czeros 126 | else: 127 | a = tf.complex(tf.sin(theta/2), zeros) 128 | b = tf.complex(tf.cos(theta/2), zeros) 129 | u0_0 = czeros 130 | u0_1 = czeros 131 | u0_2 = prefact*a*tf.complex(tf.cos(phi), -tf.sin(phi)) 132 | u0_3 = -prefact*b 133 | 134 | return tf.stack([u0_0, u0_1, u0_2, u0_3]) 135 | 136 | 137 | @tf.function 138 | def ubar0(p, i): 139 | """Compute the dirac spinor ubar0""" 140 | 141 | zeros = tf.zeros_like(p[0], dtype=DTYPE) 142 | czeros = tf.complex(zeros, zeros) 143 | ones = tf.ones_like(p[0], dtype=DTYPE) 144 | 145 | # case 1) py == 0 146 | rz = p[3]/p[0] 147 | theta1 = tf.where(rz > 0, zeros, rz) 148 | theta1 = tf.where(rz < 0, np.pi*ones, theta1) 149 | phi1 = zeros 150 | 151 | # case 2) py != 0 152 | rrr = rz 153 | rrr = tf.where(rz < -1, -ones, rrr) 154 | rrr = tf.where(rz > 1, ones, rrr) 155 | theta2 = tf.acos(rrr) 156 | rrr = p[1]/p[0]/tf.sin(theta2) 157 | rrr = tf.where(rrr < -1, -ones, rrr) 158 | rrr = tf.where(rrr > 1, ones, rrr) 159 | phi2 = tf.acos(rrr) 160 | ry = p[2]/p[0] 161 | phi2 = tf.where(ry < 0, -phi2, phi2) 162 | 163 | # combine 164 | theta = tf.where(p[1] == 0, theta1, theta2) 165 | phi = tf.where(p[1] == 0, phi1, phi2) 166 | 167 | prefact = tf.complex(np.sqrt(2), zeros)*tf.sqrt(tf.complex(p[0], zeros)) 168 | if i == -1: 169 | a = tf.complex(tf.sin(theta/2), zeros) 170 | b = tf.complex(tf.abs(tf.cos(theta/2)), zeros) 171 | ubar0_0 = prefact*a*tf.complex(tf.cos(phi), tf.sin(phi)) 172 | ubar0_1 = -prefact*b 173 | ubar0_2 = czeros 174 | ubar0_3 = czeros 175 | else: 176 | a = tf.complex(tf.cos(theta/2), zeros) 177 | b = tf.complex(tf.sin(theta/2), zeros) 178 | ubar0_0 = czeros 179 | ubar0_1 = czeros 180 | ubar0_2 = prefact*a 181 | ubar0_3 = prefact*b*tf.complex(tf.cos(phi), -tf.sin(phi)) 182 | 183 | return tf.stack([ubar0_0, ubar0_1, ubar0_2, ubar0_3]) 184 | 185 | 186 | @tf.function 187 | def za(p1, p2): 188 | ket = u0(p2, 1) 189 | bra = ubar0(p1, -1) 190 | return tf.reduce_sum(bra*ket, axis=0) 191 | 192 | 193 | @tf.function 194 | def zb(p1, p2): 195 | ket = u0(p2, -1) 196 | bra = ubar0(p1, 1) 197 | return tf.reduce_sum(bra*ket, axis=0) 198 | 199 | 200 | @tf.function 201 | def sprod(p1, p2): 202 | a = za(p1, p2) 203 | b = zb(p2, p1) 204 | return tf.math.real(a*b) 205 | 206 | 207 | @tf.function 208 | def qqxllx(p0, p1, p2, p3): 209 | """Evaluate 0 -> qqbarttbar""" 210 | lsprod = sprod(p0, p1) 211 | a = 2 * tf.abs(za(p0, p2) * zb(p3, p1)) / lsprod 212 | b = 2 * tf.abs(za(p0, p3) * zb(p2, p1)) / lsprod 213 | return 6.0 * (tf.square(a)+tf.square(b)) / 36.0 214 | 215 | 216 | @tf.function 217 | def evaluate_matrix_element_square(p0, p1, p2, p3): 218 | """Evaluate Matrix Element square""" 219 | 220 | # channels evaluation 221 | c1 = qqxllx(-p1,-p0, p2, p3) # QQBARLLBAR 222 | 223 | return c1 224 | 225 | 226 | @tf.function 227 | def pdf(fl1, fl2, x1, x2): 228 | """Dummy toy PDF""" 229 | return x1*x2 230 | 231 | 232 | @tf.function 233 | def build_luminosity(x1, x2): 234 | """Single-top t-channel luminosity""" 235 | lumi = ( 236 | pdf(1, -1, x1, x2) + pdf(2, -2, x1, x2) + 237 | pdf(3, -3, x1, x2) + pdf(4, -4, x1, x2) 238 | ) / x1 / x2 239 | return lumi 240 | 241 | 242 | @tf.function 243 | def drellyan(xarr, weight=None, **kwargs): 244 | """Single-top (t-channel) at LO""" 245 | psw, flux, p0, p1, p2, p3, x1, x2 = make_event(xarr) 246 | wgts = evaluate_matrix_element_square(p0, p1, p2, p3) 247 | lumis = build_luminosity(x1, x2) 248 | lumi_me2 = 2*lumis*wgts 249 | return lumi_me2*psw*flux*conv 250 | 251 | 252 | if __name__ == "__main__": 253 | """Testing a basic integration""" 254 | print(f"VEGAS MC, ncalls={ncalls}:") 255 | start = time.time() 256 | r = vegas_wrapper(drellyan, dim, n_iter, ncalls) 257 | end = time.time() 258 | print(f"time (s): {end-start}") 259 | -------------------------------------------------------------------------------- /doc/source/examples.rst: -------------------------------------------------------------------------------- 1 | .. _examples-label: 2 | 3 | ========== 4 | Examples 5 | ========== 6 | 7 | In the ``VegasFlow`` repository you can find `several examples `_ 8 | of integrands which can hopefully help you to quickstart your project. 9 | 10 | In this page we explain in more detail some of these examples. 11 | You can find the full code in the repository alongside more complicated versions. 12 | 13 | 14 | .. contents:: 15 | :local: 16 | :depth: 1 17 | 18 | 19 | Basic Integral 20 | ============== 21 | 22 | The most general usage of ``Vegasflow`` is the integration of a tensorflow-based integrand. 23 | 24 | .. code-block:: python 25 | 26 | from vegasflow import vegas_wrapper 27 | import tensorflow as tf 28 | 29 | @tf.function 30 | def my_integrand(xarr, **kwargs): 31 | return tf.reduce_sum(xarr, axis=1) 32 | 33 | n_dim = 10 34 | n_events = int(1e6) 35 | n_iter = 5 36 | result = vegas_wrapper(my_integrand, n_dim, n_iter, n_events) 37 | 38 | 39 | You can find a `runnable example of such a basic example in the repository `_. 40 | 41 | 42 | Using VegasFlow as a clever Random Number Generator 43 | =================================================== 44 | 45 | A possible use case for ``VegasFlow`` is to have an function that we don't necessarily 46 | want to integrate, but that we want to sample from. 47 | In general, sampling from a function can be very complicated and instead we just want 48 | to *approximately* sample from it. For that we can use the ``vegasflow`` package 49 | which instantly give access to all integrator algorithms as function approximators. 50 | 51 | In the example below we use importance sampling, from ``VegasFlow`` to approximate 52 | the function and sample from it, but the same code will work for any of the 53 | other implemented integrators. 54 | 55 | .. code-block:: python 56 | 57 | from vegasflow import VegasFlow, run_eager 58 | import tensorflow as tf 59 | 60 | run_eager(True) 61 | 62 | def my_complicated_fun(xarr, **kwargs): 63 | return tf.reduce_sum(xarr, axis=1) 64 | 65 | n_dim = 10 66 | n_events = int(1e5) 67 | sampler = VegasFlow(n_dim, n_events, verbose=False) 68 | sampler.compile(my_complicated_fun) 69 | 70 | # Now let's train the integrator for 10 iterations 71 | _ = sampler.run_integration(10) 72 | 73 | # Now we can use sampler to generate random numbers 74 | rnds, px = sampler.generate_random_array(100) 75 | 76 | The first object returned by ``generate_random_array`` are the random points, 77 | in the case in the example an array of shape ``(100, 10)``, i.e., the first axis 78 | is the number of requested events and the second axis the number of dimensions. 79 | 80 | Then ``generate_random_array`` returns also the probability distribution 81 | of the random points (i.e., the weight they carry). 82 | 83 | For convenience we include sampler wrappers which directly return a trained 84 | reference to the ``generate_random_array`` method: 85 | 86 | .. code-block:: python 87 | 88 | from vegasflow import vegas_sampler 89 | 90 | sampler = vegas_sampler(my_complicated_fun, n_dim, n_events) 91 | rnds, px = sampler(100) 92 | 93 | 94 | It is possible to change the number of training steps (default 5) or to retrieve 95 | a reference to the sampler class instead to the sampler method by using keyword 96 | arguments. 97 | 98 | .. code-block:: python 99 | 100 | sampler_class = vegas_sampler(my_complicated_fun, n_dim, n_events, training_steps=1, return_class=True) 101 | rnds, px = sampler_class.generate_random_array(100) 102 | 103 | Integrating a numpy function 104 | ============================ 105 | 106 | ``VegasFlow`` admits also the integration of non-tensorflow python-based integrands. 107 | In this case, however, it is necessary to activate ``eager-mode``, see :ref:`eager-label`. 108 | 109 | .. code-block:: python 110 | 111 | import numpy as np 112 | from vegasflow import vegas_wrapper, run_eager 113 | run_eager() 114 | 115 | def my_integrand(xarr, **kwargs): 116 | return np.sum(xarr) 117 | 118 | n_dim = 10 119 | n_events = int(1e6) 120 | n_iter = 5 121 | result = vegas_wrapper(my_integrand, n_dim, n_iter, n_events) 122 | 123 | Note, however, that in this case the integrand will always be run on CPU, while the internal steps of the integration will be run on GPU by ``VegasFlow``. 124 | In order to run the integration exclusively on GPU the environment variable ``CUDA_VISIBLE_DEVICES`` should be set to ``''``: 125 | 126 | .. code-block:: bash 127 | 128 | export CUDA_VISIBLE_DEVICES="" 129 | 130 | Interfacing C code: CFFI 131 | ======================== 132 | 133 | A popular way of interfacing python and C code is to use the 134 | `CFFI library `_. 135 | 136 | Imagine you have a C-file with some integrand: 137 | 138 | .. code-block:: C 139 | 140 | // integrand.c 141 | void integrand(double *xarr, int ndim, int nevents, double *out) { 142 | for (int i = 0; i < nevents; i++) { 143 | out[i] = 0.0; 144 | for (int j = 0; j < ndim; j++) { 145 | out[i] += 2.0*xarr[j+i*ndim]; 146 | } 147 | } 148 | } 149 | 150 | You can compile this code and integrate it (no pun intended) with ``vegasflow`` 151 | by using the CFFI library as you can see in `this `_ example. 152 | 153 | .. code-block:: python 154 | 155 | from vegasflow.configflow import DTYPE 156 | import numpy as np 157 | from vegasflow import vegas_wrapper 158 | 159 | from cffi import FFI 160 | ffibuilder = FFI() 161 | 162 | ffibuilder.cdef("void integrand(double*, int, int, double*);") 163 | 164 | with open("integrand.c", "r") as f: 165 | ffibuilder.set_source("_integrand_cffi", f.read()) 166 | 167 | ffibuilder.compile() 168 | 169 | # Now you can read up the compiled C code as a python library 170 | from _integrand_cffi import ffi, lib 171 | 172 | def integrand(xarr, **kwargs): 173 | n_dim = xarr.shape[-1] 174 | result = np.empty(n_events, dtype=DTYPE.as_numpy_dtype) 175 | x_flat = xarr.numpy().flatten() 176 | 177 | p_input = ffi.cast("double*", ffi.from_buffer(x_flat)) 178 | p_output = ffi.cast("double*", ffi.from_buffer(result)) 179 | 180 | lib.integrand(p_input, n_dim, xarr.shape[0], p_output) 181 | return result 182 | 183 | vegas_wrapper(integrand, 5, 10, int(1e5), compilable=False) 184 | 185 | 186 | Note the usage of the ``compilable=False`` flag. 187 | This signals ``VegasFlow`` that the integrand is not pure tensorflow and 188 | that a graph of the full computation cannot be compiled. 189 | 190 | 191 | Create your own TF-compilable operators 192 | ======================================= 193 | 194 | Tensorflow tries to do its best to compile your integrand to something that can 195 | quickly be evaluated on GPU. 196 | It has no information, however, about specific situations that would allow 197 | for non trivial optimizations. 198 | 199 | In these cases one could want to write its own C++ or Cuda code while still 200 | allowing for Tensorflow to create a full graph of the computation. 201 | 202 | Creating new operations in TF are an advance feature and, when possible, 203 | it is recommended to create your integrand as a composition of TF operators. 204 | If you still want to go ahead we have prepared a `simple example `_ 205 | in the repository that can be used as a template for C++ or Cuda custom 206 | operators. 207 | 208 | The example includes a `makefile `_ that you might need to modify for your particular needs. 209 | 210 | Note that in order to run the code in both GPUs and CPU you will need to provide 211 | a GPU and CPU capable kernels. 212 | -------------------------------------------------------------------------------- /src/vegasflow/tests/test_algs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Checks that the different integation algorithms 3 | are able to run and don't produce a crazy result 4 | """ 5 | 6 | """ Test a run with a simple function to make sure 7 | everything works """ 8 | import json 9 | import tempfile 10 | 11 | import numpy as np 12 | import pytest 13 | import tensorflow as tf 14 | 15 | from vegasflow import plain_sampler, vegas_sampler 16 | from vegasflow.configflow import DTYPE, run_eager 17 | from vegasflow.plain import PlainFlow 18 | from vegasflow.vflow import VegasFlow 19 | from vegasflow.vflowplus import VegasFlowPlus 20 | 21 | # Test setup 22 | dim = 2 23 | ncalls = np.int32(1e4) 24 | n_iter = 4 25 | 26 | 27 | def example_integrand(xarr, weight=None): 28 | """Example function that integrates to 1""" 29 | n_dim = xarr.shape[-1] 30 | a = tf.constant(0.1, dtype=DTYPE) 31 | n100 = tf.cast(100 * n_dim, dtype=DTYPE) 32 | pref = tf.pow(1.0 / a / np.sqrt(np.pi), n_dim) 33 | coef = tf.reduce_sum(tf.range(n100 + 1)) 34 | coef += tf.reduce_sum(tf.square((xarr - 1.0 / 2.0) / a), axis=1) 35 | coef -= (n100 + 1) * n100 / 2.0 36 | return pref * tf.exp(-coef) 37 | 38 | 39 | def instance_and_compile(Integrator, mode=0, integrand_function=example_integrand): 40 | """Wrapper for convenience""" 41 | if mode == 0: 42 | integrand = integrand_function 43 | elif mode == 1: 44 | 45 | def integrand(xarr, n_dim=None): 46 | return integrand_function(xarr) 47 | 48 | elif mode == 2: 49 | 50 | def integrand(xarr): 51 | return integrand_function(xarr) 52 | 53 | elif mode == 3: 54 | 55 | def integrand(xarr, n_dim=None, weight=None): 56 | return integrand_function(xarr, weight=None) 57 | 58 | int_instance = Integrator(dim, ncalls) 59 | int_instance.compile(integrand) 60 | return int_instance 61 | 62 | 63 | def check_is_one(result, sigmas=3, target_result=1.0): 64 | """Wrapper for convenience""" 65 | res = result[0] 66 | err = np.mean(result[1] * sigmas) 67 | # Check that it passes by {sigmas} number of sigmas 68 | np.testing.assert_allclose(res, target_result, atol=err) 69 | 70 | 71 | @pytest.mark.parametrize("mode", range(4)) 72 | def test_VegasFlow(mode): 73 | """Test VegasFlow class, importance sampling algorithm""" 74 | vegas_instance = instance_and_compile(VegasFlow, mode) 75 | _ = vegas_instance.run_integration(n_iter) 76 | vegas_instance.freeze_grid() 77 | result = vegas_instance.run_integration(n_iter) 78 | check_is_one(result) 79 | 80 | 81 | def test_VegasFlow_grid_management(): 82 | vegas_instance = instance_and_compile(VegasFlow, 1) 83 | _ = vegas_instance.run_integration(n_iter) 84 | vegas_instance.freeze_grid() 85 | 86 | # Change the number of events 87 | vegas_instance.n_events = 2 * ncalls 88 | new_result = vegas_instance.run_integration(n_iter) 89 | check_is_one(new_result) 90 | 91 | # Unfreeze the grid 92 | vegas_instance.unfreeze_grid() 93 | new_result = vegas_instance.run_integration(n_iter) 94 | check_is_one(new_result) 95 | 96 | # And change the number of calls again 97 | vegas_instance.n_events = 3 * ncalls 98 | new_result = vegas_instance.run_integration(n_iter) 99 | check_is_one(new_result) 100 | 101 | 102 | def test_VegasFlow_save_grid(): 103 | """Test the grid saving feature of vegasflow""" 104 | tmp_filename = tempfile.mktemp() 105 | vegas_instance = instance_and_compile(VegasFlow) 106 | # Run an iteration so the grid is not trivial 107 | _ = vegas_instance.run_integration(1) 108 | current_grid = vegas_instance.divisions.numpy() 109 | # Save and load the grid from the file 110 | vegas_instance.save_grid(tmp_filename) 111 | with open(tmp_filename, "r") as f: 112 | json_grid = np.array(json.load(f)["grid"]) 113 | np.testing.assert_equal(current_grid, json_grid) 114 | 115 | 116 | def test_VegasFlow_load_grid(): 117 | tmp_filename = tempfile.mktemp() 118 | # Get the information from the vegas_instance 119 | vegas_instance = instance_and_compile(VegasFlow) 120 | grid_shape = vegas_instance.divisions.shape 121 | tmp_grid = np.random.rand(*grid_shape) 122 | # Save into some rudimentary json file 123 | jdict = {"grid": tmp_grid.tolist()} 124 | with open(tmp_filename, "w") as f: 125 | json.dump(jdict, f) 126 | # Try to load it 127 | vegas_instance.load_grid(file_name=tmp_filename) 128 | # Check that the loading did work 129 | loaded_grid = vegas_instance.divisions.numpy() 130 | np.testing.assert_equal(loaded_grid, tmp_grid) 131 | # Now try to load a numpy array directly instead 132 | tmp_grid = np.random.rand(*grid_shape) 133 | vegas_instance.load_grid(numpy_grid=tmp_grid) 134 | loaded_grid = vegas_instance.divisions.numpy() 135 | np.testing.assert_equal(loaded_grid, tmp_grid) 136 | # Now check that the errors also work 137 | jdict["BINS"] = 0 138 | with open(tmp_filename, "w") as f: 139 | json.dump(jdict, f) 140 | # Check that it fails if the number of bins is different 141 | with pytest.raises(ValueError): 142 | vegas_instance.load_grid(file_name=tmp_filename) 143 | # Check that it fails if the number of dimensons is different 144 | jdict["dimensions"] = -4 145 | with open(tmp_filename, "w") as f: 146 | json.dump(jdict, f) 147 | with pytest.raises(ValueError): 148 | vegas_instance.load_grid(file_name=tmp_filename) 149 | 150 | 151 | @pytest.mark.parametrize("mode", range(4)) 152 | def test_PlainFlow(mode): 153 | plain_instance = instance_and_compile(PlainFlow, mode) 154 | result = plain_instance.run_integration(n_iter) 155 | check_is_one(result) 156 | 157 | 158 | def test_PlainFlow_change_nevents(): 159 | plain_instance = instance_and_compile(PlainFlow, 0) 160 | result = plain_instance.run_integration(n_iter) 161 | check_is_one(result) 162 | 163 | plain_instance.n_events = 2 * ncalls 164 | new_result = plain_instance.run_integration(n_iter) 165 | check_is_one(new_result) 166 | 167 | 168 | def helper_rng_tester(sampling_function, n_events): 169 | """Ensure the random number generated have the correct shape 170 | Return the random numbers and the jacobian""" 171 | rnds, px = sampling_function(n_events) 172 | np.testing.assert_equal(rnds.shape, (n_events, dim)) 173 | return rnds, px 174 | 175 | 176 | def test_rng_generation_plain(n_events=100): 177 | """Test the random number generation with plainflow""" 178 | plain_sampler_instance = instance_and_compile(PlainFlow) 179 | _, px = helper_rng_tester(plain_sampler_instance.generate_random_array, n_events) 180 | np.testing.assert_equal(px.numpy(), 1.0 / n_events) 181 | 182 | 183 | def test_rng_generation_vegasflow(n_events=100): 184 | """Test the random number generation with vegasflow""" 185 | vegas_sampler_instance = instance_and_compile(VegasFlow) 186 | # Train a bit the grid 187 | vegas_sampler_instance.run_integration(2) 188 | _, px = helper_rng_tester(vegas_sampler_instance.generate_random_array, n_events) 189 | np.testing.assert_equal(px.shape, (n_events,)) 190 | 191 | 192 | def test_rng_generation_vegasflowplus(n_events=100): 193 | """Test the random number generation with vegasflow""" 194 | vegas_sampler_instance = instance_and_compile(VegasFlowPlus) 195 | # Train a bit the grid 196 | # vegas_sampler_instance.run_integration(2) 197 | _, px = helper_rng_tester(vegas_sampler_instance.generate_random_array, n_events) 198 | np.testing.assert_equal(px.shape, (n_events,)) 199 | 200 | 201 | def test_rng_generation_wrappers(n_events=100): 202 | """Test the wrappers for the samplers""" 203 | p = plain_sampler(example_integrand, dim, n_events, training_steps=2, return_class=True) 204 | _ = helper_rng_tester(p.generate_random_array, n_events) 205 | v = vegas_sampler(example_integrand, dim, n_events, training_steps=2) 206 | _ = helper_rng_tester(v, n_events) 207 | 208 | 209 | @pytest.mark.parametrize("mode", range(4)) 210 | def test_VegasFlowPlus_ADAPTIVE_SAMPLING(mode): 211 | """Test Vegasflow with Adaptive Sampling on (the default)""" 212 | vflowplus_instance = instance_and_compile(VegasFlowPlus, mode) 213 | result = vflowplus_instance.run_integration(n_iter) 214 | check_is_one(result) 215 | 216 | 217 | def test_VegasFlowPlus_NOT_ADAPTIVE_SAMPLING(): 218 | """Test Vegasflow with Adaptive Sampling off (non-default)""" 219 | vflowplus_instance = VegasFlowPlus(dim, ncalls, adaptive=False) 220 | vflowplus_instance.compile(example_integrand) 221 | result = vflowplus_instance.run_integration(n_iter) 222 | check_is_one(result) 223 | -------------------------------------------------------------------------------- /examples/singletop_lo_tf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Example: integration of singletop production 4 | """ 5 | import sys 6 | import time 7 | import numpy as np 8 | from vegasflow.configflow import DTYPE 9 | import tensorflow as tf 10 | from vegasflow.vflow import vegas_wrapper 11 | 12 | # MC integration setup 13 | dim = 3 14 | ncalls = np.int32(1e5) 15 | n_iter = 5 16 | 17 | # Physics setup 18 | # top mass 19 | mt = tf.constant(173.2, dtype=DTYPE) 20 | # center of mass energy 21 | sqrts = tf.constant(8000, dtype=DTYPE) 22 | # minimum allowed center of mass energy 23 | sqrtsmin = tf.constant(173.2, dtype=DTYPE) 24 | # W-boson mass 25 | mw = tf.constant(80.419, dtype=DTYPE) 26 | # gaw 27 | gaw = tf.constant(2.1054, dtype=DTYPE) 28 | # GF 29 | gf = tf.constant(1.16639e-5, dtype=DTYPE) 30 | 31 | 32 | # auxiliary variables 33 | colf_bt = tf.constant(9, dtype=DTYPE) 34 | mt2 = tf.square(mt) 35 | s = tf.square(sqrts) 36 | s2 = tf.square(s) 37 | smin = tf.square(sqrtsmin) 38 | bmax = tf.sqrt(1 - smin / s) 39 | conv = tf.constant(0.3893793e9, dtype=DTYPE) # GeV to pb conversion 40 | gaw2 = tf.square(gaw) 41 | mw2 = tf.square(mw) 42 | gw4 = tf.square(4 * np.sqrt(2) * mw2 * gf) 43 | 44 | 45 | @tf.function 46 | def get_x1x2(xarr): 47 | """Remapping [0,1] to tau-y""" 48 | # building shat 49 | b = bmax * xarr[:, 0] 50 | onemb2 = 1 - tf.square(b) 51 | shat = smin / onemb2 52 | tau = shat / s 53 | 54 | # building rapidity 55 | ymax = -0.5 * tf.math.log(tau) 56 | y = ymax * (2 * xarr[:, 1] - 1) 57 | 58 | # building jacobian 59 | jac = 2 * tau * b * bmax / onemb2 # tau 60 | jac *= 2 * ymax # y 61 | 62 | # building x1 and x2 63 | sqrttau = tf.sqrt(tau) 64 | expy = tf.exp(y) 65 | x1 = sqrttau * expy 66 | x2 = sqrttau / expy 67 | 68 | return shat, jac, x1, x2 69 | 70 | 71 | @tf.function 72 | def make_event(xarr): 73 | """Generate event kinematics""" 74 | shat, jac, x1, x2 = get_x1x2(xarr) 75 | 76 | ecmo2 = tf.sqrt(shat) / 2 77 | cc = ecmo2 * (1 - mt2 / shat) 78 | cos = 1 - 2 * xarr[:, 2] 79 | sinxi = cc * tf.sqrt(1 - cos * cos) 80 | cosxi = cc * cos 81 | zeros = tf.zeros_like(ecmo2, dtype=DTYPE) 82 | 83 | p0 = tf.stack([ecmo2, zeros, zeros, ecmo2]) 84 | p1 = tf.stack([ecmo2, zeros, zeros, -ecmo2]) 85 | p2 = tf.stack([cc, sinxi, zeros, cosxi]) 86 | p3 = tf.stack([tf.sqrt(cc * cc + mt2), -sinxi, zeros, -cosxi]) 87 | 88 | psw = (1 - mt2 / shat) / (8 * np.pi) # psw 89 | psw *= jac # jac for tau, y 90 | flux = 1 / (2 * shat) # flux 91 | 92 | return psw, flux, p0, p1, p2, p3, x1, x2 93 | 94 | 95 | @tf.function 96 | def dot(p1, p2): 97 | """Dot product 4-momenta""" 98 | e = p1[0] * p2[0] 99 | px = p1[1] * p2[1] 100 | py = p1[2] * p2[2] 101 | pz = p1[3] * p2[3] 102 | return e - px - py - pz 103 | 104 | 105 | @tf.function 106 | def u0(p, i): 107 | """Compute the dirac spinor u0""" 108 | 109 | zeros = tf.zeros_like(p[0], dtype=DTYPE) 110 | czeros = tf.complex(zeros, zeros) 111 | ones = tf.ones_like(p[0], dtype=DTYPE) 112 | 113 | # case 1) py == 0 114 | rz = p[3] / p[0] 115 | theta1 = tf.where(rz > 0, zeros, rz) 116 | theta1 = tf.where(rz < 0, np.pi * ones, theta1) 117 | phi1 = zeros 118 | 119 | # case 2) py != 0 120 | rrr = rz 121 | rrr = tf.where(rz < -1, -ones, rz) 122 | rrr = tf.where(rz > 1, ones, rrr) 123 | theta2 = tf.acos(rrr) 124 | rx = p[1] / p[0] 125 | phi2 = zeros 126 | phi2 = tf.where(rx < 0, np.pi * ones, phi2) 127 | 128 | # combine 129 | theta = tf.where(p[1] == 0, theta1, theta2) 130 | phi = tf.where(p[1] == 0, phi1, phi2) 131 | 132 | prefact = tf.complex(np.sqrt(2), zeros) * tf.sqrt(tf.complex(p[0], zeros)) 133 | if i == 1: 134 | a = tf.complex(tf.cos(theta / 2), zeros) 135 | b = tf.complex(tf.sin(theta / 2), zeros) 136 | u0_0 = prefact * a 137 | u0_1 = prefact * b * tf.complex(tf.cos(phi), tf.sin(phi)) 138 | u0_2 = czeros 139 | u0_3 = czeros 140 | else: 141 | a = tf.complex(tf.sin(theta / 2), zeros) 142 | b = tf.complex(tf.cos(theta / 2), zeros) 143 | u0_0 = czeros 144 | u0_1 = czeros 145 | u0_2 = prefact * a * tf.complex(tf.cos(phi), -tf.sin(phi)) 146 | u0_3 = -prefact * b 147 | 148 | return tf.stack([u0_0, u0_1, u0_2, u0_3]) 149 | 150 | 151 | @tf.function 152 | def ubar0(p, i): 153 | """Compute the dirac spinor ubar0""" 154 | 155 | zeros = tf.zeros_like(p[0], dtype=DTYPE) 156 | czeros = tf.complex(zeros, zeros) 157 | ones = tf.ones_like(p[0], dtype=DTYPE) 158 | 159 | # case 1) py == 0 160 | rz = p[3] / p[0] 161 | theta1 = tf.where(rz > 0, zeros, rz) 162 | theta1 = tf.where(rz < 0, np.pi * ones, theta1) 163 | phi1 = zeros 164 | 165 | # case 2) py != 0 166 | rrr = rz 167 | rrr = tf.where(rz < -1, -ones, rrr) 168 | rrr = tf.where(rz > 1, ones, rrr) 169 | theta2 = tf.acos(rrr) 170 | rrr = p[1] / p[0] / tf.sin(theta2) 171 | rrr = tf.where(rrr < -1, -ones, rrr) 172 | rrr = tf.where(rrr > 1, ones, rrr) 173 | phi2 = tf.acos(rrr) 174 | ry = p[2] / p[0] 175 | phi2 = tf.where(ry < 0, -phi2, phi2) 176 | 177 | # combine 178 | theta = tf.where(p[1] == 0, theta1, theta2) 179 | phi = tf.where(p[1] == 0, phi1, phi2) 180 | 181 | prefact = tf.complex(np.sqrt(2), zeros) * tf.sqrt(tf.complex(p[0], zeros)) 182 | if i == -1: 183 | a = tf.complex(tf.sin(theta / 2), zeros) 184 | b = tf.complex(tf.abs(tf.cos(theta / 2)), zeros) 185 | ubar0_0 = prefact * a * tf.complex(tf.cos(phi), tf.sin(phi)) 186 | ubar0_1 = -prefact * b 187 | ubar0_2 = czeros 188 | ubar0_3 = czeros 189 | else: 190 | a = tf.complex(tf.cos(theta / 2), zeros) 191 | b = tf.complex(tf.sin(theta / 2), zeros) 192 | ubar0_0 = czeros 193 | ubar0_1 = czeros 194 | ubar0_2 = prefact * a 195 | ubar0_3 = prefact * b * tf.complex(tf.cos(phi), -tf.sin(phi)) 196 | 197 | return tf.stack([ubar0_0, ubar0_1, ubar0_2, ubar0_3]) 198 | 199 | 200 | @tf.function 201 | def za(p1, p2): 202 | ket = u0(p2, 1) 203 | bra = ubar0(p1, -1) 204 | return tf.reduce_sum(bra * ket, axis=0) 205 | 206 | 207 | @tf.function 208 | def zb(p1, p2): 209 | ket = u0(p2, -1) 210 | bra = ubar0(p1, 1) 211 | return tf.reduce_sum(bra * ket, axis=0) 212 | 213 | 214 | @tf.function 215 | def sprod(p1, p2): 216 | a = za(p1, p2) 217 | b = zb(p2, p1) 218 | return tf.math.real(a * b) 219 | 220 | 221 | @tf.function 222 | def qqxtbx(p0, p1, p2, p3): 223 | """Evaluate 0 -> qqbarttbar""" 224 | pw2 = sprod(p0, p1) 225 | wprop = tf.square(pw2 - mw2) + mw2 * gaw2 226 | a = sprod(p0, p2) 227 | b = sprod(p0, p3) 228 | c = sprod(p2, p3) 229 | d = sprod(p3, p1) 230 | return tf.abs((a + mt2 * b / c) * d) * colf_bt / wprop * gw4 / 36 231 | 232 | 233 | @tf.function 234 | def evaluate_matrix_element_square(p0, p1, p2, p3): 235 | """Evaluate Matrix Element square""" 236 | 237 | # massless projection 238 | k = mt2 / dot(p3, p0) / 2 239 | p3 -= p0 * k 240 | 241 | # channels evaluation 242 | c1 = qqxtbx(p2, -p1, p3, -p0) # BBARQBARQT +2 -1 +3 -0 243 | c2 = qqxtbx(-p1, p2, p3, -p0) # BBARQQBART -1 +2 +3 -0 244 | 245 | return tf.stack([c1, c2]) 246 | 247 | 248 | @tf.function 249 | def pdf(fl1, fl2, x1, x2): 250 | """Dummy toy PDF""" 251 | return x1 * x2 252 | 253 | 254 | @tf.function 255 | def build_luminosity(x1, x2): 256 | """Single-top t-channel luminosity""" 257 | lumi1 = pdf(5, 2, x1, x2) + pdf(5, 4, x1, x2) 258 | lumi2 = pdf(5, -1, x1, x2) + pdf(5, -3, x1, x2) 259 | lumis = tf.stack([lumi1, lumi2]) / x1 / x2 260 | return lumis 261 | 262 | 263 | @tf.function 264 | def singletop(xarr, **kwargs): 265 | """Single-top (t-channel) at LO""" 266 | psw, flux, p0, p1, p2, p3, x1, x2 = make_event(xarr) 267 | wgts = evaluate_matrix_element_square(p0, p1, p2, p3) 268 | lumis = build_luminosity(x1, x2) 269 | lumi_me2 = tf.reduce_sum(2 * lumis * wgts, axis=0) 270 | return lumi_me2 * psw * flux * conv 271 | 272 | 273 | if __name__ == "__main__": 274 | """Testing a basic integration""" 275 | print(f"VEGAS MC, ncalls={ncalls}:") 276 | start = time.time() 277 | r = vegas_wrapper(singletop, dim, n_iter, ncalls) 278 | end = time.time() 279 | print(f"time (s): {end-start}") 280 | 281 | try: 282 | from vegas import Integrator 283 | except ModuleNotFoundError: 284 | sys.exit(0) 285 | 286 | if len(sys.argv) > 1: 287 | print(" > Doing also the comparison with original Vegas ") 288 | 289 | def fun(xarr): 290 | x = xarr.reshape(1, -1) 291 | return singletop(x) 292 | 293 | print("Comparing with Lepage's Vegas") 294 | limits = dim * [[0.0, 1.0]] 295 | integrator = Integrator(limits) 296 | start = time.time() 297 | vr = integrator(fun, neval=ncalls, nitn=n_iter) 298 | end = time.time() 299 | print(vr.summary()) 300 | print(f"time (s): {end-start}") 301 | print(f"Per iteration (s): {(end-start)/n_iter}") 302 | -------------------------------------------------------------------------------- /src/vegasflow/vflowplus.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of vegas+ algorithm: 3 | 4 | adaptive importance sampling + adaptive stratified sampling 5 | from https://arxiv.org/abs/2009.05112 6 | 7 | The main interface is the `VegasFlowPlus` class. 8 | """ 9 | 10 | from functools import partial 11 | from itertools import product 12 | import logging 13 | 14 | import numpy as np 15 | import tensorflow as tf 16 | 17 | from vegasflow.configflow import ( 18 | BETA, 19 | BINS_MAX, 20 | DTYPE, 21 | DTYPEINT, 22 | MAX_NEVAL_HCUBE, 23 | float_me, 24 | fone, 25 | fzero, 26 | int_me, 27 | ) 28 | from vegasflow.monte_carlo import MonteCarloFlow, sampler, wrapper 29 | from vegasflow.utils import consume_array_into_indices 30 | from vegasflow.vflow import VegasFlow, importance_sampling_digest 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | FBINS = float_me(BINS_MAX) 35 | 36 | 37 | @tf.function( 38 | input_signature=[ 39 | tf.TensorSpec(shape=[None, None], dtype=DTYPE), 40 | tf.TensorSpec(shape=[], dtype=DTYPEINT), 41 | tf.TensorSpec(shape=[None], dtype=DTYPEINT), 42 | tf.TensorSpec(shape=[None, None], dtype=DTYPEINT), 43 | tf.TensorSpec(shape=[None, None], dtype=DTYPE), 44 | ] 45 | ) 46 | def generate_samples_in_hypercubes(rnds, n_strat, n_ev, hypercubes, divisions): 47 | """Receives an array of random numbers 0 and 1 and 48 | distribute them in each hypercube according to the 49 | number of samples in each hypercube specified by n_ev 50 | 51 | Parameters 52 | ---------- 53 | `rnds`: tensor of random number between 0 and 1 54 | `n_strat`: tensor with number of stratifications in each dimension 55 | `n_ev`: tensor containing number of samples per hypercube 56 | `hypercubes`: tensor containing all different hypercube 57 | `divisions`: vegas grid 58 | 59 | Returns 60 | ------- 61 | `x` : random numbers collocated in hypercubes 62 | `w` : weight of each event 63 | `ind`: division index in which each (n_dim) set of random numbers fall 64 | `segm` : segmentation for later computations 65 | """ 66 | # Use the event-per-hypercube information to fix each random event to a hypercube 67 | indices = tf.repeat(tf.range(tf.shape(hypercubes, out_type=DTYPEINT)[0]), n_ev) 68 | points = float_me(tf.gather(hypercubes, indices)) 69 | n_evs = float_me(tf.gather(n_ev, indices)) 70 | 71 | # Compute in which division of the importance_sampling grid the points fall 72 | xn = tf.transpose(points + rnds) * FBINS / float_me(n_strat) 73 | 74 | ind_xn, x, weights = importance_sampling_digest(xn, divisions) 75 | 76 | # Reweighs taking into account the number of events per hypercube 77 | final_weights = weights / n_evs 78 | 79 | segm = indices 80 | return x, final_weights, ind_xn, segm 81 | 82 | 83 | class VegasFlowPlus(VegasFlow): 84 | """ 85 | Implementation of the VEGAS+ algorithm 86 | """ 87 | 88 | def __init__(self, n_dim, n_events, train=True, adaptive=False, events_limit=None, **kwargs): 89 | # https://github.com/N3PDF/vegasflow/issues/78 90 | if events_limit is None: 91 | logger.info("Events per device limit set to %d", n_events) 92 | events_limit = n_events 93 | elif events_limit < n_events: 94 | logger.warning( 95 | "VegasFlowPlus needs to hold all events in memory at once, " 96 | "setting the `events_limit` to be equal to `n_events=%d`", 97 | n_events, 98 | ) 99 | events_limit = n_events 100 | super().__init__(n_dim, n_events, train, events_limit=events_limit, **kwargs) 101 | 102 | # Save the initial number of events 103 | self._init_calls = n_events 104 | 105 | # Don't use adaptive if the number of dimension is too big 106 | if n_dim > 13 and adaptive: 107 | self._adaptive = False 108 | logger.warning("Disabling adaptive mode from VegasFlowPlus, too many dimensions!") 109 | else: 110 | self._adaptive = adaptive 111 | 112 | # Initialize stratifications 113 | if self._adaptive: 114 | neval_eff = int(self.n_events / 2) 115 | self._n_strat = tf.math.floor(tf.math.pow(neval_eff / 2, 1 / n_dim)) 116 | else: 117 | neval_eff = self.n_events 118 | self._n_strat = tf.math.floor(tf.math.pow(neval_eff / 2, 1 / n_dim)) 119 | 120 | if tf.math.pow(self._n_strat, n_dim) > MAX_NEVAL_HCUBE: 121 | self._n_strat = tf.math.floor(tf.math.pow(1e4, 1 / n_dim)) 122 | 123 | self._n_strat = int_me(self._n_strat) 124 | 125 | # Initialize hypercubes 126 | hypercubes_one_dim = np.arange(0, int(self._n_strat)) 127 | hypercubes = [list(p) for p in product(hypercubes_one_dim, repeat=int(n_dim))] 128 | self._hypercubes = tf.convert_to_tensor(hypercubes, dtype=DTYPEINT) 129 | 130 | if len(hypercubes) != int(tf.math.pow(self._n_strat, n_dim)): 131 | raise ValueError("Hypercubes are not equal to n_strat^n_dim") 132 | 133 | self.min_neval_hcube = int(neval_eff // len(hypercubes)) 134 | self.min_neval_hcube = max(self.min_neval_hcube, 2) 135 | 136 | self.n_ev = tf.fill([1, len(hypercubes)], self.min_neval_hcube) 137 | self.n_ev = int_me(tf.reshape(self.n_ev, [-1])) 138 | self._n_events = int(tf.reduce_sum(self.n_ev)) 139 | self._modified_jac = float_me(1 / len(hypercubes)) 140 | 141 | if self._adaptive: 142 | logger.warning("Variable number of events requires function signatures all across") 143 | 144 | @property 145 | def xjac(self): 146 | return self._modified_jac 147 | 148 | def make_differentiable(self): 149 | """Overrides make_differentiable to make sure the runner has a reference to n_ev""" 150 | runner = super().make_differentiable() 151 | return partial(runner, n_ev=self.n_ev) 152 | 153 | def redistribute_samples(self, arr_var): 154 | """Receives an array with the variance of the integrand in each 155 | hypercube and recalculate the samples per hypercube according 156 | to VEGAS+ algorithm""" 157 | damped_arr_sdev = tf.pow(arr_var, BETA / 2) 158 | new_n_ev = tf.maximum( 159 | self.min_neval_hcube, 160 | damped_arr_sdev * self._init_calls / 2 / tf.reduce_sum(damped_arr_sdev), 161 | ) 162 | self.n_ev = int_me(new_n_ev) 163 | self.n_events = int(tf.reduce_sum(self.n_ev)) 164 | 165 | def _digest_random_generation(self, rnds, n_ev): 166 | """Generate a random array for a given number of events divided in hypercubes""" 167 | # Get random numbers from hypercubes 168 | x, w, ind, segm = generate_samples_in_hypercubes( 169 | rnds, self._n_strat, n_ev, self._hypercubes, self.divisions 170 | ) 171 | return x, w, ind, segm 172 | 173 | def generate_random_array(self, n_events, *args): 174 | """Override the behaviour of ``generate_random_array`` 175 | to accomodate for the peculiarities of VegasFlowPlus 176 | """ 177 | rnds = [] 178 | wgts = [] 179 | for _ in range(n_events // self.n_events + 1): 180 | r, w = super().generate_random_array(self.n_events, self.n_ev) 181 | rnds.append(r) 182 | wgts.append(w) 183 | final_r = tf.concat(rnds, axis=0)[:n_events] 184 | final_w = tf.concat(wgts, axis=0)[:n_events] * self.n_events / n_events 185 | return final_r, final_w 186 | 187 | def _run_event(self, integrand, ncalls=None, n_ev=None): 188 | """Run one step of VegasFlowPlus 189 | Similar to the event step for importance sampling VegasFlow 190 | adding the n_ev argument for the segmentation into hypercubes 191 | n_ev is a tensor containing the number of samples per hypercube 192 | 193 | Parameters 194 | ---------- 195 | `integrand`: function to integrate 196 | `ncalls`: how many events to run in this step 197 | `n_ev`: number of samples per hypercube 198 | 199 | Returns 200 | ------- 201 | `res`: sum of the result of the integrand for all events per segment 202 | `res2`: sum of the result squared of the integrand for all events per segment 203 | `arr_res2`: result of the integrand squared per dimension and grid bin 204 | """ 205 | # NOTE: needs to receive both ncalls and n_ev 206 | x, xjac, ind, segm = self._generate_random_array(ncalls, n_ev) 207 | 208 | # compute integrand 209 | tmp = xjac * integrand(x, weight=xjac) 210 | tmp2 = tf.square(tmp) 211 | 212 | # tensor containing resummed component for each hypercubes 213 | ress = tf.math.segment_sum(tmp, segm) 214 | ress2 = tf.math.segment_sum(tmp2, segm) 215 | 216 | fn_ev = float_me(n_ev) 217 | arr_var = ress2 * fn_ev - tf.square(ress) 218 | arr_res2 = self._importance_sampling_array_filling(tmp2, ind) 219 | 220 | return ress, arr_var, arr_res2 221 | 222 | def _iteration_content(self): 223 | """Steps to follow per iteration 224 | Differently from importance-sampling Vegas, the result of the integration 225 | is a result _per segment_ and thus the total result needs to be computed at this point 226 | """ 227 | ress, arr_var, arr_res2 = self.run_event(n_ev=self.n_ev) 228 | 229 | # Compute the rror 230 | sigmas2 = tf.maximum(arr_var, fzero) 231 | res = tf.reduce_sum(ress) 232 | sigma2 = tf.reduce_sum(sigmas2 / (float_me(self.n_ev) - fone)) 233 | sigma = tf.sqrt(sigma2) 234 | 235 | # If adaptive is active redistribute the samples 236 | if self._adaptive: 237 | self.redistribute_samples(arr_var) 238 | 239 | if self.train: 240 | self.refine_grid(arr_res2) 241 | 242 | return res, sigma 243 | 244 | def run_event(self, tensorize_events=None, **kwargs): 245 | """Tensorizes the number of events 246 | so they are not python or numpy primitives if self._adaptive=True""" 247 | return super().run_event(tensorize_events=self._adaptive, **kwargs) 248 | 249 | 250 | def vegasflowplus_wrapper(integrand, n_dim, n_iter, total_n_events, **kwargs): 251 | """Convenience wrapper 252 | 253 | Parameters 254 | ---------- 255 | `integrand`: tf.function 256 | `n_dim`: number of dimensions 257 | `n_iter`: number of iterations 258 | `n_events`: number of events per iteration 259 | 260 | Returns 261 | ------- 262 | `final_result`: integral value 263 | `sigma`: monte carlo error 264 | """ 265 | return wrapper(VegasFlowPlus, integrand, n_dim, n_iter, total_n_events, **kwargs) 266 | 267 | 268 | def vegasflowplus_sampler(*args, **kwargs): 269 | """Convenience wrapper for sampling random numbers 270 | 271 | Parameters 272 | ---------- 273 | `integrand`: tf.function 274 | `n_dim`: number of dimensions 275 | `n_events`: number of events per iteration 276 | `training_steps`: number of training_iterations 277 | 278 | Returns 279 | ------- 280 | `sampler`: a reference to the generate_random_array method of the integrator class 281 | """ 282 | return sampler(VegasFlowPlus, *args, **kwargs) 283 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/vegasflow/vflow.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the VegasFlow class and all its auxiliary functions 3 | 4 | The main interfaces of this class are the class `VegasFlow` and the 5 | `vegas_wrapper` 6 | """ 7 | 8 | import json 9 | import logging 10 | 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | from vegasflow.configflow import ( 15 | ALPHA, 16 | BINS_MAX, 17 | DTYPE, 18 | DTYPEINT, 19 | float_me, 20 | fone, 21 | fzero, 22 | int_me, 23 | ione, 24 | ) 25 | from vegasflow.monte_carlo import MonteCarloFlow, sampler, wrapper 26 | from vegasflow.utils import consume_array_into_indices 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | FBINS = float_me(BINS_MAX) 31 | 32 | 33 | @tf.function( 34 | input_signature=[ 35 | tf.TensorSpec(shape=[None, None], dtype=DTYPE), 36 | tf.TensorSpec(shape=[None, BINS_MAX + 1], dtype=DTYPE), 37 | ] 38 | ) 39 | def importance_sampling_digest(xn, divisions): 40 | """Common piece of the importance sampling algorithm. 41 | 42 | Receives a random array (number of dimensions, number of dim) 43 | containing information about from which bins in the 44 | grid (n_dims, BINS_MAX+1) the random points have to be sampled 45 | 46 | This algorithm is shared between the simplest form of Vegas 47 | (VegasFlow: only importance sampling) 48 | and Vegas+ (VegasFlowPlus: importance and stratified sampling) 49 | and so it has been lifted to its own function 50 | 51 | Parameters: 52 | ----------- 53 | xn: float tensor (n_dim, n_events) 54 | which bins to sample from 55 | divisions: float tensor (n_dims, BINS_MAX+1) 56 | grid of divisions for the importance sampling algorithm 57 | 58 | Returns 59 | ------- 60 | ind_i: integer tensor (n_events, n_dim) 61 | index in the divisions grid from which the points should be sampled 62 | x: float tensor (n_events, n_dim) 63 | random values sampled in the divisions grid 64 | xdelta: float tensor (n_events,) 65 | weight of the random points 66 | """ 67 | ind_i = int_me(xn) 68 | # Get the value of the left and right sides of the bins 69 | ind_f = ind_i + ione 70 | x_ini = tf.gather(divisions, ind_i, batch_dims=1) 71 | x_fin = tf.gather(divisions, ind_f, batch_dims=1) 72 | # Compute the width of the bins 73 | xdelta = x_fin - x_ini 74 | # Take the decimal part of bin (i.e., how deep within the bin) 75 | aux_rand = xn - tf.math.floor(xn) 76 | x = x_ini + xdelta * aux_rand 77 | # Compute the weight of the points 78 | weights = tf.reduce_prod(xdelta * FBINS, axis=0) 79 | 80 | # Tranpose the output to be what the external functions expect 81 | x = tf.transpose(x) 82 | ind_i = tf.transpose(ind_i) 83 | return ind_i, x, weights 84 | 85 | 86 | # Auxiliary functions for Vegas 87 | @tf.function( 88 | input_signature=[ 89 | tf.TensorSpec(shape=[None, None], dtype=DTYPE), 90 | tf.TensorSpec(shape=[None, BINS_MAX + 1], dtype=DTYPE), 91 | ] 92 | ) 93 | def _generate_random_array(rnds, divisions): 94 | """ 95 | Generates the Vegas random array for any number of events 96 | 97 | Parameters 98 | ---------- 99 | rnds: array shaped (None, n_dim) 100 | Random numbers used as an input for Vegas 101 | divisions: array shaped (n_dim, BINS_MAX+1) 102 | vegas grid 103 | 104 | Returns 105 | ------- 106 | x: array (None, n_dim) 107 | Vegas random output 108 | w: array (None,) 109 | Weight of each set of (n_dim) random numbers 110 | div_index: array (None, n_dim) 111 | division index in which each (n_dim) set of random numbers fall 112 | """ 113 | # Get the boundaries of the random numbers 114 | # reg_i = fzero 115 | # reg_f = fone 116 | # Get the index of the division we are interested in 117 | xn = FBINS * (fone - tf.transpose(rnds)) 118 | 119 | # Compute the random number between 0 and 1 120 | # and the index of the bin where it has been sampled from 121 | ind_xn, x, weights = importance_sampling_digest(xn, divisions) 122 | 123 | # Compute the random number between the limits 124 | # commented, for now only from 0 to 1 125 | # x = reg_i + rand_x * (reg_f - reg_i) 126 | return x, weights, ind_xn 127 | 128 | 129 | @tf.function( 130 | input_signature=[ 131 | tf.TensorSpec(shape=[BINS_MAX], dtype=DTYPE), 132 | tf.TensorSpec(shape=[BINS_MAX + 1], dtype=DTYPE), 133 | ] 134 | ) 135 | def refine_grid_per_dimension(t_res_sq, subdivisions): 136 | """ 137 | Modifies the boundaries for the vegas grid for a given dimension 138 | 139 | Parameters 140 | ---------- 141 | `t_res_sq`: tensor 142 | array of results squared per bin 143 | `subdivision`: tensor 144 | current boundaries for the grid 145 | 146 | Returns 147 | ------- 148 | `new_divisions`: tensor 149 | array with the new boundaries of the grid 150 | """ 151 | # Define some constants 152 | paddings = int_me([[1, 1]]) 153 | tmp_meaner = tf.fill([BINS_MAX - 2], float_me(3.0)) 154 | meaner = tf.pad(tmp_meaner, paddings, constant_values=2.0) 155 | # Pad the vector of results 156 | res_padded = tf.pad(t_res_sq, paddings) 157 | # First we need to smear out the array of results squared 158 | smeared_tensor_tmp = res_padded[1:-1] + res_padded[2:] + res_padded[:-2] 159 | smeared_tensor = tf.maximum(smeared_tensor_tmp / meaner, float_me(1e-30)) 160 | # Now we refine the grid according to 161 | # journal of comp phys, 27, 192-203 (1978) G.P. Lepage 162 | sum_t = tf.reduce_sum(smeared_tensor) 163 | log_t = tf.math.log(smeared_tensor) 164 | aux_t = (1.0 - smeared_tensor / sum_t) / (tf.math.log(sum_t) - log_t) 165 | wei_t = tf.pow(aux_t, ALPHA) 166 | ave_t = tf.reduce_sum(wei_t) / BINS_MAX 167 | 168 | ###### Auxiliary functions for the while loop 169 | @tf.function 170 | def while_check(bin_weight, *args): 171 | """Checks whether the bin has enough weight 172 | to beat the average""" 173 | return bin_weight < ave_t 174 | 175 | @tf.function( 176 | input_signature=[ 177 | tf.TensorSpec(shape=[], dtype=DTYPE), 178 | tf.TensorSpec(shape=[], dtype=DTYPEINT), 179 | tf.TensorSpec(shape=[], dtype=DTYPE), 180 | tf.TensorSpec(shape=[], dtype=DTYPE), 181 | ] 182 | ) 183 | def while_body(bin_weight, n_bin, cur, prev): 184 | """Fills the bin weight until it surpassed the avg 185 | once it's done, returns the limits of the last bin""" 186 | n_bin += 1 187 | bin_weight += wei_t[n_bin] 188 | prev = cur 189 | cur = subdivisions[n_bin + 1] 190 | return bin_weight, n_bin, cur, prev 191 | 192 | ########################### 193 | 194 | # And now resize all bins 195 | new_bins = [fzero] 196 | # Auxiliary variables 197 | bin_weight = fzero 198 | n_bin = -1 199 | cur = fzero 200 | prev = fzero 201 | for _ in range(BINS_MAX - 1): 202 | bin_weight, n_bin, cur, prev = tf.while_loop( 203 | while_check, while_body, (bin_weight, n_bin, cur, prev), parallel_iterations=1 204 | ) 205 | bin_weight -= ave_t 206 | delta = (cur - prev) * bin_weight / wei_t[n_bin] 207 | new_bins.append(cur - delta) 208 | new_bins.append(fone) 209 | 210 | new_divisions = tf.stack(new_bins) 211 | return new_divisions 212 | 213 | 214 | ####### VegasFlow 215 | class VegasFlow(MonteCarloFlow): 216 | """ 217 | Implementation of the important sampling algorithm from Vegas. 218 | 219 | Parameters 220 | ---------- 221 | n_dim: int 222 | number of dimensions to be integrated 223 | n_events: int 224 | number of events per iteration 225 | train: bool 226 | whether to train the grid 227 | main_dimension: int 228 | in case of vectorial output, main dimenison in which to train 229 | """ 230 | 231 | def __init__(self, n_dim, n_events, train=True, main_dimension=0, **kwargs): 232 | super().__init__(n_dim, n_events, **kwargs) 233 | 234 | # If training is True, the grid will be changed after every iteration 235 | # otherwise it will be frozen 236 | self.train = train 237 | 238 | # Initialize grid 239 | self.grid_bins = BINS_MAX + 1 240 | subdivision_np = np.linspace(0, 1, self.grid_bins) 241 | divisions_np = subdivision_np.repeat(n_dim).reshape(-1, n_dim).T 242 | self.divisions = tf.Variable(divisions_np, dtype=DTYPE) 243 | self._main_dimension = main_dimension 244 | 245 | def _can_run_vectorial(self, expected_shape): 246 | # only implemented for the main class at the moment, not for children 247 | if self._main_dimension >= expected_shape[-1]: 248 | raise ValueError( 249 | f"""The main dimension index ({self._main_dimension}) is greater than the dimensionality of the output ({expected_shape[-1]}). 250 | Remember that arrays in python are 0-indexed!""" 251 | ) 252 | return self.__class__.__name__ == "VegasFlow" 253 | 254 | def make_differentiable(self): 255 | """Freeze the grid if the function is to be called within a graph""" 256 | if self.train: 257 | logger.warning("Freezing the grid") 258 | self.freeze_grid() 259 | return super().make_differentiable() 260 | 261 | def freeze_grid(self): 262 | """Stops the grid from refining any more""" 263 | self.train = False 264 | self._recompile() 265 | 266 | def unfreeze_grid(self): 267 | """Enable the refining of the grid""" 268 | self.train = True 269 | self._recompile() 270 | 271 | def save_grid(self, file_name): 272 | """Save the `divisions` array in a json file 273 | 274 | Parameters 275 | ---------- 276 | `file_name`: str 277 | Filename in which to save the checkpoint 278 | """ 279 | div_np = self.divisions.numpy() 280 | if self._integrand: 281 | int_name = self._integrand.__name__ 282 | else: 283 | int_name = "" 284 | json_dict = { 285 | "dimensions": self.n_dim, 286 | "ALPHA": ALPHA, 287 | "BINS": self.grid_bins, 288 | "integrand": int_name, 289 | "grid": div_np.tolist(), 290 | } 291 | with open(file_name, "w") as f: 292 | json.dump(json_dict, f, indent=True) 293 | 294 | def load_grid(self, file_name=None, numpy_grid=None): 295 | """Load the `divisions` array from a json file 296 | or from a numpy_array 297 | 298 | Parameters 299 | ---------- 300 | `file_name`: str 301 | Filename in which the grid json is stored 302 | `numpy_grid`: np.array 303 | Numpy array to substitute divisions with 304 | """ 305 | if file_name is not None and numpy_grid is not None: 306 | raise ValueError( 307 | "Received both a numpy grid and a file_name to load the grid from." 308 | "Ambiguous call to `load_grid`" 309 | ) 310 | 311 | # If it received a file, loads up the grid 312 | if file_name: 313 | with open(file_name, "r") as f: 314 | json_dict = json.load(f) 315 | # First check the parameters of the grid are unchanged 316 | grid_dim = json_dict.get("dimensions") 317 | grid_bins = json_dict.get("BINS") 318 | # Check that the integrand is the same one 319 | if self._integrand: 320 | integrand_name = self._integrand.__name__ 321 | integrand_grid = json_dict.get("integrand") 322 | if integrand_name != integrand_grid: 323 | logger.warning( 324 | f"The grid was written for the integrand: {integrand_grid}" 325 | f"which is different from {integrand_name}" 326 | ) 327 | # Now that everything is clear, let's load up the grid 328 | numpy_grid = np.array(json_dict["grid"]) 329 | elif numpy_grid is not None: 330 | grid_dim = numpy_grid.shape[0] 331 | grid_bins = numpy_grid.shape[1] 332 | else: 333 | raise ValueError("load_grid was called but no grid was provided!") 334 | # Check that the grid has the right dimensions 335 | if grid_dim is not None and self.n_dim != grid_dim: 336 | raise ValueError( 337 | f"Received a {grid_dim}-dimensional grid while VegasFlow" 338 | f"was instantiated with {self.n_dim} dimensions" 339 | ) 340 | if grid_bins is not None and self.grid_bins != grid_bins: 341 | raise ValueError( 342 | f"The received grid contains {grid_bins} bins while the" 343 | f"current settings is of {self.grid_bins} bins" 344 | ) 345 | if file_name: 346 | logger.info(f" > SUCCESS: Loaded grid from {file_name}") 347 | self.divisions.assign(numpy_grid) 348 | 349 | def refine_grid(self, arr_res2): 350 | """Receives an array with the values of the integral squared per 351 | bin per dimension (`arr_res2.shape = (n_dim, self.grid_bins)`) 352 | and reshapes the `divisions` attribute accordingly 353 | 354 | Parameters 355 | ---------- 356 | `arr_res2`: result the integrand sq per dimension and grid bin 357 | 358 | Function not compiled 359 | """ 360 | for j in range(self.n_dim): 361 | new_divisions = refine_grid_per_dimension(arr_res2[j, :], self.divisions[j, :]) 362 | self.divisions[j, :].assign(new_divisions) 363 | 364 | def _digest_random_generation(self, rnds): 365 | """Generates ``n_events`` random numbers sampled in the 366 | adapted Vegas Grid""" 367 | x, w, ind = _generate_random_array(rnds, self.divisions) 368 | return x, w, ind 369 | 370 | def _importance_sampling_array_filling(self, results2, indices): 371 | """Receives an array of results squared for every event 372 | and an array of indices describing in which bin each result fall. 373 | Fills a array with the total result in each bin to be used by 374 | the importance sampling algorithm 375 | """ 376 | if not self.train: 377 | return [] 378 | 379 | arr_res2 = [] 380 | # If the training is active, save the result of the integral sq 381 | for j in range(self.n_dim): 382 | arr_res2.append( 383 | consume_array_into_indices( 384 | results2, indices[:, j : j + 1], int_me(self.grid_bins - 1) 385 | ) 386 | ) 387 | return tf.reshape(arr_res2, (self.n_dim, -1)) 388 | 389 | def _run_event(self, integrand, ncalls=None): 390 | """Runs one step of Vegas. 391 | 392 | Parameters 393 | ---------- 394 | `integrand`: function to integrate 395 | `ncalls`: how many events to run in this step 396 | 397 | Returns 398 | ------- 399 | `res`: sum of the result of the integrand for all events 400 | `res2`: sum of the result squared of the integrand for all events 401 | `arr_res2`: result of the integrand squared per dimension and grid bin 402 | """ 403 | if ncalls is None: 404 | n_events = self.n_events 405 | else: 406 | n_events = ncalls 407 | 408 | # Generate all random number for this iteration 409 | x, xjac, ind = self._generate_random_array(n_events) 410 | 411 | # Now compute the integrand 412 | int_result = integrand(x, weight=xjac) 413 | 414 | if self._vectorial: 415 | xjac = tf.reshape(xjac, (-1, 1)) 416 | tmp = xjac * int_result 417 | tmp2 = tf.square(tmp) 418 | 419 | # Compute the final result for this step 420 | res = tf.reduce_sum(tmp, axis=0) 421 | res2 = tf.reduce_sum(tmp2, axis=0) 422 | 423 | # If this is a vectorial integrand, make sure that only the main dimension 424 | # is used for the grid training 425 | if self._vectorial: 426 | tmp2 = tmp2[:, self._main_dimension] 427 | 428 | arr_res2 = self._importance_sampling_array_filling(tmp2, ind) 429 | 430 | return res, res2, arr_res2 431 | 432 | def _iteration_content(self): 433 | """Steps to follow per iteration""" 434 | # Compute the result 435 | res, res2, arr_res2 = self.run_event() 436 | # Compute the error 437 | err_tmp2 = (self.n_events * res2 - tf.square(res)) / (self.n_events - fone) 438 | sigma = tf.sqrt(tf.maximum(err_tmp2, fzero)) 439 | # If training is active, act post integration 440 | if self.train: 441 | self.refine_grid(arr_res2) 442 | return res, sigma 443 | 444 | def _run_iteration(self): 445 | """Runs one iteration of the Vegas integrator""" 446 | return self._iteration_content() 447 | 448 | 449 | def vegas_wrapper(integrand, n_dim, n_iter, total_n_events, **kwargs): 450 | """Convenience wrapper 451 | 452 | Parameters 453 | ---------- 454 | `integrand`: tf.function 455 | `n_dim`: number of dimensions 456 | `n_iter`: number of iterations 457 | `n_events`: number of events per iteration 458 | 459 | Returns 460 | ------- 461 | `final_result`: integral value 462 | `sigma`: monte carlo error 463 | """ 464 | return wrapper(VegasFlow, integrand, n_dim, n_iter, total_n_events, **kwargs) 465 | 466 | 467 | def vegas_sampler(*args, **kwargs): 468 | """Convenience wrapper for sampling random numbers 469 | 470 | Parameters 471 | ---------- 472 | `integrand`: tf.function 473 | `n_dim`: number of dimensions 474 | `n_events`: number of events per iteration 475 | `training_steps`: number of training_iterations 476 | 477 | Returns 478 | ------- 479 | `sampler`: a reference to the generate_random_array method of the integrator class 480 | """ 481 | return sampler(VegasFlow, *args, **kwargs) 482 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | 3 | # A comma-separated list of package or module names from where C extensions may 4 | # be loaded. Extensions are loading into the active Python interpreter and may 5 | # run arbitrary code. 6 | extension-pkg-whitelist=numpy,tensorflow 7 | 8 | # Specify a score threshold to be exceeded before program exits with error. 9 | fail-under=10 10 | 11 | # Add files or directories to the blacklist. They should be base names, not 12 | # paths. 13 | ignore=CVS 14 | 15 | # Add files or directories matching the regex patterns to the blacklist. The 16 | # regex matches against base names, not paths. 17 | ignore-patterns= 18 | 19 | # Python code to execute, usually for sys.path manipulation such as 20 | # pygtk.require(). 21 | #init-hook= 22 | 23 | # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the 24 | # number of processors available to use. 25 | jobs=2 26 | 27 | # Control the amount of potential inferred values when inferring a single 28 | # object. This can help the performance when dealing with large functions or 29 | # complex, nested conditions. 30 | limit-inference-results=100 31 | 32 | # List of plugins (as comma separated values of python module names) to load, 33 | # usually to register additional checkers. 34 | load-plugins= 35 | 36 | # Pickle collected data for later comparisons. 37 | persistent=yes 38 | 39 | # When enabled, pylint would attempt to guess common misconfiguration and emit 40 | # user-friendly hints instead of false-positive error messages. 41 | suggestion-mode=yes 42 | 43 | # Allow loading of arbitrary C extensions. Extensions are imported into the 44 | # active Python interpreter and may run arbitrary code. 45 | unsafe-load-any-extension=no 46 | 47 | 48 | [MESSAGES CONTROL] 49 | 50 | # Only show warnings with the listed confidence levels. Leave empty to show 51 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. 52 | confidence= 53 | 54 | # Disable the message, report, category or checker with the given id(s). You 55 | # can either give multiple identifiers separated by comma (,) or put this 56 | # option multiple times (only on the command line, not in the configuration 57 | # file where it should appear only once). You can also use "--disable=all" to 58 | # disable everything first and then reenable specific checks. For example, if 59 | # you want to run only the similarities checker, you can use "--disable=all 60 | # --enable=similarities". If you want to run only the classes checker, but have 61 | # no Warning level messages displayed, use "--disable=all --enable=classes 62 | # --disable=W". 63 | disable=print-statement, 64 | parameter-unpacking, 65 | unpacking-in-except, 66 | old-raise-syntax, 67 | backtick, 68 | long-suffix, 69 | old-ne-operator, 70 | old-octal-literal, 71 | import-star-module-level, 72 | non-ascii-bytes-literal, 73 | raw-checker-failed, 74 | bad-inline-option, 75 | locally-disabled, 76 | file-ignored, 77 | suppressed-message, 78 | useless-suppression, 79 | deprecated-pragma, 80 | use-symbolic-message-instead, 81 | apply-builtin, 82 | basestring-builtin, 83 | buffer-builtin, 84 | cmp-builtin, 85 | coerce-builtin, 86 | execfile-builtin, 87 | file-builtin, 88 | long-builtin, 89 | raw_input-builtin, 90 | reduce-builtin, 91 | standarderror-builtin, 92 | unicode-builtin, 93 | xrange-builtin, 94 | coerce-method, 95 | delslice-method, 96 | getslice-method, 97 | setslice-method, 98 | no-absolute-import, 99 | old-division, 100 | dict-iter-method, 101 | dict-view-method, 102 | next-method-called, 103 | metaclass-assignment, 104 | indexing-exception, 105 | raising-string, 106 | reload-builtin, 107 | oct-method, 108 | hex-method, 109 | nonzero-method, 110 | cmp-method, 111 | input-builtin, 112 | round-builtin, 113 | intern-builtin, 114 | unichr-builtin, 115 | map-builtin-not-iterating, 116 | zip-builtin-not-iterating, 117 | range-builtin-not-iterating, 118 | filter-builtin-not-iterating, 119 | using-cmp-argument, 120 | eq-without-hash, 121 | div-method, 122 | idiv-method, 123 | rdiv-method, 124 | exception-message-attribute, 125 | invalid-str-codec, 126 | sys-max-int, 127 | bad-python3-import, 128 | deprecated-string-function, 129 | deprecated-str-translate-call, 130 | invalid-name, 131 | too-few-public-methods, 132 | deprecated-itertools-function, 133 | deprecated-types-field, 134 | next-method-defined, 135 | dict-items-not-iterating, 136 | dict-keys-not-iterating, 137 | dict-values-not-iterating, 138 | deprecated-operator-function, 139 | deprecated-urllib-function, 140 | xreadlines-attribute, 141 | deprecated-sys-function, 142 | exception-escape, 143 | comprehension-escape, 144 | E1123, # pylint is not able to deal with tensorflow 145 | E1120, # same as above 146 | C0330, # black indentation when breaking long lines is better 147 | 148 | 149 | 150 | # Enable the message, report, category or checker with the given id(s). You can 151 | # either give multiple identifier separated by comma (,) or put this option 152 | # multiple time (only on the command line, not in the configuration file where 153 | # it should appear only once). See also the "--disable" option for examples. 154 | enable=c-extension-no-member 155 | 156 | 157 | [REPORTS] 158 | 159 | # Python expression which should return a score less than or equal to 10. You 160 | # have access to the variables 'error', 'warning', 'refactor', and 'convention' 161 | # which contain the number of messages in each category, as well as 'statement' 162 | # which is the total number of statements analyzed. This score is used by the 163 | # global evaluation report (RP0004). 164 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 165 | 166 | # Template used to display messages. This is a python new-style format string 167 | # used to format the message information. See doc for all details. 168 | #msg-template= 169 | 170 | # Set the output format. Available formats are text, parseable, colorized, json 171 | # and msvs (visual studio). You can also give a reporter class, e.g. 172 | # mypackage.mymodule.MyReporterClass. 173 | output-format=text 174 | 175 | # Tells whether to display a full report or only the messages. 176 | reports=no 177 | 178 | # Activate the evaluation score. 179 | score=yes 180 | 181 | 182 | [REFACTORING] 183 | 184 | # Maximum number of nested blocks for function / method body 185 | max-nested-blocks=5 186 | 187 | # Complete name of functions that never returns. When checking for 188 | # inconsistent-return-statements if a never returning function is called then 189 | # it will be considered as an explicit return statement and no message will be 190 | # printed. 191 | never-returning-functions=sys.exit 192 | 193 | 194 | [STRING] 195 | 196 | # This flag controls whether inconsistent-quotes generates a warning when the 197 | # character used as a quote delimiter is used inconsistently within a module. 198 | check-quote-consistency=no 199 | 200 | # This flag controls whether the implicit-str-concat should generate a warning 201 | # on implicit string concatenation in sequences defined over several lines. 202 | check-str-concat-over-line-jumps=no 203 | 204 | 205 | [TYPECHECK] 206 | 207 | # List of decorators that produce context managers, such as 208 | # contextlib.contextmanager. Add to this list to register other decorators that 209 | # produce valid context managers. 210 | contextmanager-decorators=contextlib.contextmanager 211 | 212 | # List of members which are set dynamically and missed by pylint inference 213 | # system, and so shouldn't trigger E1101 when accessed. Python regular 214 | # expressions are accepted. 215 | generated-members= 216 | 217 | # Tells whether missing members accessed in mixin class should be ignored. A 218 | # mixin class is detected if its name ends with "mixin" (case insensitive). 219 | ignore-mixin-members=yes 220 | 221 | # Tells whether to warn about missing members when the owner of the attribute 222 | # is inferred to be None. 223 | ignore-none=yes 224 | 225 | # This flag controls whether pylint should warn about no-member and similar 226 | # checks whenever an opaque object is returned when inferring. The inference 227 | # can return multiple potential results while evaluating a Python object, but 228 | # some branches might not be evaluated, which results in partial inference. In 229 | # that case, it might be useful to still emit no-member and other checks for 230 | # the rest of the inferred objects. 231 | ignore-on-opaque-inference=yes 232 | 233 | # List of class names for which member attributes should not be checked (useful 234 | # for classes with dynamically set attributes). This supports the use of 235 | # qualified names. 236 | ignored-classes=optparse.Values,thread._local,_thread._local 237 | 238 | # List of module names for which member attributes should not be checked 239 | # (useful for modules/projects where namespaces are manipulated during runtime 240 | # and thus existing member attributes cannot be deduced by static analysis). It 241 | # supports qualified module names, as well as Unix pattern matching. 242 | ignored-modules=tensorflow 243 | 244 | # Show a hint with possible names when a member name was not found. The aspect 245 | # of finding the hint is based on edit distance. 246 | missing-member-hint=yes 247 | 248 | # The minimum edit distance a name should have in order to be considered a 249 | # similar match for a missing member name. 250 | missing-member-hint-distance=1 251 | 252 | # The total number of similar names that should be taken in consideration when 253 | # showing a hint for a missing member. 254 | missing-member-max-choices=1 255 | 256 | # List of decorators that change the signature of a decorated function. 257 | signature-mutators= 258 | 259 | 260 | [BASIC] 261 | 262 | # Naming hint for argument names 263 | argument-name-hint=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 264 | 265 | # Regular expression matching correct argument names 266 | argument-rgx=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 267 | 268 | # Naming hint for attribute names 269 | attr-name-hint=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 270 | 271 | # Regular expression matching correct attribute names 272 | attr-rgx=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 273 | 274 | # Bad variable names which should always be refused, separated by a comma 275 | bad-names=foo,bar,baz,toto,tutu,tata 276 | 277 | # Naming hint for class attribute names 278 | class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ 279 | 280 | # Regular expression matching correct class attribute names 281 | class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ 282 | 283 | # Naming hint for class names 284 | class-name-hint=[A-Z_][a-zA-Z0-9]+$ 285 | 286 | # Regular expression matching correct class names 287 | class-rgx=[A-Z_][a-zA-Z0-9]+$ 288 | 289 | # Naming hint for constant names 290 | const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$ 291 | 292 | # Regular expression matching correct constant names 293 | const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$ 294 | 295 | # Minimum line length for functions/classes that require docstrings, shorter 296 | # ones are exempt. 297 | docstring-min-length=-1 298 | 299 | # Naming hint for function names 300 | function-name-hint=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 301 | 302 | # Regular expression matching correct function names 303 | function-rgx=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 304 | 305 | # Good variable names which should always be accepted, separated by a comma 306 | good-names=i,j,k,ex,Run,_ 307 | 308 | # Good variable names regexes, separated by a comma. If names match any regex, 309 | # they will always be accepted 310 | good-names-rgxs= 311 | 312 | # Include a hint for the correct naming format with invalid-name. 313 | include-naming-hint=no 314 | 315 | # Naming hint for inline iteration names 316 | inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$ 317 | 318 | # Regular expression matching correct inline iteration names 319 | inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ 320 | 321 | # Naming hint for method names 322 | method-name-hint=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 323 | 324 | # Regular expression matching correct method names 325 | method-rgx=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 326 | 327 | # Naming hint for module names 328 | module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ 329 | 330 | # Regular expression matching correct module names 331 | module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ 332 | 333 | # Colon-delimited sets of names that determine each other's naming style when 334 | # the name regexes allow several styles. 335 | name-group= 336 | 337 | # Regular expression which should only match function or class names that do 338 | # not require a docstring. 339 | no-docstring-rgx=^_ 340 | 341 | # List of decorators that produce properties, such as abc.abstractproperty. Add 342 | # to this list to register other decorators that produce valid properties. 343 | # These decorators are taken in consideration only for invalid-name. 344 | property-classes=abc.abstractproperty 345 | 346 | # Naming hint for variable names 347 | variable-name-hint=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 348 | 349 | # Regular expression matching correct variable names 350 | variable-rgx=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$ 351 | 352 | 353 | [VARIABLES] 354 | 355 | # List of additional names supposed to be defined in builtins. Remember that 356 | # you should avoid defining new builtins when possible. 357 | additional-builtins= 358 | 359 | # Tells whether unused global variables should be treated as a violation. 360 | allow-global-unused-variables=yes 361 | 362 | # List of strings which can identify a callback function by name. A callback 363 | # name must start or end with one of those strings. 364 | callbacks=cb_,_cb 365 | 366 | # A regular expression matching the name of dummy variables (i.e. expected to 367 | # not be used). 368 | dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ 369 | 370 | # Argument names that match this expression will be ignored. Default to name 371 | # with leading underscore. 372 | ignored-argument-names=_.*|^ignored_|^unused_ 373 | 374 | # Tells whether we should check for unused import in __init__ files. 375 | init-import=no 376 | 377 | # List of qualified module names which can have objects that can redefine 378 | # builtins. 379 | redefining-builtins-modules=six.moves,future.builtins 380 | 381 | 382 | [FORMAT] 383 | 384 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 385 | expected-line-ending-format= 386 | 387 | # Regexp for a line that is allowed to be longer than the limit. 388 | ignore-long-lines=^\s*(# )??$ 389 | 390 | # Number of spaces of indent required inside a hanging or continued line. 391 | indent-after-paren=4 392 | 393 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 394 | # tab). 395 | indent-string=' ' 396 | 397 | # Maximum number of characters on a single line. 398 | max-line-length=100 399 | 400 | # Maximum number of lines in a module. 401 | max-module-lines=1000 402 | 403 | # List of optional constructs for which whitespace checking is disabled. `dict- 404 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. 405 | # `trailing-comma` allows a space between comma and closing bracket: (a, ). 406 | # `empty-line` allows space-only lines. 407 | no-space-check=trailing-comma,dict-separator 408 | 409 | # Allow the body of a class to be on the same line as the declaration if body 410 | # contains single statement. 411 | single-line-class-stmt=no 412 | 413 | # Allow the body of an if to be on the same line as the test if there is no 414 | # else. 415 | single-line-if-stmt=no 416 | 417 | 418 | [MISCELLANEOUS] 419 | 420 | # List of note tags to take in consideration, separated by a comma. 421 | notes=FIXME, 422 | XXX, 423 | TODO 424 | 425 | # Regular expression of note tags to take in consideration. 426 | #notes-rgx= 427 | 428 | 429 | [SIMILARITIES] 430 | 431 | # Ignore comments when computing similarities. 432 | ignore-comments=yes 433 | 434 | # Ignore docstrings when computing similarities. 435 | ignore-docstrings=yes 436 | 437 | # Ignore imports when computing similarities. 438 | ignore-imports=no 439 | 440 | # Minimum lines number of a similarity. 441 | min-similarity-lines=4 442 | 443 | 444 | [SPELLING] 445 | 446 | # Limits count of emitted suggestions for spelling mistakes. 447 | max-spelling-suggestions=4 448 | 449 | # Spelling dictionary name. Available dictionaries: none. To make it work, 450 | # install the python-enchant package. 451 | spelling-dict= 452 | 453 | # List of comma separated words that should not be checked. 454 | spelling-ignore-words= 455 | 456 | # A path to a file that contains the private dictionary; one word per line. 457 | spelling-private-dict-file= 458 | 459 | # Tells whether to store unknown words to the private dictionary (see the 460 | # --spelling-private-dict-file option) instead of raising a message. 461 | spelling-store-unknown-words=no 462 | 463 | 464 | [IMPORTS] 465 | 466 | # List of modules that can be imported at any level, not just the top level 467 | # one. 468 | allow-any-import-level= 469 | 470 | # Allow wildcard imports from modules that define __all__. 471 | allow-wildcard-with-all=no 472 | 473 | # Analyse import fallback blocks. This can be used to support both Python 2 and 474 | # 3 compatible code, which means that the block might have code that exists 475 | # only in one or another interpreter, leading to false positives when analysed. 476 | analyse-fallback-blocks=no 477 | 478 | # Deprecated modules which should not be used, separated by a comma. 479 | deprecated-modules=optparse,tkinter.tix 480 | 481 | # Create a graph of external dependencies in the given file (report RP0402 must 482 | # not be disabled). 483 | ext-import-graph= 484 | 485 | # Create a graph of every (i.e. internal and external) dependencies in the 486 | # given file (report RP0402 must not be disabled). 487 | import-graph= 488 | 489 | # Create a graph of internal dependencies in the given file (report RP0402 must 490 | # not be disabled). 491 | int-import-graph= 492 | 493 | # Force import order to recognize a module as part of the standard 494 | # compatibility libraries. 495 | known-standard-library= 496 | 497 | # Force import order to recognize a module as part of a third party library. 498 | known-third-party=enchant 499 | 500 | # Couples of modules and preferred modules, separated by a comma. 501 | preferred-modules= 502 | 503 | 504 | [DESIGN] 505 | 506 | # Maximum number of arguments for function / method. 507 | max-args=5 508 | 509 | # Maximum number of attributes for a class (see R0902). 510 | max-attributes=7 511 | 512 | # Maximum number of boolean expressions in an if statement (see R0916). 513 | max-bool-expr=5 514 | 515 | # Maximum number of branch for function / method body. 516 | max-branches=12 517 | 518 | # Maximum number of locals for function / method body 519 | max-locals=22 520 | 521 | # Maximum number of parents for a class (see R0901). 522 | max-parents=7 523 | 524 | # Maximum number of public methods for a class (see R0904). 525 | max-public-methods=20 526 | 527 | # Maximum number of return / yield for function / method body. 528 | max-returns=6 529 | 530 | # Maximum number of statements in function / method body. 531 | max-statements=50 532 | 533 | # Minimum number of public methods for a class (see R0903). 534 | min-public-methods=2 535 | 536 | 537 | [CLASSES] 538 | 539 | # List of method names used to declare (i.e. assign) instance attributes. 540 | defining-attr-methods=__init__, 541 | __new__, 542 | setUp, 543 | __post_init__ 544 | 545 | # List of member names, which should be excluded from the protected access 546 | # warning. 547 | exclude-protected=_asdict, 548 | _fields, 549 | _replace, 550 | _source, 551 | _make 552 | 553 | # List of valid names for the first argument in a class method. 554 | valid-classmethod-first-arg=cls 555 | 556 | # List of valid names for the first argument in a metaclass class method. 557 | valid-metaclass-classmethod-first-arg=cls 558 | 559 | 560 | [EXCEPTIONS] 561 | 562 | # Exceptions that will emit a warning when being caught. Defaults to 563 | # "BaseException, Exception". 564 | overgeneral-exceptions=BaseException, 565 | Exception 566 | -------------------------------------------------------------------------------- /doc/source/how_to.rst: -------------------------------------------------------------------------------- 1 | .. _howto-label: 2 | 3 | ===== 4 | Usage 5 | ===== 6 | 7 | ``VegasFlow`` is a python library that provides a number of functions to perform Monte Carlo integration of some functions. 8 | In this guide we do our best to explain the steps to follow in order to perform a successful calculation with ``VegasFlow``. 9 | If, after reading this, you have any doubts, questions (or ideas for 10 | improvements!) please, don't hesitate to contact us by `opening an issue on GitHub 11 | `_. 12 | 13 | 14 | .. contents:: 15 | :local: 16 | :depth: 1 17 | 18 | 19 | Integrating with VegasFlow 20 | ========================== 21 | 22 | Basic usage 23 | ^^^^^^^^^^^ 24 | 25 | Integrating a function with ``VegasFlow`` is done in three basic steps: 26 | 27 | 1. **Instantiating an integrator**: At the time of instantiation it is necessary to provide 28 | a number of dimensions and a number of calls per iteration. 29 | The reason for giving this information beforehand is to allow for optimization. 30 | 31 | .. code-block:: python 32 | 33 | from vegasflow import VegasFlow 34 | 35 | dims = 3 36 | n_calls = int(1e7) 37 | vegas_instance = VegasFlow(dims, n_calls) 38 | 39 | 2. **Compiling the integrand**: The integrand needs to be given to the integrator for compilation. 40 | Compilation serves a dual purposes, it first registers the integrand and then it compiles it 41 | using the ``tf.function`` decorator. 42 | 43 | .. code-block:: python 44 | 45 | import tensorflow as tf 46 | 47 | def example_integrand(xarr, weight=None): 48 | s = tf.reduce_sum(xarr, axis=1) 49 | result = tf.pow(0.1/s, 2) 50 | return result 51 | 52 | vegas_instance.compile(example_integrand) 53 | 54 | 3. **Running the integration**: Once everything is in place, we just need to inform the integrator of the number of 55 | iterations we want. 56 | 57 | .. code-block:: python 58 | 59 | n_iter = 5 60 | result = vegas_instance.run_integration(n_iter) 61 | 62 | 63 | Constructing the integrand 64 | ^^^^^^^^^^^^^^^^^^^^^^^^^^ 65 | Constructing an integrand for ``VegasFlow`` is similar to constructing an integrand for any other algorithm with a small difference: 66 | the output of the integrand should be a tensor of results instead of just one number. 67 | While most integration algorithms will take a function and then evaluate said function ``n`` number of times (to calculate ``n`` events) 68 | ``VegasFlow`` takes the approach of evaluating as many events as possible at once. 69 | As such the input random array (``xarr``) is a tensor of shape (``(n_events, n_dim)``) instead of the usual (``(n_dim,)``) 70 | and, suitably, the output result is not a scalar bur rather a tensor of shape (``(n_events)``). 71 | 72 | Note that the ``example_integrand`` contains only ``TensorFlow`` function and method and operations between ``TensorFlow`` variables: 73 | 74 | .. code-block:: python 75 | 76 | def example_integrand(xarr, weight=None): 77 | s = tf.reduce_sum(xarr, axis=1) 78 | result = tf.pow(0.1/s, 2) 79 | return result 80 | 81 | 82 | By making ``VegasFlow`` integrand depend only on python and ``TensorFlow`` primitives the code can be understood by 83 | ``TenosrFlow`` and be compiled to run on CPU, GPU or other hardware accelerators 84 | as well as to apply optimizations based on `XLA `_. 85 | 86 | It is possible, however (and often useful when prototyping) to integrate functions not 87 | based on ``TensorFlow``, by passing the ``compilable`` flag at compile time. 88 | This will spare the compilation of the integrand (while maintaining the compilation of 89 | the integration algorithm). 90 | 91 | .. code-block:: python 92 | 93 | import numpy as np 94 | 95 | def example_integrand(xarr, weight=None): 96 | s = np.sum(xarr, axis=1) 97 | result = np.square(0.1/s) 98 | return result 99 | 100 | vegas_instance.compile(example_integrand, compilable=False) 101 | 102 | .. note:: Integrands must always accept as first argument the random number (``xarr``) 103 | and can also accept the keyword argument ``weight``. The ``compile`` method of the integration 104 | will try to find the most adequate signature in each situation. 105 | 106 | 107 | It is also possible to completely avoid compilation, 108 | by leveraging ``TensorFlow``'s `eager execution `_ as 109 | explained at :ref:`eager-label`. 110 | 111 | Integrating vector functions 112 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 113 | 114 | It is also possible to integrate vector-valued functions with most algorithms included in ``VegasFlow`` while simply modifying 115 | the integrand to return a vector of values per event instead of a scalar (in other words, the output shape of the result 116 | should be (``(n_events, n_outputs)``). 117 | 118 | .. code-block:: python 119 | 120 | @tf.function 121 | def test_function(xarr): 122 | res = tf.square((xarr - 1.0) ** 2) 123 | return tf.exp(-res) 124 | 125 | 126 | For adaptative algorithms however only one of the dimensions is taken into account to adapt the grid 127 | (by default it will be the first output). 128 | In ``VegasFlow`` it is possible to modify this beahaviour with the ``main_dimension`` keyword argument. 129 | 130 | 131 | .. code-block:: python 132 | 133 | vegas = VegasFlow(dim, ncalls, main_dimension=1) 134 | 135 | 136 | ``VegasFlow`` will automatically (by trying to evaluate the integrand with a small number of events) try to 137 | discover whether the functon is vector-valued and will check a) whether the algorithm can integrate vector-valued integrals 138 | and b) whether the ``main_dimension`` index is contained in the dimensionality of the output. 139 | 140 | 141 | .. note:: Remember that python lists and arrays are 0-indexed and such for an output with 2 components the index of the last dimension is 1 and not 2! 142 | 143 | 144 | Choosing the correct types 145 | ^^^^^^^^^^^^^^^^^^^^^^^^^^ 146 | 147 | A common pitfall when writing ``TensorFlow``-compilable integrands is to mix different precision types. 148 | If a function is compiled with a 32-bit float input not only it won't work when called with a 64-bit 149 | float, but it will catastrophically fail. 150 | The types in ``VegasFlow`` can be controlled via :ref:`environ-label` but we also provide the 151 | ``float_me`` and ``int_me`` function in order to ensure that all variables in the program have consistent 152 | types. 153 | 154 | These functions are wrappers around ``tf.cast`` `🔗 `__. 155 | 156 | .. code-block:: python 157 | 158 | from vegasflow import float_me, int_me 159 | import tensorflow as tf 160 | 161 | constant = float_me(0.1) 162 | 163 | def example_integrand(xarr, weight=None): 164 | s = tf.reduce_sum(xarr, axis=1) 165 | result = tf.pow(constant/s, 2) 166 | return result 167 | 168 | vegas_instance.compile(example_integrand) 169 | 170 | 171 | 172 | Integration wrappers 173 | ^^^^^^^^^^^^^^^^^^^^ 174 | 175 | Although manually instantiating the integrator allows for a better fine-grained control 176 | of the integration, it is also possible to use wrappers which automatically do most of the work 177 | behind the scenes. 178 | 179 | .. code-block:: python 180 | 181 | from vegasflow import vegas_wrapper 182 | 183 | result = vegas_wrapper(example_integrand, dims, n_iter, n_calls, compilable=False) 184 | 185 | 186 | The full list of integration algorithms and wrappers can be consulted at: :ref:`intalg-label`. 187 | 188 | 189 | Tips and Tricks 190 | =============== 191 | 192 | Changing the integration limits 193 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 194 | 195 | By default ``VegasFlow`` provides random number only in the 0 to 1 range (and so all integrals are expected to be integrals from 0 to 1). 196 | But it is possible to choose any other ranges by passing to the initializer of the algorithm the ``xmin`` and ``xman`` variables. 197 | 198 | Note that if any limit is to be changed all ``xmin`` and ``xmax`` must be provided: 199 | 200 | .. code-block:: python 201 | 202 | from vegasflow import VegasFlow 203 | 204 | dimensions = 2 205 | vegas_instance = VegasFlow(dimensions, n_calls, xmin=[0, -4], xmax=[1, 10]) 206 | 207 | 208 | Seeding the random number generator 209 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 210 | 211 | Seeding operations in ``TensorFlow`` is not always trivial. 212 | We include in all integrators the method ``set_seed`` which is a wrapper to 213 | ``TensorFlow``'s own `seed method `_. 214 | 215 | .. code-block:: python 216 | 217 | from vegasflow import VegasFlow 218 | 219 | vegas_instance = VegasFlow(dimensions, n_calls) 220 | vegas_instance.set_seed(7) 221 | 222 | 223 | This is equivalent to: 224 | 225 | .. code-block:: python 226 | 227 | from vegasflow import VegasFlow 228 | import tensorflow as tf 229 | 230 | vegas_instance = VegasFlow(dimensions, n_calls) 231 | tf.random.set_seed(7) 232 | 233 | 234 | This seed is what ``TensorFlow`` calls a global seed and is then used to generate operation-level seeds. 235 | In graph mode (see :ref:`eager-label`) all top level ``tf.functions`` branch out 236 | of the same initial state. 237 | As a consequence, if we were to run two separate instances of ``VegasFlow``, 238 | despite running sequentially, they would both run with the same seed. 239 | Note that this only occurs if the seed is manually set. 240 | 241 | .. code-block:: python 242 | 243 | from vegasflow import vegas_wrapper 244 | import tensorflow as tf 245 | 246 | tf.random.set_seed(7) 247 | result_1 = vegas_wrapper(example_integrand, dims, n_iter, n_calls) 248 | result_2 = vegas_wrapper(example_integrand, dims, n_iter, n_calls) 249 | assert result_1 == result_2 250 | 251 | 252 | The way ``TensorFlow`` seeding works can be consulted here `here `_. 253 | 254 | .. note:: Even when using seed, reproducibility is not guaranteed between two different versions of TensorFlow. 255 | 256 | 257 | Constructing differentiable and compilable integrations 258 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 259 | 260 | An interface to generate integration callabales that can be used inside a TensorFlow library (for instance, inside a Neural Network) 261 | is provided through the ``make_differentiable`` method. 262 | This method will make the necessary changes to the integration, mainly 263 | such as freezing the grid and ensuring that only one device is used, 264 | and it returns a callable function that can be used as just another TensorFlow function. 265 | 266 | In the following example, we generate a function to be integrated 267 | (which can depend on external input through the mutable variable ``z``). 268 | Afterwards, the function is compiled (and trained) as a normal integrand, 269 | until we call ``make_differentiable``. 270 | At that point the grid is frozen and a ``runner`` is returned which will 271 | run the integration result. 272 | The ``runner`` can now be used inside a ``tf.function``-compiled function 273 | and gradients can be computed as shown below. 274 | 275 | 276 | .. code-block:: python 277 | 278 | from vegasflow import VegasFlow, float_me 279 | import tensorflow as tf 280 | 281 | dims = 4 282 | n_calls = int(1e4) 283 | vegas_instance = VegasFlow(dims, n_calls, verbose=False) 284 | z = tf.Variable(float_me(1.0)) 285 | 286 | def example_integrand(x, **kwargs): 287 | y = tf.reduce_sum(x, axis=1) 288 | return y*z 289 | 290 | vegas_instance.compile(example_integrand) 291 | # Now we run a few iterations to train the grid, but we can bin them 292 | _ = vegas_instance.run_integration(3) 293 | 294 | runner = vegas_instance.make_differentiable() 295 | 296 | @tf.function 297 | def some_complicated_function(x): 298 | integration_result, error, _ = runner() 299 | return x*integration_result 300 | 301 | my_x = float_me(4.0) 302 | result = some_complicated_function(my_x) 303 | 304 | def compute_and_print_gradient(): 305 | with tf.GradientTape() as tape: 306 | tape.watch(my_x) 307 | y = some_complicated_function(my_x) 308 | 309 | grad = tape.gradient(y, my_x) 310 | print(f"Result {y.numpy():.3}, gradient: {grad.numpy():.3}") 311 | 312 | compute_and_print_gradient() 313 | z.assign(float_me(4.0)) 314 | compute_and_print_gradient() 315 | 316 | 317 | Just In Time (jit) compilation 318 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 319 | 320 | When compiling very large functions, ``tensorflow`` might take much more memory than expected. 321 | In those cases, and for recent versions of ``tensorflow``, it might be beneficial to compile 322 | the function using XLA, which you can enable with the optional argument ``jit_compile=True``. `See in the tensorflow docs `_ 323 | 324 | .. code-block:: python 325 | 326 | @tf.function(jit_compile=True) 327 | def some_complicated_function(x): 328 | integration_result, error, _ = runner() 329 | return x*integration_result 330 | 331 | 332 | Running in distributed systems 333 | ============================== 334 | 335 | ``vegasflow`` implements an easy interface to distributed system via 336 | the `dask `_ library. 337 | In order to enable it, it is enough to call the ``set_distribute`` method 338 | of the instantiated integrator class. 339 | This method takes a `dask_jobqueue `_ 340 | to send the jobs to. 341 | 342 | An example can be found in the `examples/cluster_dask.py `_ file where 343 | a `SLURM `_ cluster is used as an example 344 | 345 | .. note:: When the distributing capabilities of dask are being useful, ``VegasFlow`` "forfeits" control of the devices in which to run, trusting ``TensorFlow``'s defaults. To run, for instance, two GPUs in one single node while using dask the user should send two separate dask jobs, each targetting a different GPU. 346 | 347 | 348 | 349 | Global configuration 350 | ==================== 351 | 352 | Verbosity 353 | ^^^^^^^^^ 354 | 355 | ``VegasFlow`` uses the internal logging capabilities of python by 356 | creating a new logger handle named ``vegasflow``. 357 | You can modify the behavior of the logger as with any sane python library with the following lines: 358 | 359 | .. code-block:: python 360 | 361 | import logging 362 | 363 | log_dict = { 364 | "0" : logging.ERROR, 365 | "1" : logging.WARNING, 366 | "2" : logging.INFO, 367 | "3" : logging.DEBUG 368 | } 369 | logger_vegasflow = logging.getLogger('vegasflow') 370 | logger_vegasflow.setLevel(log_dict["0"]) 371 | 372 | Where the log level can be any level defined in the ``log_dict`` dictionary. 373 | 374 | Since ``VegasFlow`` is meant to be interfaced with non-python code it is also 375 | possible to control the behaviour through the environment variable ``VEGASFLOW_LOG_LEVEL``, in that case any of the keys in ``log_dict`` can be used. For instance: 376 | 377 | .. code-block:: bash 378 | 379 | export VEGASFLOW_LOG_LEVEL=1 380 | 381 | will suppress all logger information other than ``WARNING`` and ``ERROR``. 382 | 383 | 384 | 385 | .. _environ-label: 386 | 387 | Environment 388 | ^^^^^^^^^^^ 389 | 390 | ``VegasFlow`` is based on ``TensorFlow`` and as such all environment variables that 391 | have an effect on ``TensorFlow``'s behavior will also have an effect on ``VegasFlow``. 392 | 393 | Here we describe only some of what we found to be the most useful variables. 394 | For a complete description of the variables controlling the GPU-behavior of ``TensorFlow`` please refer to 395 | the `nvidia official documentation `_. 396 | 397 | - ``TF_CPP_MIN_LOG_LEVEL``: controls the ``TensorFlow`` logging level. It is set to 1 by default so that only errors are printed. 398 | - ``VEGASFLOW_LOG_LEVEL``: controls the ``VegasFlow`` logging level. Set to 3 by default so that everything is printed. 399 | - ``VEGASFLOW_FLOAT``: controls the ``VegasFlow`` float precision. Default is 64 for 64-bits. Accepts: 64, 32. 400 | - ``VEGASFLOW_INT``: controls the ``VegasFlow`` integer precision. Default is 32 for 32-bits. Accepts: 64, 32. 401 | 402 | 403 | Choosing integration device 404 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 405 | 406 | The ``CUDA_VISIBLE_DEVICES`` environment variable will tell ``Tensorflow`` 407 | (and thus ``VegasFlow``) on which device(s) it should run. 408 | If this variable is not set, it will default to using all available GPUs and avoid running on the CPU. 409 | In order to use the CPU you can hide the GPU by setting 410 | ``export CUDA_VISIBLE_DEVICES=""``. 411 | 412 | If you have a set-up with more than one GPU you can select which one you 413 | want to use for the integration by setting the environment variable to the 414 | right device, e.g., ``export CUDA_VISIBLE_DEVICES=0``. 415 | 416 | 417 | 418 | .. _eager-label: 419 | 420 | Eager Vs Graph-mode 421 | ^^^^^^^^^^^^^^^^^^^ 422 | 423 | When performing computationally expensive tasks ``Tensorflow``'s graph mode is preferred. 424 | When compiling you will notice the first iteration of the integration takes a 425 | bit longer, this is normal and it's due to the creation of the graph. 426 | Subsequent iterations will be faster. 427 | 428 | Graph-mode, however, is not debugger friendly, as the code is read only once, when compiling the graph. 429 | You can, however, enable ``Tensorflow``'s `eager execution `_. 430 | With eager mode the code is run sequentially as you would expect with normal python code, 431 | this will allow you, for instance, to throw in instances of ``pdb.set_trace()``. 432 | In order to use eager execution we provide the ``run_eager`` wrapper. 433 | 434 | .. code-block:: python 435 | 436 | from vegasflow import run_eager 437 | 438 | run_eager() # Enable eager-mode 439 | run_eager(False) # Disable 440 | 441 | 442 | This is a wrapper around the following lines of code: 443 | 444 | .. code-block:: python 445 | 446 | import tensorflow as tf 447 | tf.config.run_functions_eagerly(True) 448 | 449 | or if you are using versions of ``TensorFlow`` older than 2.3: 450 | 451 | .. code-block:: python 452 | 453 | import tensorflow as tf 454 | tf.config.experimental_run_functions_eagerly(True) 455 | 456 | 457 | Eager mode also enables the usage of the library as a `standard` python library 458 | allowing you to integrate non-tensorflow integrands. 459 | These integrands, as they are not understood by ``TensorFlow``, are not run using 460 | GPU kernels while the rest of ``VegasFlow`` will still be run on GPU if possible. 461 | 462 | 463 | Histograms 464 | ========== 465 | 466 | A commonly used feature in Monte Carlo calculations is the generation of histograms. 467 | In order to generate them while at the same time keeping all the features of ``VegasFlow``, 468 | such as GPU computing, it is necessary to ensure that the histogram generation is also wrapped with the ``@tf.function`` directive. 469 | 470 | Below we show one such example (how the histogram is actually generated and saved is up to the user). 471 | The first step is to create a ``Variable`` tensor which will be used to fill the histograms. 472 | This is a crucial step (and the only fixed step) as this tensor will be accumulated internally by ``VegasFlow``. 473 | 474 | 475 | .. code-block:: python 476 | 477 | from vegasflow.utils import consume_array_into_indices 478 | from vegasflow.configflow import fzero, fone, int_me, DTYPE 479 | 480 | HISTO_BINS = int_me(2) 481 | cumulator_tensor = tf.Variable(tf.zeros(HISTO_BINS, dtype=DTYPE)) 482 | 483 | @tf.function 484 | def histogram_collector(results, variables): 485 | """ This function will receive a tensor (result) 486 | and the variables corresponding to those integrand results 487 | In the example integrand below, these corresponds to 488 | `final_result` and `histogram_values` respectively. 489 | `current_histograms` instead is the current value of the histogram 490 | which will be overwritten """ 491 | # Fill a histogram with HISTO_BINS (2) bins, (0 to 0.5, 0.5 to 1) 492 | # First generate the indices with TF 493 | indices = tf.histogram_fixed_width_bins( 494 | variables, [fzero, fone], nbins=HISTO_BINS 495 | ) 496 | t_indices = tf.transpose(indices) 497 | # Then consume the results with the utility we provide 498 | partial_hist = consume_array_into_indices(results, t_indices, HISTO_BINS) 499 | # Then update the results of current_histograms 500 | new_histograms = partial_hist + current_histograms 501 | cummulator_tensor.assign(new_histograms) 502 | 503 | @tf.function 504 | def integrand_example(xarr, weight=fone): 505 | # some complicated calculation that generates 506 | # a final_result and some histogram values: 507 | final_result = tf.constant(42, dtype=tf.float64) 508 | histogram_values = xarr 509 | histogram_collector(final_result * weight, histogram_values) 510 | return final_result 511 | 512 | Finally we can call ``VegasFlow``, remembering to pass down the accumulator tensor, which will be filled in with the histograms. 513 | Note that here we are only filling in one histogram and so the histogram tuple contains only one element, but any number of histograms may be filled. 514 | 515 | 516 | .. code-block:: python 517 | 518 | histogram_tuple = (cumulator_tensor,) 519 | results = mc_instance.run_integration(n_iter, histograms=histogram_tuple) 520 | 521 | 522 | We include an example of an integrand which generates histograms in `examples/histogram.py `_ 523 | 524 | Generate conditions 525 | =================== 526 | 527 | A very common case when integrating using Monte Carlo method is to add non trivial cuts to the 528 | integration space. 529 | It is not obvious how to implement cuts in a consistent manner on a GPU or using ``TensorFlow`` 530 | routines when we have to combine several conditions. 531 | We provide the ``generate_condition_function`` auxiliary function which generates 532 | a ``TensorFlow``-compiled function for the necessary number of conditions. 533 | 534 | For instance, let's take the case of a parton collision simulation, in which 535 | we want to constrain the phase space of the two final state particles to the region 536 | in which the two particles have a transverse momentum above 15 GeV, or any of them have 537 | a rapidity below 4. 538 | 539 | We first generate the condition we want to apply using ``generate_condition_function``. 540 | 541 | .. code-block:: python 542 | 543 | from vegasflow.utils import generate_condition_function 544 | 545 | f_cond = generate_condition_function(3, condition = ['and', 'or']) 546 | 547 | 548 | Now we can use the ``f_cond`` function in our integrand. 549 | This ``f_cond`` function accepts three arguments and returns a mask of all of them 550 | and the ``True`` indices. 551 | 552 | .. code-block:: python 553 | 554 | import tensorflow as tf 555 | from vegasflow import vegas_wrapper 556 | 557 | def two_particle(xarr, **kwargs): 558 | # Complicated calculation of phase space 559 | pt_jet_1 = xarr[:,0]*100 + 5 560 | pt_jet_2 = xarr[:,1]*100 + 5 561 | rapidity = xarr[:,2]*50 562 | # Generate the conditions 563 | c_1 = pt_jet_1 > 15 564 | c_2 = pt_jet_2 > 15 565 | c_3 = rapidity < 4 566 | mask, idx = f_cond(c_1, c_2, c_3) 567 | # Now we can mask away the unwanted results 568 | good_vals = tf.boolean_mask(xarr[:,3], mask, axis=0) 569 | # Perform very complicated calculation 570 | result = tf.square(good_vals) 571 | # Return a sparse tensor so that only the actual results have a value 572 | ret = tf.scatter_nd(idx, result, shape=c_1.shape) 573 | return ret 574 | 575 | result = vegas_wrapper(two_particle, 4, 3, 100, compilable=False) 576 | 577 | Note that we use the mask to remove the values that are not part of the phase space. 578 | If the phase space to be integrated is much smaller than the integration region, 579 | removing unwanted values can have a huge impact in the calculation from the 580 | point of view of speed and memory, so we recommend removing them instead of just 581 | zeroing them. 582 | 583 | The resulting array, however, must have one value per event, so before returning 584 | back the array to ``VegasFlow`` we use ``tf.scatter_nd`` to create a sparse tensor 585 | where all values are set to 0 except the indices defined in ``idx`` that 586 | have the values defined by ``result``. 587 | --------------------------------------------------------------------------------