├── .gitignore ├── LICENSE.md ├── MANIFEST.in ├── README.md ├── qrnn.py ├── setup.py ├── src ├── Makefile ├── fo_pool_op.h ├── fo_pool_op_cpu.cpp ├── fo_pool_op_cpu.h ├── fo_pool_op_gpu.cpp ├── fo_pool_op_gpu.h ├── fo_pool_op_kernel.cu ├── fo_pool_op_kernel.h ├── third_party │ ├── nsync.h │ ├── nsync_atomic.h │ ├── nsync_counter.h │ ├── nsync_cpp.h │ ├── nsync_cv.h │ ├── nsync_debug.h │ ├── nsync_mu.h │ ├── nsync_mu_wait.h │ ├── nsync_note.h │ ├── nsync_once.h │ ├── nsync_time.h │ ├── nsync_time_internal.h │ └── nsync_waiter.h ├── thread_pool.cpp └── thread_pool.h └── test └── test_fo_pool.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.so 2 | *.o 3 | .d 4 | .DS_Store 5 | ._.DS_Store 6 | __pycache__ 7 | *.pyc 8 | *.egg-info 9 | build/ 10 | dist/ 11 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2019 Jonathan Raiman 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include src *.cpp *.h *.cu 2 | include src/Makefile 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Quasi-Recurrent Neural Network (QRNN) for Tensorflow 2 | 3 | This repository contains a Tensorflow implementation of [Salesforce Research](https://einstein.ai/)'s [Quasi-Recurrent Neural Networks](https://arxiv.org/abs/1611.01576) paper. It supports batch-major or time-major inputs in single or double precision. 4 | 5 | From the authors: 6 | > The QRNN provides similar accuracy to the LSTM but can be betwen 2 and 17 times faster than the highly optimized NVIDIA cuDNN LSTM implementation depending on the use case. 7 | 8 | To install, simply run: 9 | 10 | `pip3 install qrnn` 11 | 12 | If you use this code or their results in your research, you should cite: 13 | 14 | ``` 15 | @article{bradbury2016quasi, 16 | title={{Quasi-Recurrent Neural Networks}}, 17 | author={Bradbury, James and Merity, Stephen and Xiong, Caiming and Socher, Richard}, 18 | journal={International Conference on Learning Representations (ICLR 2017)}, 19 | year={2017} 20 | } 21 | ``` 22 | 23 | The original PyTorch implementation of the QRNN can be found [here](https://github.com/salesforce/pytorch-qrnn). 24 | 25 | ### Requirements 26 | 27 | - Tensorflow 1.4 (`pip install tensorflow` or `pip install tensorflow-gpu`) 28 | - GCC 29 | - CUDA (optional, needed for GPU support) 30 | 31 | ### Testing 32 | 33 | ``` 34 | python3 test/test_fo_pool.py 35 | ``` 36 | 37 | ### TODOs: 38 | 39 | - create wheels for Fedora, Ubuntu, etc... 40 | -------------------------------------------------------------------------------- /qrnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Quasi-Recurrent Neural Network (QRNN) for Tensorflow 3 | ---------------------------------------------------- 4 | 5 | This repository contains a Tensorflow implementation of 6 | [Salesforce Research](https://einstein.ai/)'s 7 | [Quasi-Recurrent Neural Networks](https://arxiv.org/abs/1611.01576) 8 | paper. It supports batch-major or time-major inputs in 9 | single or double precision. 10 | 11 | From the authors: 12 | > The QRNN provides similar accuracy to the LSTM but can be betwen 13 | > 2 and 17 times faster than the highly optimized NVIDIA cuDNN 14 | > LSTM implementation depending on the use case. 15 | 16 | If you use this code or their results in your research, you should cite: 17 | 18 | @article{bradbury2016quasi, 19 | title={{Quasi-Recurrent Neural Networks}}, 20 | author={Bradbury, James and Merity, Stephen and Xiong, Caiming and Socher, Richard}, 21 | journal={International Conference on Learning Representations (ICLR 2017)}, 22 | year={2017} 23 | } 24 | 25 | Usage 26 | ----- 27 | 28 | Use QRNNs as you would use LSTMs or RNNs, to 29 | encode order-specific information: 30 | 31 | ``` 32 | import qrnn 33 | 34 | # input sequence in Batch, Time, Channels format: 35 | inputs = tf.placeholder(tf.float32, [None, None, 128]) 36 | encoded = qrnn.qrnn(inputs) 37 | 38 | with tf.Session() as sess: 39 | sess.run(tf.global_variables_initializer()) 40 | out = sess.run(encoded, {inputs: my_data}) 41 | ``` 42 | 43 | """ 44 | import tensorflow as tf 45 | from os.path import join, dirname, realpath 46 | 47 | SCRIPT_DIR = dirname(realpath(__file__)) 48 | 49 | def get_ext_filename(ext_name): 50 | from distutils.sysconfig import get_config_var 51 | ext_path = ext_name.split('.') 52 | ext_suffix = get_config_var('EXT_SUFFIX') 53 | return join(*ext_path) + ext_suffix 54 | 55 | 56 | qrnn_lib = tf.load_op_library(join(SCRIPT_DIR, get_ext_filename("qrnn_lib"))) 57 | 58 | time_major_fo_pool_unsliced = qrnn_lib.time_major_fo_pool 59 | time_major_bwd_fo_pool = qrnn_lib.time_major_bwd_fo_pool 60 | 61 | batch_major_fo_pool_unsliced = qrnn_lib.batch_major_fo_pool 62 | batch_major_bwd_fo_pool = qrnn_lib.batch_major_bwd_fo_pool 63 | 64 | @tf.RegisterGradient("TimeMajorFoPool") 65 | def _fo_pool_grad(op, grad): 66 | return time_major_bwd_fo_pool(h=op.outputs[0], x=op.inputs[0], 67 | forget=op.inputs[1], gh=grad) 68 | 69 | @tf.RegisterGradient("BatchMajorFoPool") 70 | def _fo_pool_grad(op, grad): 71 | return batch_major_bwd_fo_pool(h=op.outputs[0], x=op.inputs[0], 72 | forget=op.inputs[1], gh=grad) 73 | 74 | 75 | def fo_pool(x, forget, initial_state=None, time_major=False): 76 | """Applies a single layer Quasi-Recurrent Neural Network (QRNN) to an input sequence. 77 | Args: 78 | x: Tensor, input values in [Batch, Time, Channels] format, 79 | float32 or double 80 | or [Time, Batch, Channels] if time_major 81 | forget: Tensor, input values in [Batch, Time, Channels] format, 82 | float32 or double. Usually in the range 0-1. 83 | or [Time, Batch, Channels] if time_major 84 | initial_state: Tensor, initial hidden state values in [Batch, Channels] format, 85 | float32 or double. 86 | 87 | Returns: 88 | Tensor: fo_pooled output, [Batch, Time, Channels] format 89 | or [Time, Batch, Channels] if time_major 90 | """ 91 | if initial_state is None: 92 | initial_state = tf.zeros((tf.shape(x)[1] if time_major else tf.shape(x)[0], 93 | tf.shape(x)[2]), dtype=tf.dtype) 94 | if time_major: 95 | return time_major_fo_pool_unsliced(x, forget, initial_state)[1:] 96 | else: 97 | return batch_major_fo_pool_unsliced(x, forget, initial_state)[:, 1:] 98 | 99 | 100 | def qrnn(inputs, num_outputs, window=2, output_gate=True, 101 | activation_fn=tf.tanh, gate_activation_fn=tf.nn.sigmoid, 102 | padding="SAME", initial_state=None, time_major=False, scope=None, 103 | **kwargs): 104 | """Applies a single layer Quasi-Recurrent Neural Network (QRNN) to an input sequence. 105 | Args: 106 | inputs: Tensor, input values in [Batch, Time, Channels] format, 107 | float32 or double, or [Time, Batch, Channels] if time_major 108 | window: int, number of values each gating depends on (default=2). 109 | num_outputs: int, Number of output channels 110 | keep_prob: float, zoneout dropout probability 111 | is_training: bool, whether to apply dropout mask 112 | output_gate: bool, use a gating mechanism on the output 113 | activation_fn: function, default tanh 114 | gate_activation_fn: function, default sigmoid 115 | padding: str, SAME or VALID. 116 | initial_state: Tensor/None, optional, initializes the QRNN to that value. 117 | time_major: bool, whether inputs have time-dimension first or second. 118 | scope: str/None, what to prefix the name the variables under this layer. 119 | 120 | Returns: 121 | Tensor : qrnn_output, [Batch, Time, Channels] or 122 | [Time, Batch, Channels] if time_major 123 | """ 124 | with tf.variable_scope(scope or "QRNNLayer"): 125 | conv1d_channels = 3 * num_outputs if output_gate else 2 * num_outputs 126 | if time_major: 127 | # go to batch_major for convolution if needed 128 | inputs_batch_major = tf.transpose(inputs, (1, 0, 2), name="InputsBatchMajor") 129 | else: 130 | inputs_batch_major = inputs 131 | gate_values = tf.layers.conv1d(inputs_batch_major, 132 | filters=conv1d_channels, 133 | kernel_size=window, 134 | strides=1, 135 | data_format="channels_last", 136 | name="QRNNConv1D", 137 | padding=padding, 138 | **kwargs) 139 | if time_major: 140 | # return to time_major if needed 141 | gate_values = tf.transpose(gate_values, (1, 0, 2), 142 | name="GateValuesTimeMajor") 143 | 144 | gate_values = tf.split(gate_values, 3 if output_gate else 2, axis=2) 145 | if output_gate: 146 | x, forget, output = gate_values 147 | else: 148 | x, forget = gate_values 149 | 150 | with tf.name_scope("GateActivations"): 151 | if activation_fn is not None: 152 | x = activation_fn(x) 153 | if gate_activation_fn is not None: 154 | forget = gate_activation_fn(forget) 155 | 156 | with tf.name_scope("FoPool"): 157 | c = fo_pool(x, forget, initial_state=initial_state, 158 | time_major=time_major) 159 | 160 | with tf.name_scope("OutputGate"): 161 | h = gate_activation_fn(c) if output_gate else c 162 | return h, c 163 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from os.path import join, dirname, realpath, relpath, splitext, abspath, exists, getmtime, relpath, lexists, islink 4 | from os import walk, sep, remove, listdir, stat, symlink, pathsep 5 | from setuptools import setup 6 | from distutils.extension import Extension 7 | from distutils.command.build_ext import build_ext 8 | from setuptools.command.develop import develop 9 | from setuptools.command.install import install 10 | import warnings 11 | import tempfile 12 | import subprocess 13 | 14 | import numpy as np 15 | import tensorflow as tf 16 | 17 | def find_in_path(name, path): 18 | "Find a file in a search path" 19 | #adapted fom http://code.activestate.com/recipes/52224-find-a-file-given-a-search-path/ 20 | solutions = [] 21 | for dir in path.split(pathsep): 22 | binpath = join(dir, name) 23 | if exists(binpath): 24 | solutions.append(abspath(binpath)) 25 | if len(solutions) >= 1: 26 | if any("usr" in sol for sol in solutions): 27 | solutions = [sol for sol in solutions if "usr" in sol] 28 | return solutions[0] 29 | return None 30 | 31 | 32 | def locate_cuda(): 33 | """Locate the CUDA environment on the system 34 | Returns a dict with keys 'home', 'nvcc', 'include', and 'lib64' 35 | and values giving the absolute path to each directory. 36 | Starts by looking for the CUDAHOME env variable. If not found, everything 37 | is based on finding 'nvcc' in the PATH. 38 | """ 39 | 40 | # first check if the CUDAHOME env variable is in use 41 | if "CUDAHOME" in os.environ: 42 | home = os.environ["CUDAHOME"] 43 | nvcc = join(home, "bin", "nvcc") 44 | else: 45 | # otherwise, search the PATH for NVCC 46 | nvcc = find_in_path("nvcc", os.environ["PATH"]) 47 | print() 48 | if nvcc is None: 49 | return None 50 | home = dirname(dirname(nvcc)) 51 | print(home) 52 | cudaconfig = {"nvcc": nvcc, 53 | "include": [join(home, "include"), join(home, "include", "cuda")], 54 | "lib64": [join(home, "lib64"), join(home, "lib")]} 55 | for k, v in cudaconfig.items(): 56 | if isinstance(v, str): 57 | v = [v] 58 | all_missing = all(not exists(path) for path in v) 59 | if all_missing: 60 | raise EnvironmentError("The CUDA %s path could not be located in %r" % (k, v)) 61 | 62 | return cudaconfig 63 | 64 | 65 | def customize_compiler_for_nvcc(self): 66 | """inject deep into distutils to customize how the dispatch 67 | to gcc/nvcc works. 68 | If you subclass UnixCCompiler, it's not trivial to get your subclass 69 | injected in, and still have the right customizations (i.e. 70 | distutils.sysconfig.customize_compiler) run on it. So instead of going 71 | the OO route, I have this. Note, it's kindof like a wierd functional 72 | subclassing going on.""" 73 | 74 | # tell the compiler it can processes .cu 75 | self.src_extensions.append('.cu') 76 | 77 | # save references to the default compiler_so and _comple methods 78 | default_compiler_so = self.compiler_so 79 | super = self._compile 80 | 81 | # now redefine the _compile method. This gets executed for each 82 | # object but distutils doesn't have the ability to change compilers 83 | # based on source extension: we add it. 84 | def _compile(obj, src, ext, cc_args, extra_postargs, pp_opts): 85 | if os.path.splitext(src)[1] == ".cu": 86 | # use the cuda for .cu files 87 | self.set_executable("compiler_so", CUDA["nvcc"]) 88 | # use only a subset of the extra_postargs, which are 1-1 translated 89 | # from the extra_compile_args in the Extension class 90 | postargs = extra_postargs["nvcc"] 91 | else: 92 | postargs = extra_postargs["gcc"] 93 | 94 | super(obj, src, ext, cc_args, postargs, pp_opts) 95 | # reset the default compiler_so, which we might have changed for cuda 96 | self.compiler_so = default_compiler_so 97 | 98 | # inject our redefined _compile method into the class 99 | self._compile = _compile 100 | 101 | 102 | def check_openmp_presence(): 103 | source = """ 104 | #include 105 | int main() { 106 | #ifdef _OPENMP 107 | return 0; 108 | #else 109 | breaks_on_purpose 110 | #endif 111 | } 112 | """ 113 | with tempfile.NamedTemporaryFile() as foutput: 114 | with tempfile.NamedTemporaryFile() as ftest: 115 | with open(ftest.name, "wt") as fout: 116 | fout.write(source) 117 | try: 118 | out = subprocess.check_output(["g++", ftest.name, "-o", foutput.name, "-fopenmp"]) 119 | return True 120 | except subprocess.CalledProcessError: 121 | return False 122 | 123 | 124 | 125 | # run the customize_compiler 126 | class custom_build_ext(build_ext): 127 | def build_extensions(self): 128 | customize_compiler_for_nvcc(self.compiler) 129 | build_ext.build_extensions(self) 130 | 131 | 132 | def find_files_by_suffix(path, suffix): 133 | """Recursively find files with specific suffix in a directory""" 134 | for relative_path, dirs, files in walk(path): 135 | for fname in files: 136 | if fname.endswith(suffix): 137 | yield join(path, relative_path, fname) 138 | 139 | CUDA = locate_cuda() 140 | SRC_DIR = join(dirname(realpath(__file__)), "src") 141 | TF_LIB = tf.sysconfig.get_lib() 142 | TF_INCLUDE = tf.sysconfig.get_include() 143 | TF_CUDA = tf.test.is_built_with_cuda() 144 | HAS_OPENMP = check_openmp_presence() 145 | 146 | if TF_CUDA and CUDA is None: 147 | warnings.warn("qrnn can run on gpu, but nvcc was not found in your path. " 148 | "Either add it to your path, or set the $CUDAHOME variable.") 149 | 150 | USE_CUDA = TF_CUDA and CUDA is not None 151 | 152 | 153 | cu_sources = list(find_files_by_suffix(SRC_DIR, ".cu")) 154 | cpp_sources = list(find_files_by_suffix(SRC_DIR, ".cpp")) 155 | 156 | 157 | cmdclass = {} 158 | include_dirs = [np.get_include(), TF_INCLUDE, SRC_DIR, join(SRC_DIR, "third_party")] 159 | TF_FLAGS = ["-D_MWAITXINTRIN_H_INCLUDED", "-D_FORCE_INLINES", "-D_GLIBCXX_USE_CXX11_ABI=0"] 160 | gcc_extra_compile_args = ["-g", "-std=c++11", "-fPIC", "-O3", "-march=native", "-mtune=native"] + TF_FLAGS 161 | 162 | nvcc_extra_compile_args = [] 163 | extra_link_args = ["-fPIC"] 164 | if HAS_OPENMP: 165 | gcc_extra_compile_args.append("-fopenmp") 166 | extra_link_args.append("-fopenmp") 167 | 168 | if sys.platform == 'darwin': 169 | gcc_extra_compile_args.append('-stdlib=libc++') 170 | nvcc_extra_compile_args.append('-stdlib=libc++') 171 | extra_link_args.append('-stdlib=libc++') 172 | else: 173 | extra_link_args.append("-shared") 174 | 175 | 176 | if USE_CUDA: 177 | cmdclass["build_ext"] = custom_build_ext 178 | gcc_extra_compile_args.extend(["-D", "GOOGLE_CUDA"]) 179 | include_dirs.extend(CUDA["include"]) 180 | nvcc_extra_compile_args.extend(TF_FLAGS + ["-std=c++11", "-D", "GOOGLE_CUDA=1", 181 | "-I", TF_INCLUDE, 182 | "-x", "cu", "--compiler-options", "'-fPIC'", 183 | "--gpu-architecture=sm_30", "-lineinfo", 184 | "-Xcompiler", "-std=c++98"] + ["-I" + path for path in CUDA["include"]]) 185 | extra_compile_args = {"gcc": gcc_extra_compile_args, 186 | "nvcc": nvcc_extra_compile_args} 187 | runtime_library_dirs = CUDA['lib64'] 188 | else: 189 | cu_sources = [] 190 | extra_compile_args = gcc_extra_compile_args 191 | runtime_library_dirs = [] 192 | 193 | 194 | ext = Extension("qrnn_lib", 195 | sources=cu_sources + cpp_sources, 196 | library_dirs=[TF_LIB], 197 | libraries=["tensorflow_framework"], 198 | language="c++", 199 | runtime_library_dirs=runtime_library_dirs, 200 | # this syntax is specific to this build system 201 | # we're only going to use certain compiler args with nvcc and not with gcc 202 | # the implementation of this trick is in customize_compiler() below 203 | extra_compile_args=extra_compile_args, 204 | extra_link_args=extra_link_args, 205 | include_dirs=include_dirs) 206 | 207 | setup(name='qrnn', 208 | # random metadata. there's more you can supploy 209 | author="Jonathan Raiman", 210 | author_email="jonathanraiman@gmail.com", 211 | version="0.2.2", 212 | install_requires=["numpy", "tensorflow>=1.4"], 213 | ext_modules = [ext], 214 | py_modules=["qrnn"], 215 | # inject our custom trigger 216 | cmdclass=cmdclass, 217 | zip=False) 218 | -------------------------------------------------------------------------------- /src/Makefile: -------------------------------------------------------------------------------- 1 | # Tensorflow includes and defines 2 | TF_INC=$(shell python3 -c 'from __future__ import print_function; import tensorflow as tf; print(tf.sysconfig.get_include())') 3 | TF_CUDA=$(shell python3 -c 'from __future__ import print_function; import tensorflow as tf; print(int(tf.test.is_built_with_cuda()))') 4 | 5 | TF_FLAGS=-D_MWAITXINTRIN_H_INCLUDED -D_FORCE_INLINES -D_GLIBCXX_USE_CXX11_ABI=0 6 | TF_LIB=$(shell python3 -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())') 7 | 8 | # Dependencies 9 | DEPDIR:=.d 10 | $(shell mkdir -p $(DEPDIR) >/dev/null) 11 | DEPFLAGS=-MT $@ -MMD -MP -MF $(DEPDIR)/$*.Td 12 | 13 | # Define our sources, compiling CUDA code if it's enabled 14 | ifeq ($(TF_CUDA), 1) 15 | SOURCES=$(wildcard *.cpp *.cu) 16 | GOOGLE_CUDA_FLAG=-D GOOGLE_CUDA=1 17 | GOOGLE_CUDA_INCLUDES= -I/usr/include/cuda 18 | else 19 | SOURCES=$(wildcard *.cpp) 20 | GOOGLE_CUDA_FLAG= 21 | GOOGLE_CUDA_INCLUDES= 22 | endif 23 | 24 | # Define objects and shared_library 25 | OBJECTS=$(addsuffix .o, $(basename $(SOURCES))) 26 | LIBRARY=qrnn_lib.so 27 | 28 | # Compiler flags 29 | INCLUDES= -I $(TF_INC) $(GOOGLE_CUDA_INCLUDES) 30 | 31 | CPPFLAGS=-g -std=c++11 $(TF_FLAGS) $(INCLUDES) -fPIC -fopenmp -O2 -march=native -mtune=native $(GOOGLE_CUDA_FLAG) 32 | NVCCFLAGS=-std=c++11 -D GOOGLE_CUDA=$(TF_CUDA) $(TF_FLAGS) $(INCLUDES) \ 33 | -x cu --compiler-options "-fPIC" --gpu-architecture=sm_30 -lineinfo \ 34 | -Xcompiler -std=c++98 35 | 36 | LDFLAGS = -fPIC -fopenmp -L$(TF_LIB) -ltensorflow_framework 37 | 38 | # Compiler directives 39 | COMPILE.cpp = g++ $(DEPFLAGS) $(CPPFLAGS) -c 40 | COMPILE.nvcc = nvcc --compiler-options " $(DEPFLAGS)" $(NVCCFLAGS) -c 41 | 42 | all : $(LIBRARY) 43 | 44 | %.o : %.cpp 45 | $(COMPILE.cpp) $< 46 | 47 | %.o : %.cu 48 | $(COMPILE.nvcc) $< 49 | 50 | clean : 51 | rm -f $(OBJECTS) $(LIBRARY) 52 | 53 | $(LIBRARY) : $(OBJECTS) 54 | g++ -shared $(OBJECTS) -o $(LIBRARY) $(LDFLAGS) 55 | 56 | $(DEPDIR)/%.d: ; 57 | .PRECIOUS: $(DEPDIR)/%.d 58 | 59 | -include $(patsubst %,$(DEPDIR)/%.d,$(basename $(SRCS))) 60 | -------------------------------------------------------------------------------- /src/fo_pool_op.h: -------------------------------------------------------------------------------- 1 | #ifndef QRNN_FO_POOL_OP_H 2 | #define QRNN_FO_POOL_OP_H 3 | 4 | // tf_qrnn namespace start and stop defines 5 | #define TF_QRNN_NAMESPACE_BEGIN namespace tf_qrnn { 6 | #define TF_QRNN_NAMESPACE_STOP } 7 | 8 | // namespace start and stop defines 9 | #define TF_QRNN_FO_POOL_NAMESPACE_BEGIN namespace { 10 | #define TF_QRNN_FO_POOL_NAMESPACE_STOP } 11 | 12 | TF_QRNN_NAMESPACE_BEGIN 13 | TF_QRNN_FO_POOL_NAMESPACE_BEGIN 14 | 15 | // General definition of the FoPool op, which will be specialised in: 16 | // - fo_pool_op_cpu.h for CPUs 17 | // - fo_pool_op_gpu.cuh for CUDA devices 18 | // Concrete template instantions of this class are provided in: 19 | // - fo_pool_op_cpu.cpp for CPUs 20 | // - fo_pool_op_gpu.cu for CUDA devices 21 | template 22 | class FoPool {}; 23 | 24 | template 25 | class BwdFoPool {}; 26 | 27 | TF_QRNN_FO_POOL_NAMESPACE_STOP 28 | TF_QRNN_NAMESPACE_STOP 29 | 30 | #endif // #ifndef QRNN_FO_POOL_OP_H 31 | -------------------------------------------------------------------------------- /src/fo_pool_op_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include "fo_pool_op_cpu.h" 2 | 3 | #include "tensorflow/core/framework/shape_inference.h" 4 | 5 | TF_QRNN_NAMESPACE_BEGIN 6 | TF_QRNN_FO_POOL_NAMESPACE_BEGIN 7 | 8 | using tensorflow::shape_inference::InferenceContext; 9 | using tensorflow::shape_inference::ShapeHandle; 10 | using tensorflow::shape_inference::DimensionHandle; 11 | using tensorflow::Status; 12 | 13 | auto time_major_fo_pool_shape_function = [](InferenceContext* c) { 14 | // Dummies for tests 15 | ShapeHandle input; 16 | DimensionHandle d; 17 | 18 | ShapeHandle in_x = c->input(0); 19 | // Assert 'x' number of dimensions 20 | TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(in_x, 3, &input), 21 | "x must have shape [None, None, None] but is " + 22 | c->DebugString(in_x)); 23 | ShapeHandle in_forget = c->input(1); 24 | // Assert 'forget' number of dimensions 25 | TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(in_forget, 3, &input), 26 | "forget must have shape [None, None, None] but is " + 27 | c->DebugString(in_forget)); 28 | 29 | ShapeHandle in_hinit = c->input(2); 30 | // Assert 'hinit' number of dimensions 31 | TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(in_hinit, 2, &input), 32 | "hinit must have shape [None, None] but is " + 33 | c->DebugString(in_hinit)); 34 | 35 | std::vector dims(3); 36 | for (int i = 1; i < 3; i++) { 37 | TF_RETURN_IF_ERROR( 38 | c->Merge(c->Dim(in_x, i), c->Dim(in_hinit, i - 1), &dims[i])); 39 | } 40 | 41 | for (int i = 0; i < 3; i++) { 42 | TF_RETURN_IF_ERROR( 43 | c->Merge(c->Dim(in_x, i), c->Dim(in_forget, i), &dims[i])); 44 | } 45 | 46 | TF_RETURN_IF_ERROR(c->Add(c->Dim(in_x, 0), 47 | static_cast(1), 48 | &dims[0])); 49 | 50 | c->set_output(0, c->MakeShape(dims)); 51 | return Status::OK(); 52 | }; 53 | 54 | auto batch_major_fo_pool_shape_function = [](InferenceContext* c) { 55 | // Dummies for tests 56 | ShapeHandle input; 57 | DimensionHandle d; 58 | 59 | ShapeHandle in_x = c->input(0); 60 | // Assert 'x' number of dimensions 61 | TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(in_x, 3, &input), 62 | "x must have shape [None, None, None] but is " + 63 | c->DebugString(in_x)); 64 | ShapeHandle in_forget = c->input(1); 65 | // Assert 'forget' number of dimensions 66 | TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(in_forget, 3, &input), 67 | "forget must have shape [None, None, None] but is " + 68 | c->DebugString(in_forget)); 69 | 70 | ShapeHandle in_hinit = c->input(2); 71 | // Assert 'hinit' number of dimensions 72 | TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(in_hinit, 2, &input), 73 | "hinit must have shape [None, None] but is " + 74 | c->DebugString(in_hinit)); 75 | 76 | std::vector dims(3); 77 | TF_RETURN_IF_ERROR( 78 | c->Merge(c->Dim(in_x, 0), c->Dim(in_hinit, 0), &dims[0])); 79 | TF_RETURN_IF_ERROR( 80 | c->Merge(c->Dim(in_x, 2), c->Dim(in_hinit, 1), &dims[2])); 81 | 82 | for (int i = 0; i < 3; i++) { 83 | TF_RETURN_IF_ERROR( 84 | c->Merge(c->Dim(in_x, i), c->Dim(in_forget, i), &dims[i])); 85 | } 86 | 87 | TF_RETURN_IF_ERROR(c->Add(c->Dim(in_x, 1), 88 | static_cast(1), 89 | &dims[1])); 90 | 91 | c->set_output(0, c->MakeShape(dims)); 92 | return Status::OK(); 93 | }; 94 | 95 | // Register the FoPool operator. 96 | REGISTER_OP("TimeMajorFoPool") 97 | .Input("x: FT") 98 | .Input("forget: FT") 99 | .Input("initial_state: FT") 100 | .Output("output: FT") 101 | .Attr("FT: {float, double} = DT_FLOAT") 102 | .Doc(R"doc(QRNN fo_pool operation.)doc") 103 | .SetShapeFn(time_major_fo_pool_shape_function); 104 | 105 | REGISTER_OP("BatchMajorFoPool") 106 | .Input("x: FT") 107 | .Input("forget: FT") 108 | .Input("initial_state: FT") 109 | .Output("output: FT") 110 | .Attr("FT: {float, double} = DT_FLOAT") 111 | .Doc(R"doc(QRNN fo_pool operation.)doc") 112 | .SetShapeFn(batch_major_fo_pool_shape_function); 113 | 114 | REGISTER_KERNEL_BUILDER( 115 | Name("TimeMajorFoPool") 116 | .TypeConstraint("FT") 117 | .Device(tensorflow::DEVICE_CPU), 118 | FoPool); 119 | 120 | REGISTER_KERNEL_BUILDER( 121 | Name("TimeMajorFoPool") 122 | .TypeConstraint("FT") 123 | .Device(tensorflow::DEVICE_CPU), 124 | FoPool); 125 | 126 | REGISTER_KERNEL_BUILDER( 127 | Name("BatchMajorFoPool") 128 | .TypeConstraint("FT") 129 | .Device(tensorflow::DEVICE_CPU), 130 | FoPool); 131 | 132 | REGISTER_KERNEL_BUILDER( 133 | Name("BatchMajorFoPool") 134 | .TypeConstraint("FT") 135 | .Device(tensorflow::DEVICE_CPU), 136 | FoPool); 137 | 138 | auto time_major_bwd_fo_pool_shape_function = [](InferenceContext* c) { 139 | // Dummies for tests 140 | ShapeHandle input; 141 | DimensionHandle d; 142 | 143 | ShapeHandle in_h = c->input(0); 144 | ShapeHandle in_x = c->input(1); 145 | ShapeHandle in_forget = c->input(2); 146 | ShapeHandle in_gh = c->input(3); 147 | 148 | TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(in_h, 3, &input), 149 | "h must have shape [None, None, None] but is " + 150 | c->DebugString(in_h)); 151 | TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(in_x, 3, &input), 152 | "x must have shape [None, None, None] but is " + 153 | c->DebugString(in_h)); 154 | TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(in_forget, 3, &input), 155 | "forget must have shape [None, None, None] but is " + 156 | c->DebugString(in_forget)); 157 | TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(in_gh, 3, &input), 158 | "gh must have shape [None, None, None] but is " + 159 | c->DebugString(in_gh)); 160 | 161 | std::vector dims({ 162 | c->Dim(in_gh, 1), 163 | c->Dim(in_gh, 2) 164 | }); 165 | 166 | c->set_output(0, in_x); 167 | c->set_output(1, in_forget); 168 | c->set_output(2, c->MakeShape(dims)); 169 | 170 | return Status::OK(); 171 | }; 172 | 173 | auto batch_major_bwd_fo_pool_shape_function = [](InferenceContext* c) { 174 | // Dummies for tests 175 | ShapeHandle input; 176 | DimensionHandle d; 177 | 178 | ShapeHandle in_h = c->input(0); 179 | ShapeHandle in_x = c->input(1); 180 | ShapeHandle in_forget = c->input(2); 181 | ShapeHandle in_gh = c->input(3); 182 | 183 | TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(in_h, 3, &input), 184 | "h must have shape [None, None, None] but is " + 185 | c->DebugString(in_h)); 186 | TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(in_x, 3, &input), 187 | "x must have shape [None, None, None] but is " + 188 | c->DebugString(in_h)); 189 | TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(in_forget, 3, &input), 190 | "forget must have shape [None, None, None] but is " + 191 | c->DebugString(in_forget)); 192 | TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(in_gh, 3, &input), 193 | "gh must have shape [None, None, None] but is " + 194 | c->DebugString(in_gh)); 195 | 196 | std::vector dims({ 197 | c->Dim(in_gh, 0), 198 | c->Dim(in_gh, 2) 199 | }); 200 | 201 | c->set_output(0, in_x); 202 | c->set_output(1, in_forget); 203 | c->set_output(2, c->MakeShape(dims)); 204 | 205 | return Status::OK(); 206 | }; 207 | 208 | REGISTER_OP("TimeMajorBwdFoPool") 209 | .Input("h: FT") 210 | .Input("x: FT") 211 | .Input("forget: FT") 212 | .Input("gh: FT") 213 | .Output("gx: FT") 214 | .Output("gf: FT") 215 | .Output("ginitial_state: FT") 216 | .Attr("FT: {float, double} = DT_FLOAT") 217 | .Doc(R"doc(QRNN fo_pool gradient operation.)doc") 218 | .SetShapeFn(time_major_bwd_fo_pool_shape_function); 219 | 220 | REGISTER_OP("BatchMajorBwdFoPool") 221 | .Input("h: FT") 222 | .Input("x: FT") 223 | .Input("forget: FT") 224 | .Input("gh: FT") 225 | .Output("gx: FT") 226 | .Output("gf: FT") 227 | .Output("ginitial_state: FT") 228 | .Attr("FT: {float, double} = DT_FLOAT") 229 | .Doc(R"doc(QRNN fo_pool gradient operation.)doc") 230 | .SetShapeFn(batch_major_bwd_fo_pool_shape_function); 231 | 232 | REGISTER_KERNEL_BUILDER( 233 | Name("TimeMajorBwdFoPool") 234 | .TypeConstraint("FT") 235 | .Device(tensorflow::DEVICE_CPU), 236 | BwdFoPool); 237 | 238 | REGISTER_KERNEL_BUILDER( 239 | Name("TimeMajorBwdFoPool") 240 | .TypeConstraint("FT") 241 | .Device(tensorflow::DEVICE_CPU), 242 | BwdFoPool); 243 | 244 | REGISTER_KERNEL_BUILDER( 245 | Name("BatchMajorBwdFoPool") 246 | .TypeConstraint("FT") 247 | .Device(tensorflow::DEVICE_CPU), 248 | BwdFoPool); 249 | 250 | REGISTER_KERNEL_BUILDER( 251 | Name("BatchMajorBwdFoPool") 252 | .TypeConstraint("FT") 253 | .Device(tensorflow::DEVICE_CPU), 254 | BwdFoPool); 255 | 256 | TF_QRNN_FO_POOL_NAMESPACE_STOP 257 | TF_QRNN_NAMESPACE_STOP 258 | -------------------------------------------------------------------------------- /src/fo_pool_op_cpu.h: -------------------------------------------------------------------------------- 1 | #ifndef QRNN_FO_POOL_CPU_H 2 | #define QRNN_FO_POOL_CPU_H 3 | 4 | #include "fo_pool_op.h" 5 | #include "thread_pool.h" 6 | 7 | // Required in order for Eigen::ThreadPoolDevice to be an actual type 8 | #define EIGEN_USE_THREADS 9 | 10 | #include "tensorflow/core/framework/op.h" 11 | #include "tensorflow/core/framework/op_kernel.h" 12 | #include "tensorflow/core/util/work_sharder.h" 13 | #include "tensorflow/core/framework/types.h" 14 | 15 | TF_QRNN_NAMESPACE_BEGIN 16 | TF_QRNN_FO_POOL_NAMESPACE_BEGIN 17 | 18 | typedef tensorflow::int64 int64; 19 | 20 | // For simpler partial specialisation 21 | typedef Eigen::ThreadPoolDevice CPUDevice; 22 | 23 | template 24 | void time_major_fo_pool(tensorflow::OpKernelContext* context, 25 | FT *dst, const FT *x, const FT *f, const FT *initial_state, int SEQ, int batch_size, int HIDDEN) { 26 | /* 27 | Note: destination is assumed to be one timestep longer than f or x where dst[0] = h_{-1} 28 | This means dst array has a separate index than that of f or x 29 | */ 30 | auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 31 | const tensorflow::int64 cost = SEQ * HIDDEN * 1000; 32 | Shard(worker_threads.num_threads, worker_threads.num_threads, batch_size, cost, 33 | [&batch_size, x, f, initial_state, dst, &HIDDEN, &SEQ](const int start, const int limit) { 34 | for (int batch_id = start; batch_id < limit; ++batch_id) { 35 | for (int hid = 0; hid < HIDDEN; hid++) { 36 | dst[batch_id * HIDDEN + hid] = initial_state[batch_id * HIDDEN + hid]; 37 | for (int ts = 0 + 1; ts < SEQ + 1; ts++) { 38 | // Good sanity check for debugging - only perform additions to a zeroed chunk of memory 39 | // Addition seems atomic or near atomic - you should get incorrect answers if doubling up via threads 40 | // Note: the index i needs to be offset by one as f[0] (f_t) is used for dst[1] (h_t) etc 41 | // To move timesteps, we step HIDDEN * batch_size 42 | // To move batches, we move HIDDEN 43 | // To move neurons, we move +- 1 44 | // Note: dst[dst_i] = ts * 100 + batch_id * 10 + hid; is useful for debugging 45 | int i = (ts - 1) * HIDDEN * batch_size + batch_id * HIDDEN + hid; 46 | int dst_i = (ts - 0) * HIDDEN * batch_size + batch_id * HIDDEN + hid; 47 | int dst_iminus1 = (ts - 1) * HIDDEN * batch_size + batch_id * HIDDEN + hid; 48 | dst[dst_i] = f[i] * x[i]; 49 | dst[dst_i] += (1 - f[i]) * dst[dst_iminus1]; 50 | } 51 | } 52 | } 53 | }); 54 | } 55 | 56 | template 57 | void batch_major_fo_pool(tensorflow::OpKernelContext* context, 58 | FT *dst, const FT *x, const FT *f, const FT *initial_state, int SEQ, int batch_size, int HIDDEN) { 59 | /* 60 | Note: destination is assumed to be one timestep longer than f or x where dst[0] = h_{-1} 61 | This means dst array has a separate index than that of f or x 62 | */ 63 | auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 64 | const tensorflow::int64 cost = SEQ * HIDDEN * 1000; 65 | Shard(worker_threads.num_threads, worker_threads.num_threads, batch_size, cost, 66 | [&batch_size, x, f, initial_state, dst, &HIDDEN, &SEQ](const int start, const int limit) { 67 | for (int batch_id = start; batch_id < limit; ++batch_id) { 68 | for (int hid = 0; hid < HIDDEN; hid++) { 69 | dst[batch_id * HIDDEN * (SEQ + 1) + hid] = initial_state[batch_id * HIDDEN + hid]; 70 | for (int ts = 0 + 1; ts < SEQ + 1; ts++) { 71 | // Good sanity check for debugging - only perform additions to a zeroed chunk of memory 72 | // Addition seems atomic or near atomic - you should get incorrect answers if doubling up via threads 73 | // Note: the index i needs to be offset by one as f[0] (f_t) is used for dst[1] (h_t) etc 74 | // To move timesteps, we step HIDDEN * batch_size 75 | // To move batches, we move HIDDEN 76 | // To move neurons, we move +- 1 77 | // Note: dst[dst_i] = ts * 100 + batch_id * 10 + hid; is useful for debugging 78 | int i = (ts - 1) * HIDDEN + batch_id * HIDDEN * SEQ + hid; 79 | int dst_i = (ts - 0) * HIDDEN + batch_id * HIDDEN * (SEQ + 1) + hid; 80 | int dst_iminus1 = (ts - 1) * HIDDEN + batch_id * HIDDEN * (SEQ + 1) + hid; 81 | dst[dst_i] = f[i] * x[i]; 82 | dst[dst_i] += (1 - f[i]) * dst[dst_iminus1]; 83 | } 84 | } 85 | } 86 | }); 87 | } 88 | 89 | template 90 | void time_major_bwd_fo_pool(tensorflow::OpKernelContext* context, 91 | const FT *h, const FT *x, const FT *f, const FT *gh, FT *gx, FT *gf, FT *ginitial_state, 92 | int SEQ, int batch_size, int HIDDEN) { 93 | /* 94 | Note: h is assumed to be one timestep longer than f, x, gf, gx, or gh where dst[0] = h_{-1} 95 | This means dst array has a separate index than that of f or x 96 | */ 97 | auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 98 | const tensorflow::int64 cost = SEQ * HIDDEN * 1000; 99 | Shard(worker_threads.num_threads, worker_threads.num_threads, 100 | batch_size, cost, [&batch_size, h, f, x, gh, gf, gx, ginitial_state, &HIDDEN, &SEQ](const int start, const int limit) { 101 | for (int batch_id = start; batch_id < limit; ++batch_id) { 102 | for (int hid = 0; hid < HIDDEN; hid++) { 103 | double running_f = 0; 104 | for (int ts = SEQ - 1 + 1; ts >= 0 + 1; ts--) { 105 | int i = (ts - 1) * HIDDEN * batch_size + batch_id * HIDDEN + hid; 106 | int dst_iminus1 = (ts - 1) * HIDDEN * batch_size + batch_id * HIDDEN + hid; 107 | // 108 | running_f += gh[dst_iminus1]; 109 | // Gradient of X 110 | gx[i] = f[i] * running_f; 111 | // Gradient of F 112 | gf[i] = (x[i] - h[dst_iminus1]) * running_f; 113 | // The line below is likely more numerically stable than (1 - f[i]) * running_f; 114 | running_f = running_f - f[i] * running_f; 115 | } 116 | ginitial_state[batch_id * HIDDEN + hid] = running_f + gh[batch_id * HIDDEN + hid]; 117 | } 118 | } 119 | }); 120 | } 121 | 122 | template 123 | void batch_major_bwd_fo_pool(tensorflow::OpKernelContext* context, 124 | const FT *h, const FT *x, const FT *f, const FT *gh, FT *gx, FT *gf, FT *ginitial_state, 125 | int SEQ, int batch_size, int HIDDEN) { 126 | /* 127 | Note: h is assumed to be one timestep longer than f, x, gf, gx, or gh where dst[0] = h_{-1} 128 | This means dst array has a separate index than that of f or x 129 | */ 130 | auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 131 | const tensorflow::int64 cost = SEQ * HIDDEN * 1000; 132 | Shard(worker_threads.num_threads, worker_threads.num_threads, 133 | batch_size, cost, [&batch_size, h, f, x, gh, gf, gx, ginitial_state, &HIDDEN, &SEQ](const int start, const int limit) { 134 | for (int batch_id = start; batch_id < limit; ++batch_id) { 135 | for (int hid = 0; hid < HIDDEN; hid++) { 136 | double running_f = 0; 137 | for (int ts = SEQ - 1 + 1; ts >= 0 + 1; ts--) { 138 | int i = (ts - 1) * HIDDEN + batch_id * HIDDEN * SEQ + hid; 139 | int dst_iminus1 = (ts - 1) * HIDDEN + batch_id * HIDDEN * (SEQ + 1) + hid; 140 | // 141 | running_f += gh[dst_iminus1]; 142 | // Gradient of X 143 | gx[i] = f[i] * running_f; 144 | // Gradient of F 145 | gf[i] = (x[i] - h[dst_iminus1]) * running_f; 146 | // The line below is likely more numerically stable than (1 - f[i]) * running_f; 147 | running_f = running_f - f[i] * running_f; 148 | } 149 | ginitial_state[batch_id * HIDDEN + hid] = running_f + gh[batch_id * HIDDEN * (SEQ + 1) + hid]; 150 | } 151 | } 152 | }); 153 | } 154 | 155 | // Specialise the FoPool op for CPUs 156 | template 157 | class FoPool : public tensorflow::OpKernel { 158 | public: 159 | explicit FoPool(tensorflow::OpKernelConstruction * context) : 160 | tensorflow::OpKernel(context) {} 161 | 162 | void Compute(tensorflow::OpKernelContext * context) override { 163 | namespace tf = tensorflow; 164 | 165 | // Create reference to input Tensorflow tensors 166 | const auto & in_x = context->input(0); 167 | const auto & in_forget = context->input(1); 168 | const auto & in_initial_state = context->input(2); 169 | 170 | 171 | // Extract Eigen tensors 172 | auto x = in_x.flat().data(); 173 | auto forget = in_forget.flat().data(); 174 | auto initial_state = in_initial_state.flat().data(); 175 | 176 | // Allocate output tensors 177 | // Allocate space for output tensor 'output' 178 | tf::Tensor * output_ptr = nullptr; 179 | auto in_x_shape = in_x.shape(); 180 | tf::TensorShape output_shape = in_x_shape; 181 | if (time_major) { 182 | output_shape.set_dim(0, output_shape.dim_size(0) + 1); 183 | } else { 184 | output_shape.set_dim(1, output_shape.dim_size(1) + 1); 185 | } 186 | OP_REQUIRES_OK(context, context->allocate_output( 187 | 0, output_shape, &output_ptr)); 188 | auto out = output_ptr->flat().data(); 189 | if (time_major) { 190 | time_major_fo_pool(context, 191 | out, 192 | x, 193 | forget, 194 | initial_state, 195 | in_x_shape.dim_size(0), 196 | output_shape.dim_size(1), 197 | output_shape.dim_size(2)); 198 | } else { 199 | batch_major_fo_pool(context, 200 | out, 201 | x, 202 | forget, 203 | initial_state, 204 | in_x_shape.dim_size(1), 205 | output_shape.dim_size(0), 206 | output_shape.dim_size(2)); 207 | } 208 | } 209 | }; 210 | 211 | template 212 | class BwdFoPool : public tensorflow::OpKernel { 213 | public: 214 | explicit BwdFoPool(tensorflow::OpKernelConstruction * context) : 215 | tensorflow::OpKernel(context) {} 216 | 217 | void Compute(tensorflow::OpKernelContext * context) override { 218 | namespace tf = tensorflow; 219 | 220 | const auto& in_h = context->input(0); 221 | const auto& in_x = context->input(1); 222 | const auto& in_forget = context->input(2); 223 | const auto& in_gh = context->input(3); 224 | 225 | // Extract Eigen tensors 226 | auto h = in_h.flat().data(); 227 | auto x = in_x.flat().data(); 228 | auto forget = in_forget.flat().data(); 229 | auto gh = in_gh.flat().data(); 230 | 231 | // Allocate output tensors 232 | // Allocate space for output tensor 'output' 233 | tf::Tensor * out_gx = nullptr; 234 | tf::Tensor * out_gf = nullptr; 235 | tf::Tensor * out_ginitial_state = nullptr; 236 | 237 | auto in_x_shape = in_x.shape(); 238 | tf::TensorShape grad_shape = in_x_shape; 239 | int batch_size = time_major ? in_x_shape.dim_size(1) : in_x_shape.dim_size(0); 240 | tf::TensorShape ginitial_state_shape({batch_size, 241 | in_x_shape.dim_size(2)}); 242 | 243 | OP_REQUIRES_OK(context, context->allocate_output( 244 | 0, grad_shape, &out_gx)); 245 | OP_REQUIRES_OK(context, context->allocate_output( 246 | 1, grad_shape, &out_gf)); 247 | OP_REQUIRES_OK(context, context->allocate_output( 248 | 2, ginitial_state_shape, &out_ginitial_state)); 249 | auto gx = out_gx->flat().data(); 250 | auto gf = out_gf->flat().data(); 251 | auto ginitial_state = out_ginitial_state->flat().data(); 252 | 253 | if (time_major) { 254 | time_major_bwd_fo_pool(context, 255 | h, 256 | x, 257 | forget, 258 | gh, 259 | gx, 260 | gf, 261 | ginitial_state, 262 | grad_shape.dim_size(0), 263 | grad_shape.dim_size(1), 264 | grad_shape.dim_size(2)); 265 | } else { 266 | batch_major_bwd_fo_pool(context, 267 | h, 268 | x, 269 | forget, 270 | gh, 271 | gx, 272 | gf, 273 | ginitial_state, 274 | grad_shape.dim_size(1), 275 | grad_shape.dim_size(0), 276 | grad_shape.dim_size(2)); 277 | } 278 | } 279 | }; 280 | 281 | TF_QRNN_FO_POOL_NAMESPACE_STOP 282 | TF_QRNN_NAMESPACE_STOP 283 | 284 | #endif // #ifndef QRNN_FO_POOL_OP_CPU_H 285 | -------------------------------------------------------------------------------- /src/fo_pool_op_gpu.cpp: -------------------------------------------------------------------------------- 1 | #if GOOGLE_CUDA 2 | 3 | #include "fo_pool_op_gpu.h" 4 | #include "fo_pool_op_kernel.h" 5 | 6 | TF_QRNN_NAMESPACE_BEGIN 7 | TF_QRNN_FO_POOL_NAMESPACE_BEGIN 8 | 9 | // Register a GPU kernel for FoPool 10 | 11 | /* TIME MAJOR */ 12 | 13 | REGISTER_KERNEL_BUILDER( 14 | Name("TimeMajorFoPool") 15 | .TypeConstraint("FT") 16 | .Device(tensorflow::DEVICE_GPU), 17 | FoPool); 18 | 19 | REGISTER_KERNEL_BUILDER( 20 | Name("TimeMajorBwdFoPool") 21 | .TypeConstraint("FT") 22 | .Device(tensorflow::DEVICE_GPU), 23 | BwdFoPool); 24 | 25 | REGISTER_KERNEL_BUILDER( 26 | Name("TimeMajorFoPool") 27 | .TypeConstraint("FT") 28 | .Device(tensorflow::DEVICE_GPU), 29 | FoPool); 30 | 31 | REGISTER_KERNEL_BUILDER( 32 | Name("TimeMajorBwdFoPool") 33 | .TypeConstraint("FT") 34 | .Device(tensorflow::DEVICE_GPU), 35 | BwdFoPool); 36 | 37 | /* BATCH MAJOR */ 38 | 39 | REGISTER_KERNEL_BUILDER( 40 | Name("BatchMajorFoPool") 41 | .TypeConstraint("FT") 42 | .Device(tensorflow::DEVICE_GPU), 43 | FoPool); 44 | 45 | REGISTER_KERNEL_BUILDER( 46 | Name("BatchMajorBwdFoPool") 47 | .TypeConstraint("FT") 48 | .Device(tensorflow::DEVICE_GPU), 49 | BwdFoPool); 50 | 51 | REGISTER_KERNEL_BUILDER( 52 | Name("BatchMajorFoPool") 53 | .TypeConstraint("FT") 54 | .Device(tensorflow::DEVICE_GPU), 55 | FoPool); 56 | 57 | REGISTER_KERNEL_BUILDER( 58 | Name("BatchMajorBwdFoPool") 59 | .TypeConstraint("FT") 60 | .Device(tensorflow::DEVICE_GPU), 61 | BwdFoPool); 62 | 63 | 64 | TF_QRNN_FO_POOL_NAMESPACE_STOP 65 | TF_QRNN_NAMESPACE_STOP 66 | 67 | #endif // #if GOOGLE_CUDA 68 | -------------------------------------------------------------------------------- /src/fo_pool_op_gpu.h: -------------------------------------------------------------------------------- 1 | #if GOOGLE_CUDA 2 | 3 | #ifndef QRNN_FO_POOL_OP_GPU_CUH 4 | #define QRNN_FO_POOL_OP_GPU_CUH 5 | 6 | #include "fo_pool_op.h" 7 | #include "fo_pool_op_kernel.h" 8 | 9 | // Required in order for Eigen::GpuDevice to be an actual type 10 | #define EIGEN_USE_GPU 11 | 12 | #include "tensorflow/core/framework/op.h" 13 | #include "tensorflow/core/framework/op_kernel.h" 14 | 15 | TF_QRNN_NAMESPACE_BEGIN 16 | TF_QRNN_FO_POOL_NAMESPACE_BEGIN 17 | 18 | // For simpler partial specialisation 19 | typedef Eigen::GpuDevice GPUDevice; 20 | 21 | /* TIME MAJOR */ 22 | 23 | // Specialise the FoPool op for GPUs 24 | template 25 | class FoPool : public tensorflow::OpKernel { 26 | public: 27 | explicit FoPool(tensorflow::OpKernelConstruction * context) : 28 | tensorflow::OpKernel(context) {} 29 | 30 | void Compute(tensorflow::OpKernelContext * context) override { 31 | namespace tf = tensorflow; 32 | 33 | // Create variables for input tensors 34 | const auto & in_x = context->input(0); 35 | const auto & in_forget = context->input(1); 36 | const auto & in_hinit = context->input(2); 37 | 38 | // Allocate output tensors 39 | // Allocate space for output tensor 'output' 40 | tf::Tensor * output_ptr = nullptr; 41 | auto in_x_shape = in_x.shape(); 42 | tf::TensorShape output_shape = in_x_shape; 43 | if (time_major) { 44 | output_shape.set_dim(0, output_shape.dim_size(0) + 1); 45 | } else { 46 | output_shape.set_dim(1, output_shape.dim_size(1) + 1); 47 | } 48 | OP_REQUIRES_OK(context, context->allocate_output( 49 | 0, output_shape, &output_ptr)); 50 | 51 | // Get pointers to flattened tensor data buffers 52 | const auto fin_x = in_x.flat().data(); 53 | const auto fin_forget = in_forget.flat().data(); 54 | const auto fin_hinit = in_hinit.flat().data(); 55 | auto fout_output = output_ptr->flat().data(); 56 | 57 | 58 | // Get the GPU device 59 | const auto & device = context->eigen_device(); 60 | 61 | // Call the qrnn_fo_pool CUDA kernel 62 | if (time_major) { 63 | TimeMajorFoPoolLauncher(fout_output, fin_x, fin_forget, fin_hinit, 64 | in_x_shape.dim_size(0), 65 | output_shape.dim_size(1), 66 | output_shape.dim_size(2), 67 | device.stream()); 68 | } else { 69 | BatchMajorFoPoolLauncher(fout_output, fin_x, fin_forget, fin_hinit, 70 | in_x_shape.dim_size(1), 71 | output_shape.dim_size(0), 72 | output_shape.dim_size(2), 73 | device.stream()); 74 | } 75 | } 76 | }; 77 | 78 | 79 | template 80 | class BwdFoPool : public tensorflow::OpKernel { 81 | public: 82 | explicit BwdFoPool(tensorflow::OpKernelConstruction * context) : 83 | tensorflow::OpKernel(context) {} 84 | 85 | void Compute(tensorflow::OpKernelContext * context) override { 86 | namespace tf = tensorflow; 87 | 88 | const auto& in_h = context->input(0); 89 | const auto& in_x = context->input(1); 90 | const auto& in_forget = context->input(2); 91 | const auto& in_gh = context->input(3); 92 | 93 | // Extract Eigen tensors 94 | auto h = in_h.flat().data(); 95 | auto x = in_x.flat().data(); 96 | auto forget = in_forget.flat().data(); 97 | auto gh = in_gh.flat().data(); 98 | 99 | // Allocate output tensors 100 | // Allocate space for output tensor 'output' 101 | tf::Tensor * out_gf = nullptr; 102 | tf::Tensor * out_gx = nullptr; 103 | tf::Tensor * out_ginitial_state = nullptr; 104 | 105 | auto in_x_shape = in_x.shape(); 106 | tf::TensorShape grad_shape = in_x_shape; 107 | int batch_size = time_major ? in_x_shape.dim_size(1) : in_x_shape.dim_size(0); 108 | tf::TensorShape ginitial_state_shape({batch_size, 109 | in_x_shape.dim_size(2)}); 110 | 111 | OP_REQUIRES_OK(context, context->allocate_output( 112 | 0, grad_shape, &out_gx)); 113 | OP_REQUIRES_OK(context, context->allocate_output( 114 | 1, grad_shape, &out_gf)); 115 | OP_REQUIRES_OK(context, context->allocate_output( 116 | 2, ginitial_state_shape, &out_ginitial_state)); 117 | auto gf = out_gf->flat().data(); 118 | auto gx = out_gx->flat().data(); 119 | auto ginitial_state = out_ginitial_state->flat().data(); 120 | 121 | // Get the GPU device 122 | const auto & device = context->eigen_device(); 123 | 124 | if (time_major) { 125 | TimeMajorBwdFoPoolLauncher(h, 126 | x, 127 | forget, 128 | gh, 129 | gx, 130 | gf, 131 | ginitial_state, 132 | grad_shape.dim_size(0), 133 | grad_shape.dim_size(1), 134 | grad_shape.dim_size(2), 135 | device.stream()); 136 | } else { 137 | BatchMajorBwdFoPoolLauncher(h, 138 | x, 139 | forget, 140 | gh, 141 | gx, 142 | gf, 143 | ginitial_state, 144 | grad_shape.dim_size(1), 145 | grad_shape.dim_size(0), 146 | grad_shape.dim_size(2), 147 | device.stream()); 148 | } 149 | } 150 | }; 151 | 152 | TF_QRNN_FO_POOL_NAMESPACE_STOP 153 | TF_QRNN_NAMESPACE_STOP 154 | 155 | #endif // #ifndef QRNN_FO_POOL_OP_GPU_CUH 156 | 157 | #endif // #if GOOGLE_CUDA 158 | -------------------------------------------------------------------------------- /src/fo_pool_op_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "fo_pool_op_kernel.h" 2 | 3 | struct KernelParams { 4 | dim3 grid; 5 | dim3 blocks; 6 | KernelParams(int HIDDEN, int batch_size) : 7 | grid(std::ceil(double(HIDDEN / double(min(HIDDEN, 512)))), batch_size, 1), 8 | blocks(min(HIDDEN, 512), 1, 1) {}; 9 | }; 10 | 11 | 12 | /* TIME MAJOR */ 13 | 14 | template 15 | __global__ 16 | void time_major_fo_pool(FT *dst, const FT *x, const FT *f, const FT *initial_state, int SEQ, int batch_size, int HIDDEN) { 17 | /* 18 | Note: destination is assumed to be one timestep longer than f or x where dst[0] = h_{-1} 19 | This means dst array has a separate index than that of f or x 20 | */ 21 | int hid = blockIdx.x * blockDim.x + threadIdx.x; 22 | int batch_id = blockIdx.y * blockDim.y + threadIdx.y; 23 | if (hid >= HIDDEN || batch_id >= batch_size) 24 | return; 25 | // 26 | dst[batch_id * HIDDEN + hid] = initial_state[batch_id * HIDDEN + hid]; 27 | for (int ts = 0 + 1; ts < SEQ + 1; ts++) { 28 | // Good sanity check for debugging - only perform additions to a zeroed chunk of memory 29 | // Addition seems atomic or near atomic - you should get incorrect answers if doubling up via threads 30 | // Note: the index i needs to be offset by one as f[0] (f_t) is used for dst[1] (h_t) etc 31 | // To move timesteps, we step HIDDEN * batch_size 32 | // To move batches, we move HIDDEN 33 | // To move neurons, we move +- 1 34 | // Note: dst[dst_i] = ts * 100 + batch_id * 10 + hid; is useful for debugging 35 | int i = (ts - 1) * HIDDEN * batch_size + batch_id * HIDDEN + hid; 36 | int dst_i = (ts - 0) * HIDDEN * batch_size + batch_id * HIDDEN + hid; 37 | int dst_iminus1 = (ts - 1) * HIDDEN * batch_size + batch_id * HIDDEN + hid; 38 | dst[dst_i] = f[i] * x[i]; 39 | dst[dst_i] += (1 - f[i]) * dst[dst_iminus1]; 40 | } 41 | } 42 | 43 | template 44 | __global__ 45 | void time_major_bwd_fo_pool(const FT *h, const FT *x, const FT *f, const FT *gh, FT *gx, FT *gf, FT *ginitial_state, int SEQ, int batch_size, int HIDDEN) { 46 | /* 47 | Note: h is assumed to be one timestep longer than f, x, gf, gx, or gh where dst[0] = h_{-1} 48 | This means dst array has a separate index than that of f or x 49 | */ 50 | int hid = blockIdx.x * blockDim.x + threadIdx.x; 51 | int batch_id = blockIdx.y * blockDim.y + threadIdx.y; 52 | if (hid >= HIDDEN || batch_id >= batch_size) 53 | return; 54 | // 55 | double running_f = 0; 56 | for (int ts = SEQ - 1 + 1; ts >= 0 + 1; ts--) { 57 | int i = (ts - 1) * HIDDEN * batch_size + batch_id * HIDDEN + hid; 58 | int dst_iminus1 = (ts - 1) * HIDDEN * batch_size + batch_id * HIDDEN + hid; 59 | // 60 | running_f += gh[dst_iminus1]; 61 | // Gradient of X 62 | gx[i] = f[i] * running_f; 63 | // Gradient of F 64 | gf[i] = (x[i] - h[dst_iminus1]) * running_f; 65 | // 66 | // The line below is likely more numerically stable than (1 - f[i]) * running_f; 67 | running_f = running_f - f[i] * running_f; 68 | } 69 | ginitial_state[batch_id * HIDDEN + hid] = running_f + gh[batch_id * HIDDEN + hid]; 70 | } 71 | 72 | void TimeMajorFoPoolLauncher(float *dst, const float *x, const float *f, const float *initial_state, int SEQ, int batch_size, int HIDDEN, cudaStream_t stream) { 73 | KernelParams l(HIDDEN, batch_size); 74 | time_major_fo_pool<<>>(dst, x, f, initial_state, SEQ, batch_size, HIDDEN); 75 | } 76 | void TimeMajorFoPoolLauncher(double *dst, const double *x, const double *f, const double *initial_state, int SEQ, int batch_size, int HIDDEN, cudaStream_t stream) { 77 | KernelParams l(HIDDEN, batch_size); 78 | time_major_fo_pool<<>>(dst, x, f, initial_state, SEQ, batch_size, HIDDEN); 79 | } 80 | void TimeMajorBwdFoPoolLauncher(const float *h, const float *x, const float *f, const float *gh, float *gx, float *gf, float *ginitial_state, int SEQ, int batch_size, int HIDDEN, cudaStream_t stream) { 81 | KernelParams l(HIDDEN, batch_size); 82 | time_major_bwd_fo_pool<<>>(h, x, f, gh, gx, gf, ginitial_state, SEQ, batch_size, HIDDEN); 83 | } 84 | void TimeMajorBwdFoPoolLauncher(const double *h, const double *x, const double *f, const double *gh, double *gx, double *gf, double *ginitial_state, int SEQ, int batch_size, int HIDDEN, cudaStream_t stream) { 85 | KernelParams l(HIDDEN, batch_size); 86 | time_major_bwd_fo_pool<<>>(h, x, f, gh, gx, gf, ginitial_state, SEQ, batch_size, HIDDEN); 87 | } 88 | 89 | 90 | /* BATCH MAJOR */ 91 | 92 | template 93 | __global__ 94 | void batch_major_fo_pool(FT *dst, const FT *x, const FT *f, const FT *initial_state, int SEQ, int batch_size, int HIDDEN) { 95 | /* 96 | Note: destination is assumed to be one timestep longer than f or x where dst[0] = h_{-1} 97 | This means dst array has a separate index than that of f or x 98 | */ 99 | int hid = blockIdx.x * blockDim.x + threadIdx.x; 100 | int batch_id = blockIdx.y * blockDim.y + threadIdx.y; 101 | if (hid >= HIDDEN || batch_id >= batch_size) 102 | return; 103 | // 104 | dst[batch_id * HIDDEN * (SEQ + 1) + hid] = initial_state[batch_id * HIDDEN + hid]; 105 | for (int ts = 0 + 1; ts < SEQ + 1; ts++) { 106 | // Good sanity check for debugging - only perform additions to a zeroed chunk of memory 107 | // Addition seems atomic or near atomic - you should get incorrect answers if doubling up via threads 108 | // Note: the index i needs to be offset by one as f[0] (f_t) is used for dst[1] (h_t) etc 109 | // To move timesteps, we step HIDDEN * batch_size 110 | // To move batches, we move HIDDEN 111 | // To move neurons, we move +- 1 112 | // Note: dst[dst_i] = ts * 100 + batch_id * 10 + hid; is useful for debugging 113 | int i = (ts - 1) * HIDDEN + batch_id * HIDDEN * SEQ + hid; 114 | int dst_i = (ts - 0) * HIDDEN + batch_id * HIDDEN * (SEQ + 1) + hid; 115 | int dst_iminus1 = (ts - 1) * HIDDEN + batch_id * HIDDEN * (SEQ + 1) + hid; 116 | dst[dst_i] = f[i] * x[i]; 117 | dst[dst_i] += (1 - f[i]) * dst[dst_iminus1]; 118 | } 119 | } 120 | 121 | template 122 | __global__ 123 | void batch_major_bwd_fo_pool(const FT *h, const FT *x, const FT *f, const FT *gh, FT *gx, FT *gf, FT *ginitial_state, int SEQ, int batch_size, int HIDDEN) { 124 | /* 125 | Note: h is assumed to be one timestep longer than f, x, gf, gx, or gh where dst[0] = h_{-1} 126 | This means dst array has a separate index than that of f or x 127 | */ 128 | int hid = blockIdx.x * blockDim.x + threadIdx.x; 129 | int batch_id = blockIdx.y * blockDim.y + threadIdx.y; 130 | if (hid >= HIDDEN || batch_id >= batch_size) 131 | return; 132 | // 133 | double running_f = 0; 134 | for (int ts = SEQ - 1 + 1; ts >= 0 + 1; ts--) { 135 | int i = (ts - 1) * HIDDEN + batch_id * HIDDEN * SEQ + hid; 136 | int dst_iminus1 = (ts - 1) * HIDDEN + batch_id * HIDDEN * (SEQ + 1) + hid; 137 | // 138 | running_f += gh[dst_iminus1]; 139 | // Gradient of X 140 | gx[i] = f[i] * running_f; 141 | // Gradient of F 142 | gf[i] = (x[i] - h[dst_iminus1]) * running_f; 143 | // 144 | // The line below is likely more numerically stable than (1 - f[i]) * running_f; 145 | running_f = running_f - f[i] * running_f; 146 | } 147 | ginitial_state[batch_id * HIDDEN + hid] = running_f + gh[batch_id * HIDDEN * (SEQ + 1) + hid]; 148 | } 149 | 150 | 151 | void BatchMajorFoPoolLauncher(float *dst, const float *x, const float *f, const float *initial_state, int SEQ, int batch_size, int HIDDEN, cudaStream_t stream) { 152 | KernelParams l(HIDDEN, batch_size); 153 | batch_major_fo_pool<<>>(dst, x, f, initial_state, SEQ, batch_size, HIDDEN); 154 | } 155 | void BatchMajorFoPoolLauncher(double *dst, const double *x, const double *f, const double *initial_state, int SEQ, int batch_size, int HIDDEN, cudaStream_t stream) { 156 | KernelParams l(HIDDEN, batch_size); 157 | batch_major_fo_pool<<>>(dst, x, f, initial_state, SEQ, batch_size, HIDDEN); 158 | } 159 | void BatchMajorBwdFoPoolLauncher(const float *h, const float *x, const float *f, const float *gh, float *gx, float *gf, float *ginitial_state, int SEQ, int batch_size, int HIDDEN, cudaStream_t stream) { 160 | KernelParams l(HIDDEN, batch_size); 161 | batch_major_bwd_fo_pool<<>>(h, x, f, gh, gx, gf, ginitial_state, SEQ, batch_size, HIDDEN); 162 | } 163 | void BatchMajorBwdFoPoolLauncher(const double *h, const double *x, const double *f, const double *gh, double *gx, double *gf, double *ginitial_state, int SEQ, int batch_size, int HIDDEN, cudaStream_t stream) { 164 | KernelParams l(HIDDEN, batch_size); 165 | batch_major_bwd_fo_pool<<>>(h, x, f, gh, gx, gf, ginitial_state, SEQ, batch_size, HIDDEN); 166 | } 167 | -------------------------------------------------------------------------------- /src/fo_pool_op_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef RECURRENT_FORGET_MULT_OP_KERNEL_H 2 | #define RECURRENT_FORGET_MULT_OP_KERNEL_H 3 | 4 | #include 5 | 6 | /* TIME MAJOR */ 7 | 8 | void TimeMajorFoPoolLauncher(float *dst, const float *x, const float *f, const float *initial_state, int SEQ, int batch_size, int HIDDEN, cudaStream_t stream); 9 | void TimeMajorBwdFoPoolLauncher(const float *h, const float *x, const float *f, const float *gh, float *gx, float *gf, float *ginitial_state, int SEQ, int batch_size, int HIDDEN, cudaStream_t stream); 10 | 11 | void TimeMajorFoPoolLauncher(double *dst, const double *x, const double *f, const double *initial_state, int SEQ, int batch_size, int HIDDEN, cudaStream_t stream); 12 | void TimeMajorBwdFoPoolLauncher(const double *h, const double *x, const double *f, const double *gh, double *gx, double *gf, double *ginitial_state, int SEQ, int batch_size, int HIDDEN, cudaStream_t stream); 13 | 14 | /* BATCH MAJOR */ 15 | 16 | void BatchMajorFoPoolLauncher(float *dst, const float *x, const float *f, const float *initial_state, int SEQ, int batch_size, int HIDDEN, cudaStream_t stream); 17 | void BatchMajorBwdFoPoolLauncher(const float *h, const float *x, const float *f, const float *gh, float *gx, float *gf, float *ginitial_state, int SEQ, int batch_size, int HIDDEN, cudaStream_t stream); 18 | 19 | void BatchMajorFoPoolLauncher(double *dst, const double *x, const double *f, const double *initial_state, int SEQ, int batch_size, int HIDDEN, cudaStream_t stream); 20 | void BatchMajorBwdFoPoolLauncher(const double *h, const double *x, const double *f, const double *gh, double *gx, double *gf, double *ginitial_state, int SEQ, int batch_size, int HIDDEN, cudaStream_t stream); 21 | 22 | 23 | #endif 24 | -------------------------------------------------------------------------------- /src/third_party/nsync.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 Google Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. */ 14 | 15 | #ifndef NSYNC_PUBLIC_NSYNC_H_ 16 | #define NSYNC_PUBLIC_NSYNC_H_ 17 | 18 | #include "nsync_mu.h" 19 | #include "nsync_mu_wait.h" 20 | #include "nsync_cv.h" 21 | #include "nsync_note.h" 22 | #include "nsync_counter.h" 23 | #include "nsync_waiter.h" 24 | #include "nsync_once.h" 25 | #include "nsync_debug.h" 26 | 27 | #endif /*NSYNC_PUBLIC_NSYNC_H_*/ 28 | -------------------------------------------------------------------------------- /src/third_party/nsync_atomic.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 Google Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. */ 14 | 15 | #ifndef NSYNC_PUBLIC_NSYNC_ATOMIC_H_ 16 | #define NSYNC_PUBLIC_NSYNC_ATOMIC_H_ 17 | 18 | #include "nsync_cpp.h" 19 | 20 | /* This file is not to be included directly by the client. It exists because 21 | on some platforms, one cannot use a simple uint32_t with atomic operations. 22 | */ 23 | #if NSYNC_ATOMIC_TYPECHECK 24 | #include 25 | NSYNC_CPP_START_ 26 | typedef struct { uint32_t value; } nsync_atomic_uint32_; 27 | NSYNC_CPP_END_ 28 | #define NSYNC_ATOMIC_UINT32_INIT_ { 0 } 29 | #define NSYNC_ATOMIC_UINT32_LOAD_(p) ((p)->value) 30 | #define NSYNC_ATOMIC_UINT32_STORE_(p,v) ((p)->value = (v)) 31 | #define NSYNC_ATOMIC_UINT32_PTR_(p) (&(p)->value) 32 | 33 | #elif NSYNC_ATOMIC_C11 34 | #include 35 | NSYNC_CPP_START_ 36 | typedef atomic_uint_least32_t nsync_atomic_uint32_; 37 | NSYNC_CPP_END_ 38 | #define NSYNC_ATOMIC_UINT32_INIT_ 0 39 | #define NSYNC_ATOMIC_UINT32_LOAD_(p) (*(p)) 40 | #define NSYNC_ATOMIC_UINT32_STORE_(p,v) (*(p) = (v)) 41 | #define NSYNC_ATOMIC_UINT32_PTR_(p) (p) 42 | 43 | #elif NSYNC_ATOMIC_CPP11 44 | #include 45 | NSYNC_CPP_START_ 46 | typedef std::atomic nsync_atomic_uint32_; 47 | NSYNC_CPP_END_ 48 | #define NSYNC_ATOMIC_UINT32_INIT_ ATOMIC_VAR_INIT (0) 49 | #define NSYNC_ATOMIC_UINT32_LOAD_(p) (std::atomic_load (p)) 50 | #define NSYNC_ATOMIC_UINT32_STORE_(p,v) (std::atomic_store ((p), (uint32_t) (v))) 51 | #define NSYNC_ATOMIC_UINT32_PTR_(p) (p) 52 | 53 | #else 54 | #include 55 | NSYNC_CPP_START_ 56 | typedef uint32_t nsync_atomic_uint32_; 57 | NSYNC_CPP_END_ 58 | #define NSYNC_ATOMIC_UINT32_INIT_ 0 59 | #define NSYNC_ATOMIC_UINT32_LOAD_(p) (*(p)) 60 | #define NSYNC_ATOMIC_UINT32_STORE_(p,v) (*(p) = (v)) 61 | #define NSYNC_ATOMIC_UINT32_PTR_(p) (p) 62 | #endif 63 | 64 | #endif /*NSYNC_PUBLIC_NSYNC_ATOMIC_H_*/ 65 | -------------------------------------------------------------------------------- /src/third_party/nsync_counter.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 Google Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. */ 14 | 15 | #ifndef NSYNC_PUBLIC_NSYNC_COUNTER_H_ 16 | #define NSYNC_PUBLIC_NSYNC_COUNTER_H_ 17 | 18 | #include 19 | #include "nsync_cpp.h" 20 | #include "nsync_mu.h" 21 | #include "nsync_atomic.h" 22 | #include "nsync_time.h" 23 | 24 | NSYNC_CPP_START_ 25 | 26 | struct nsync_dll_element_s_; 27 | 28 | /* An nsync_counter represents an unsigned integer that can count up and down, 29 | and wake waiters when zero. */ 30 | typedef struct nsync_counter_s_ *nsync_counter; 31 | 32 | /* Return a freshly allocated nsync_counter with the specified value, 33 | of NULL if an nsync_counter cannot be created. 34 | 35 | Any non-NULL returned value should be passed to nsync_counter_free() when no 36 | longer needed. */ 37 | nsync_counter nsync_counter_new (uint32_t value); 38 | 39 | /* Free resources associated with c. Requires that c was allocated by 40 | nsync_counter_new(), and no concurrent or future operations are applied to 41 | c. */ 42 | void nsync_counter_free (nsync_counter c); 43 | 44 | /* Add delta to c, and return its new value. It is a checkable runtime error 45 | to decrement c below 0, or to increment c (i.e., apply a delta > 0) after a 46 | waiter has waited. */ 47 | uint32_t nsync_counter_add (nsync_counter c, int32_t delta); 48 | 49 | /* Return the current value of c. */ 50 | uint32_t nsync_counter_value (nsync_counter c); 51 | 52 | /* Wait until c has value 0, or until abs_deadline, then return 53 | the value of c. It is a checkable runtime error to increment c after 54 | a waiter may have been woken due to the counter reaching zero. 55 | If abs_deadline==nsync_time_no_deadline, the deadline 56 | is far in the future. */ 57 | uint32_t nsync_counter_wait (nsync_counter c, nsync_time abs_deadline); 58 | 59 | NSYNC_COUNTER_CPP_OVERLOAD_ 60 | NSYNC_CPP_END_ 61 | 62 | #endif /*NSYNC_PUBLIC_NSYNC_COUNTER_H_*/ 63 | -------------------------------------------------------------------------------- /src/third_party/nsync_cpp.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 Google Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. */ 14 | 15 | #ifndef NSYNC_PUBLIC_NSYNC_CPP_H_ 16 | #define NSYNC_PUBLIC_NSYNC_CPP_H_ 17 | 18 | /* This header file permits compilation via a C++ compiler using the macros 19 | NSYNC_CPP_START_, NSYNC_CPP_END_, and NSYNC_CPP_USING_. 20 | 21 | NSYNC_CPP_START_ and NSYNC_CPP_END_ surround C code in the public library. 22 | They put all public symbols into the "nsync" name space. 23 | 24 | NSYNC_CPP_USING_ is used before C code (used for testing) that might use 25 | public exports from this package. It makes symbols in the "nsync" 26 | name space available without the "nsync::" prefix. 27 | 28 | NSYNC_C_START_ and NSYNC_C_END_ surround C code in the C++ modules. 29 | */ 30 | 31 | #if defined(__cplusplus) 32 | #define NSYNC_CPP_START_ namespace nsync { 33 | #define NSYNC_CPP_END_ } 34 | #define NSYNC_CPP_USING_ using namespace nsync; 35 | #define NSYNC_C_START_ extern "C" { 36 | #define NSYNC_C_END_ } 37 | #else 38 | #define NSYNC_CPP_START_ 39 | #define NSYNC_CPP_END_ 40 | #define NSYNC_CPP_USING_ 41 | #define NSYNC_C_START_ 42 | #define NSYNC_C_END_ 43 | #endif 44 | 45 | #endif /*NSYNC_PUBLIC_NSYNC_CPP_H_*/ 46 | -------------------------------------------------------------------------------- /src/third_party/nsync_cv.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 Google Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. */ 14 | 15 | #ifndef NSYNC_PUBLIC_NSYNC_CV_H_ 16 | #define NSYNC_PUBLIC_NSYNC_CV_H_ 17 | 18 | #include 19 | #include "nsync_cpp.h" 20 | #include "nsync_mu.h" 21 | #include "nsync_atomic.h" 22 | #include "nsync_time.h" 23 | 24 | NSYNC_CPP_START_ 25 | 26 | struct nsync_dll_element_s_; 27 | struct nsync_note_s_; 28 | 29 | /* An nsync_cv is a condition variable in the style of Mesa, Java, POSIX, and Go's sync.Cond. 30 | It allows a thread to wait for a condition on state protected by a mutex, 31 | and to proceed with the mutex held and the condition true. 32 | 33 | See also nsync_mu_wait() and nsync_mu_wait_with_deadline(), which implement conditional 34 | critical sections. In many cases, they are easier to use than condition 35 | variables. 36 | 37 | Usage: 38 | 39 | after making the desired predicate true, call: 40 | nsync_cv_signal (&cv); // If at most one thread can make use of the predicate becoming true. 41 | or 42 | nsync_cv_broadcast (&cv); // If multiple threads can make use of the predicate becoming true. 43 | 44 | To wait for a predicate with no deadline (assuming nsync_cv_broadcast() or 45 | nsync_cv_signal() is called whenever the predicate becomes true): 46 | nsync_mu_lock (μ) 47 | while (!some_predicate_protected_by_mu) { // the while-loop is required. 48 | nsync_cv_wait (&cv, &mu); 49 | } 50 | // predicate is now true 51 | nsync_mu_unlock (&mu); 52 | 53 | To wait for a predicate with a deadline (assuming nsync_cv_broadcast() or 54 | nsync_cv_signal() is called whenever the predicate becomes true): 55 | nsync_mu_lock (&mu); 56 | while (!some_predicate_protected_by_mu && 57 | nsync_cv_wait_with_deadline (&cv, &mu, abs_deadline, cancel_note) == 0) { 58 | } 59 | if (some_predicate_protected_by_mu) { // predicate is true 60 | } else { // predicate is false, and deadline expired, or cancel_note was notified. 61 | } 62 | nsync_mu_unlock (&mu); 63 | or, if the predicate is complex and you wish to write it just once and 64 | inline, you could use the following instead of the for-loop above: 65 | nsync_mu_lock (&mu); 66 | int pred_is_true = 0; 67 | int outcome = 0; 68 | while (!(pred_is_true = some_predicate_protected_by_mu) && outcome == 0) { 69 | outcome = nsync_cv_wait_with_deadline (&cv, &mu, abs_deadline, cancel_note); 70 | } 71 | if (pred_is_true) { // predicate is true 72 | } else { // predicate is false, and deadline expired, or cancel_note was notified. 73 | } 74 | nsync_mu_unlock (&mu); 75 | 76 | As the examples show, Mesa-style condition variables require that waits use 77 | a loop that tests the predicate anew after each wait. It may be surprising 78 | that these are preferred over the precise wakeups offered by the condition 79 | variables in Hoare monitors. Imprecise wakeups make more efficient use of 80 | the critical section, because threads can enter it while a woken thread is 81 | still emerging from the scheduler, which may take thousands of cycles. 82 | Further, they make the programme easier to read and debug by making the 83 | predicate explicit locally at the wait, where the predicate is about to be 84 | assumed; the reader does not have to infer the predicate by examining all 85 | the places where wakeups may occur. */ 86 | typedef struct nsync_cv_s_ { 87 | nsync_atomic_uint32_ word; /* see bits below */ 88 | struct nsync_dll_element_s_ *waiters; /* points to tail of list of waiters; under mu. */ 89 | } nsync_cv; 90 | 91 | /* An nsync_cv should be zeroed to initialize, which can be accomplished by 92 | initializing with static initializer NSYNC_CV_INIT, or by setting the entire 93 | struct to 0, or using nsync_cv_init(). */ 94 | #define NSYNC_CV_INIT { NSYNC_ATOMIC_UINT32_INIT_, 0 } 95 | void nsync_cv_init (nsync_cv *cv); 96 | 97 | /* Wake at least one thread if any are currently blocked on *cv. If 98 | the chosen thread is a reader on an nsync_mu, wake all readers and, if 99 | possible, a writer. */ 100 | void nsync_cv_signal (nsync_cv *cv); 101 | 102 | /* Wake all threads currently blocked on *cv. */ 103 | void nsync_cv_broadcast (nsync_cv *cv); 104 | 105 | /* Atomically release "mu" (which must be held on entry) and block the caller 106 | on *cv. Wait until awakened by a call to nsync_cv_signal() or 107 | nsync_cv_broadcast(), or a spurious wakeup; then reacquire "mu", and return. 108 | Equivalent to a call to nsync_mu_wait_with_deadline() with 109 | abs_deadline==nsync_time_no_deadline, and cancel_note==NULL. Callers should use 110 | nsync_cv_wait() in a loop, as with all standard Mesa-style condition 111 | variables. See examples above. */ 112 | void nsync_cv_wait (nsync_cv *cv, nsync_mu *mu); 113 | 114 | /* Atomically release "mu" (which must be held on entry) 115 | and block the calling thread on *cv. It then waits until awakened by a 116 | call to nsync_cv_signal() or nsync_cv_broadcast() (or a spurious wakeup), or by the time 117 | reaching abs_deadline, or by cancel_note being notified. In all cases, it 118 | reacquires "mu", and returns the reason for the call returned (0, ETIMEDOUT, 119 | or ECANCELED). Use abs_deadline==nsync_time_no_deadline for no deadline, and 120 | cancel_note==NULL for no cancellation. wait_with_deadline() should be used in a 121 | loop, as with all Mesa-style condition variables. See examples above. 122 | 123 | There are two reasons for using an absolute deadline, rather than a relative 124 | timeout---these are why pthread_cond_timedwait() also uses an absolute 125 | deadline. First, condition variable waits have to be used in a loop; with 126 | an absolute times, the deadline does not have to be recomputed on each 127 | iteration. Second, in most real programmes, some activity (such as an RPC 128 | to a server, or when guaranteeing response time in a UI), there is a 129 | deadline imposed by the specification or the caller/user; relative delays 130 | can shift arbitrarily with scheduling delays, and so after multiple waits 131 | might extend beyond the expected deadline. Relative delays tend to be more 132 | convenient mostly in tests and trivial examples than they are in real 133 | programmes. */ 134 | int nsync_cv_wait_with_deadline (nsync_cv *cv, nsync_mu *mu, 135 | nsync_time abs_deadline, 136 | struct nsync_note_s_ *cancel_note); 137 | 138 | /* Like nsync_cv_wait_with_deadline(), but allow an arbitrary lock *v to be used, 139 | given its (*lock)(mu) and (*unlock)(mu) routines. */ 140 | int nsync_cv_wait_with_deadline_generic (nsync_cv *cv, 141 | void *mu, void (*lock) (void *), void (*unlock) (void *), 142 | nsync_time abs_deadline, 143 | struct nsync_note_s_ *cancel_note); 144 | 145 | NSYNC_CV_CPP_OVERLOAD_ 146 | NSYNC_CPP_END_ 147 | 148 | #endif /*NSYNC_PUBLIC_NSYNC_CV_H_*/ 149 | -------------------------------------------------------------------------------- /src/third_party/nsync_debug.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 Google Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. */ 14 | 15 | #ifndef NSYNC_PUBLIC_NSYNC_DEBUG_H_ 16 | #define NSYNC_PUBLIC_NSYNC_DEBUG_H_ 17 | 18 | /* Debugging operations for mutexes and condition variables. 19 | 20 | These operations should not be relied upon for normal functionality. The 21 | implementation may be slow, output formats may change, and the 22 | implementation is free to yield the empty string. */ 23 | 24 | #include "nsync_cpp.h" 25 | #include "nsync_mu.h" 26 | #include "nsync_cv.h" 27 | 28 | NSYNC_CPP_START_ 29 | 30 | /* Place in buf[0,..,n-1] a nul-terminated, human readable string indicative of 31 | some of the internal state of the mutex or condition variable, and return 32 | buf. If n>=4, buffer overflow is indicated by placing the characters "..." 33 | at the end of the string. 34 | 35 | The *_and_waiters() variants attempt to output the waiter lists in addition 36 | to the basic state. These variants may acquire internal locks and follow 37 | internal pointers. Thus, they are riskier if invoked in an address space 38 | whose overall health is uncertain. */ 39 | char *nsync_mu_debug_state (nsync_mu *mu, char *buf, int n); 40 | char *nsync_cv_debug_state (nsync_cv *cv, char *buf, int n); 41 | char *nsync_mu_debug_state_and_waiters (nsync_mu *mu, char *buf, int n); 42 | char *nsync_cv_debug_state_and_waiters (nsync_cv *cv, char *buf, int n); 43 | 44 | /* Like nsync_*_debug_state_and_waiters(), but ignoring all locking and safety 45 | considerations, and using an internal, possibly static buffer that may be 46 | overwritten by subsequent or concurrent calls to these routines. These 47 | variants should be used only from an interactive debugger, when all other 48 | threads are stopped; the debugger is expected to recover from errors. */ 49 | char *nsync_mu_debugger (nsync_mu *mu); 50 | char *nsync_cv_debugger (nsync_cv *cv); 51 | 52 | NSYNC_CPP_END_ 53 | 54 | #endif /*NSYNC_PUBLIC_NSYNC_DEBUG_H_*/ 55 | -------------------------------------------------------------------------------- /src/third_party/nsync_mu.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 Google Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. */ 14 | 15 | #ifndef NSYNC_PUBLIC_NSYNC_MU_H_ 16 | #define NSYNC_PUBLIC_NSYNC_MU_H_ 17 | 18 | #include 19 | #include "nsync_cpp.h" 20 | #include "nsync_atomic.h" 21 | 22 | NSYNC_CPP_START_ 23 | 24 | struct nsync_dll_element_s_; 25 | 26 | /* An nsync_mu is a lock. If initialized to all zeroes, it is valid and unlocked. 27 | 28 | An nsync_mu can be "free", held by a single thread (aka fiber, goroutine) in 29 | "write" (exclusive) mode, or by many threads in "read" (shared) mode. A 30 | thread that acquires it should eventually release it. It is illegal to 31 | acquire an nsync_mu in one thread and release it in another. It is 32 | illegal for a thread to reacquire an nsync_mu while holding it (even a 33 | second share of a "read" lock). 34 | 35 | Example usage: 36 | static struct foo { 37 | nsync_mu mu; // protects invariant a+b==0 on fields below. 38 | int a; 39 | int b; 40 | } p = { NSYNC_MU_INIT, 0, 0 }; 41 | .... 42 | nsync_mu_lock (&p.mu); 43 | // The current thread now has exclusive access to p.a and p.b; invariant assumed true. 44 | p.a++; 45 | p.b--; // restore invariant p.a+p.b==0 before releasing p.mu 46 | nsync_mu_unlock (&p.mu) 47 | 48 | Mutexes can be used with condition variables; see nsync_cv.h. 49 | 50 | nsync_mu_wait() and nsync_mu_wait_with_deadline() can be used instead of 51 | condition variables. See nsync_mu_wait.h for more details. 52 | Example use of nsync_mu_wait() to wait for p.a==0, using definition above: 53 | int a_is_zero (const void *condition_arg) { 54 | return (((const struct foo *)condition_arg)->a == 0); 55 | } 56 | ... 57 | nsync_mu_lock (&p.mu); 58 | nsync_mu_wait (&p.mu, &a_is_zero, &p, NULL); 59 | // The current thread now has exclusive access to p.a and p.b, and p.a==0. 60 | ... 61 | nsync_mu_unlock (&p.mu); */ 62 | typedef struct nsync_mu_s_ { 63 | nsync_atomic_uint32_ word; /* internal use only */ 64 | struct nsync_dll_element_s_ *waiters; /* internal use only */ 65 | } nsync_mu; 66 | 67 | /* An nsync_mu should be zeroed to initialize, which can be accomplished by 68 | initializing with static initializer NSYNC_MU_INIT, or by setting the entire 69 | structure to all zeroes, or using nsync_mu_init(). */ 70 | #define NSYNC_MU_INIT { NSYNC_ATOMIC_UINT32_INIT_, 0 } 71 | void nsync_mu_init (nsync_mu *mu); 72 | 73 | /* Block until *mu is free and then acquire it in writer mode. 74 | Requires that the calling thread not already hold *mu in any mode. */ 75 | void nsync_mu_lock (nsync_mu *mu); 76 | 77 | /* Unlock *mu, which must have been acquired in write mode by the calling 78 | thread, and wake waiters, if appropriate. */ 79 | void nsync_mu_unlock (nsync_mu *mu); 80 | 81 | /* Attempt to acquire *mu in writer mode without blocking, and return non-zero 82 | iff successful. Return non-zero with high probability if *mu was free 83 | on entry. */ 84 | int nsync_mu_trylock (nsync_mu *mu); 85 | 86 | /* Block until *mu can be acquired in reader mode and then acquire it. 87 | Requires that the calling thread not already hold *mu in any mode. */ 88 | void nsync_mu_rlock (nsync_mu *mu); 89 | 90 | /* Unlock *mu, which must have been acquired in read mode by the calling 91 | thread, and wake waiters, if appropriate. */ 92 | void nsync_mu_runlock (nsync_mu *mu); 93 | 94 | /* Attempt to acquire *mu in reader mode without blocking, and return non-zero 95 | iff successful. Return non-zero with high probability if *mu was free on 96 | entry. Perhaps fail to acquire if a writer is waiting, to avoid starvation. 97 | */ 98 | int nsync_mu_rtrylock (nsync_mu *mu); 99 | 100 | /* May abort if *mu is not held in write mode by the calling thread. */ 101 | void nsync_mu_assert_held (const nsync_mu *mu); 102 | 103 | /* May abort if *mu is not held in read or write mode 104 | by the calling thread. */ 105 | void nsync_mu_rassert_held (const nsync_mu *mu); 106 | 107 | /* Return whether *mu is held in read mode. 108 | Requires that the calling thread holds *mu in some mode. */ 109 | int nsync_mu_is_reader (const nsync_mu *mu); 110 | 111 | NSYNC_CPP_END_ 112 | 113 | #endif /*NSYNC_PUBLIC_NSYNC_MU_H_*/ 114 | -------------------------------------------------------------------------------- /src/third_party/nsync_mu_wait.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 Google Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. */ 14 | 15 | #ifndef NSYNC_PUBLIC_NSYNC_MU_WAIT_H_ 16 | #define NSYNC_PUBLIC_NSYNC_MU_WAIT_H_ 17 | 18 | /* nsync_mu_wait() and nsync_mu_wait_with_deadline() can be used instead of condition 19 | variables. In many straightforward situations they are of equivalent 20 | performance and are somewhat easier to use, because unlike condition 21 | variables, they do not require that the waits be placed in a loop, and they 22 | do not require explicit wakeup calls. Example: 23 | 24 | Definitions: 25 | static nsync_mu mu = NSYNC_MU_INIT; 26 | static int i = 0; // protected by mu 27 | // Condition for use with nsync_mu_wait(). 28 | static int int_is_zero (const void *v) { return (*(const int *)v == 0); } 29 | 30 | Waiter: 31 | nsync_mu_lock (&mu); 32 | // Wait until i is zero. 33 | nsync_mu_wait (&mu, &int_is_zero, &i, NULL); 34 | // i is known to be zero here. 35 | // ... 36 | nsync_mu_unlock (&mu); 37 | 38 | 39 | Thread potentially making i zero: 40 | nsync_mu_lock (&mu); 41 | i--; 42 | // No need to signal that i may have become zero. The unlock call below 43 | // will evaluate waiters' conditions to decide which to wake. 44 | nsync_mu_unlock (&mu); 45 | 46 | It is legal to use conditional critical sections and condition variables 47 | on the same mutex. 48 | 49 | -------------- 50 | 51 | The implementation benefits from determining whether waiters are waiting for 52 | the same condition; it may then evaluate a condition once on behalf 53 | of several waiters. Two waiters have equal condition if their "condition" 54 | pointers are equal, and either: 55 | - their "condition_arg" pointers are equal, or 56 | - "condition_arg_eq" is non-null and 57 | (*condition_arg_eq) (condition_arg0, condition_arg1) returns non-zero. 58 | *condition_arg_eq will not be invoked unless the "condition" pointers 59 | are equal, and the "condition_arg" pointers are unequal. 60 | 61 | If many waiters wait for distinct conditions simultaneously, condition 62 | variables may be faster. 63 | */ 64 | 65 | #include "nsync_cpp.h" 66 | #include "nsync_mu.h" 67 | #include "nsync_time.h" 68 | 69 | NSYNC_CPP_START_ 70 | 71 | struct nsync_note_s_; /* forward declaration for an nsync_note */ 72 | 73 | /* Return when (*condition) (condition_arg) is true. Perhaps unlock and relock 74 | *mu while blocked waiting for the condition to become true. nsync_mu_wait() 75 | is equivalent to nsync_mu_wait_with_deadline() with 76 | abs_deadline==nsync_time_no_deadline, and cancel_note==NULL. 77 | 78 | Requires that *mu be held on entry. 79 | See nsync_mu_wait_with_deadline() for more details on *condition and 80 | *condition_arg_eq. */ 81 | void nsync_mu_wait (nsync_mu *mu, int (*condition) (const void *condition_arg), 82 | const void *condition_arg, 83 | int (*condition_arg_eq) (const void *a, const void *b)); 84 | 85 | /* Return when at least one of: (*condition) (condition_arg) is true, the 86 | deadline expires, or *cancel_note is notified. Perhaps unlock and relock *mu 87 | while blocked waiting for one of these events, but always return with *mu 88 | held. Return 0 iff the (*condition) (condition_arg) is true on return, and 89 | otherwise either ETIMEDOUT or ECANCELED, depending on why the call returned 90 | early. Callers should use abs_deadline==nsync_time_no_deadline for no 91 | deadline, and cancel_note==NULL for no cancellation. 92 | 93 | Requires that *mu be held on entry. 94 | 95 | The implementation may call *condition from any thread using the mutex, and 96 | while holding *mu in either read or write mode; it guarantees that any 97 | thread calling *condition will hold *mu in some mode. 98 | Requires that (*condition) (condition_arg) neither modify state protected by 99 | *mu, nor return a value dependent on state not protected by *mu. To depend 100 | on time, use the abs_deadline parameter. 101 | (Conventional use of condition variables have the same restrictions on the 102 | conditions tested by the while-loop.) 103 | If non-null, condition_arg_eq should return whether two condition_arg 104 | calls with the same "condition" pointer are considered equivalent; it should 105 | have no side-effects. */ 106 | int nsync_mu_wait_with_deadline (nsync_mu *mu, 107 | int (*condition) (const void *condition_arg), 108 | const void *condition_arg, 109 | int (*condition_arg_eq) (const void *a, const void *b), 110 | nsync_time abs_deadline, 111 | struct nsync_note_s_ *cancel_note); 112 | 113 | /* Unlock *mu, which must be held in write mode, and wake waiters, if 114 | appropriate. Unlike nsync_mu_unlock(), this call is not required to wake 115 | nsync_mu_wait/nsync_mu_wait_with_deadline calls on conditions that were 116 | false before this thread acquired the lock. This call should be used only 117 | at the end of critical sections for which: 118 | - nsync_mu_wait and/or nsync_mu_wait_with_deadline are in use on the same 119 | mutex, 120 | - this critical section cannot make the condition true for any of those 121 | nsync_mu_wait/nsync_mu_wait_with_deadline waits, and 122 | - when performance is significantly improved by using this call. */ 123 | void nsync_mu_unlock_without_wakeup (nsync_mu *mu); 124 | 125 | NSYNC_MU_WAIT_CPP_OVERLOAD_ 126 | NSYNC_CPP_END_ 127 | 128 | #endif /*NSYNC_PUBLIC_NSYNC_MU_WAIT_H_*/ 129 | -------------------------------------------------------------------------------- /src/third_party/nsync_note.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 Google Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. */ 14 | 15 | #ifndef NSYNC_PUBLIC_NSYNC_NOTE_H_ 16 | #define NSYNC_PUBLIC_NSYNC_NOTE_H_ 17 | 18 | #include "nsync_cpp.h" 19 | #include "nsync_time.h" 20 | 21 | NSYNC_CPP_START_ 22 | 23 | /* An nsync_note represents a single bit that can transition from 0 to 1 at 24 | most once. When 1, the note is said to be notified. There are operations 25 | to wait for the transition, which can be triggered either by an explicit 26 | call, or timer expiry. Notes can have parent notes; a note becomes notified 27 | if its parent becomes notified. */ 28 | typedef struct nsync_note_s_ *nsync_note; 29 | 30 | /* Return a freshly allocated nsync_note, or NULL if an nsync_note cannot be 31 | created. 32 | 33 | If parent!=NULL, the allocated nsync_note's parent will be parent. The 34 | newaly allocated note will be automatically notified at abs_deadline, and is 35 | notified at initialization if abs_deadline==nsync_zero_time. 36 | 37 | nsync_notes should be passed to nsync_note_free() when no longer needed. */ 38 | nsync_note nsync_note_new (nsync_note parent, nsync_time abs_deadline); 39 | 40 | /* Free resources associated with n. Requires that n was allocated by 41 | nsync_note_new(), and no concurrent or future operations are applied to n 42 | directly. 43 | It is legal to call nsync_note_free() on a node even if it has a parent or 44 | children that are in use; if n has both a parent and children, n's 45 | parent adopts its children. */ 46 | void nsync_note_free (nsync_note n); 47 | 48 | /* Notify n and all its descendants. */ 49 | void nsync_note_notify (nsync_note n); 50 | 51 | /* Return whether n has been notified. */ 52 | int nsync_note_is_notified (nsync_note n); 53 | 54 | /* Wait until n has been notified or abs_deadline is reached, and return 55 | whether n has been notified. If abs_deadline==nsync_time_no_deadline, 56 | the deadline is far in the future. */ 57 | int nsync_note_wait (nsync_note n, nsync_time abs_deadline); 58 | 59 | /* Return the expiry time associated with n. 60 | This is the minimum of the abs_deadline passed on creation and that of any 61 | of its ancestors. */ 62 | nsync_time nsync_note_expiry (nsync_note n); 63 | 64 | NSYNC_NOTE_CPP_OVERLOAD_ 65 | NSYNC_CPP_END_ 66 | 67 | #endif /*NSYNC_PUBLIC_NSYNC_NOTE_H_*/ 68 | -------------------------------------------------------------------------------- /src/third_party/nsync_once.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 Google Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. */ 14 | 15 | #ifndef NSYNC_PUBLIC_NSYNC_ONCE_H_ 16 | #define NSYNC_PUBLIC_NSYNC_ONCE_H_ 17 | 18 | #include 19 | #include "nsync_cpp.h" 20 | #include "nsync_atomic.h" 21 | 22 | NSYNC_CPP_START_ 23 | 24 | /* An nsync_once allows a function to be called exactly once, when first referenced. */ 25 | typedef nsync_atomic_uint32_ nsync_once; 26 | 27 | /* An initializer for nsync_once; it is guaranteed to be all zeroes. */ 28 | #define NSYNC_ONCE_INIT NSYNC_ATOMIC_UINT32_INIT_ 29 | 30 | /* The first time nsync_run_once() or nsync_run_once_arg() is applied to *once, 31 | the supplied function is run (with argument, in the case of nsync_run_once_arg()). 32 | Other callers will wait until the run of the function is complete, and then 33 | return without running the function again. */ 34 | void nsync_run_once (nsync_once *once, void (*f) (void)); 35 | void nsync_run_once_arg (nsync_once *once, void (*farg) (void *arg), void *arg); 36 | 37 | /* Same as nsync_run_once()/nsync_run_once_arg() but uses a spinloop. 38 | Can be used on the same nsync_once as nsync_run_once/nsync_run_once_arg(). 39 | 40 | These *_spin variants should be used only in contexts where normal blocking 41 | is disallowed, such as within user-space schedulers, when the runtime is 42 | not fully initialized, etc. They provide no significant performance benefit, 43 | and they should be avoided in normal code. */ 44 | void nsync_run_once_spin (nsync_once *once, void (*f) (void)); 45 | void nsync_run_once_arg_spin (nsync_once *once, void (*farg) (void *arg), void *arg); 46 | 47 | NSYNC_CPP_END_ 48 | 49 | #endif /*NSYNC_PUBLIC_NSYNC_ONCE_H_*/ 50 | -------------------------------------------------------------------------------- /src/third_party/nsync_time.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 Google Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. */ 14 | 15 | #ifndef NSYNC_PUBLIC_NSYNC_TIME_H_ 16 | #define NSYNC_PUBLIC_NSYNC_TIME_H_ 17 | 18 | #include "nsync_cpp.h" 19 | #include "nsync_time_internal.h" 20 | 21 | /* The type nsync_time represents the interval elapsed between two moments in 22 | time. Often the first such moment is an address-space-wide epoch, such as 23 | the Unix epoch, but clients should not rely on the epoch in one address 24 | space being the same as that in another. Intervals relative to the epoch 25 | are known as absolute times. 26 | 27 | The internals of nsync_time should be treated as opaque by clients. 28 | See nsync_time_internal.h. */ 29 | 30 | NSYNC_CPP_START_ 31 | 32 | extern const nsync_time nsync_time_no_deadline; /* A deadline infinitely far in the future. */ 33 | extern const nsync_time nsync_time_zero; /* The zero delay, or an expired deadline. */ 34 | 35 | nsync_time nsync_time_now (void); /* Return the current time since the epoch. */ 36 | 37 | /* Sleep for the specified delay. Returns the unslept time 38 | which may be non-zero if the call was interrupted. */ 39 | nsync_time nsync_time_sleep (nsync_time delay); 40 | 41 | /* Return a+b */ 42 | nsync_time nsync_time_add (nsync_time a, nsync_time b); 43 | 44 | /* Return a-b */ 45 | nsync_time nsync_time_sub (nsync_time a, nsync_time b); 46 | 47 | /* Return +ve, 0, or -ve according to whether a>b, a==b, or a 47 | NSYNC_CPP_START_ 48 | typedef struct { 49 | time_t seconds; 50 | unsigned nanoseconds; 51 | } nsync_time; 52 | #define NSYNC_TIME_SEC(t) ((t).seconds) 53 | #define NSYNC_TIME_NSEC(t) ((t).nanoseconds) 54 | NSYNC_CPP_END_ 55 | 56 | #elif defined(__cplusplus) && \ 57 | (NSYNC_USE_CPP11_TIMEPOINT || (__cplusplus >= 201103L) || (_MSC_VER >= 1700)) 58 | /* The inline functions below provide function overloads that accept the most 59 | likely C++11 time type(s). 60 | 61 | C++11 time types have many variations and subtleties: 62 | - There are multiple clocks with potentially differing epochs; these clocks 63 | are not necessarily phase-locked to the same rate, making conversion and 64 | comparison between clocks tricky. 65 | - Relative and absolute times are distinguished in the type system. 66 | - Either integral or floating point counters may be used to represent time 67 | intervals, and code valid with one may not be valid with the other 68 | (see std::chrono::treat_as_floating_point). 69 | - A counter increment of one can represent any rational number of seconds 70 | (for whatever "seconds" means for this clock). 71 | - Conversions between duration types may round or truncate at the 72 | implementation's discretion. 73 | - As mentioned above, common implementations of the default monotonic clock 74 | ("steady_clock") illegally allow a thread to observe time going backwards, 75 | especially in the face of scheduling on a different CPU, making its use 76 | misleading, at best. 77 | I've chosen to handle this complexity by doing a conversion to absolute 78 | timespec at the interface layer, so all the C++ complication is here, rather 79 | than spread throughout the library. */ 80 | 81 | #include 82 | #include 83 | NSYNC_CPP_START_ 84 | typedef struct timespec nsync_time; 85 | #define NSYNC_TIME_SEC(t) ((t).tv_sec) 86 | #define NSYNC_TIME_NSEC(t) ((t).tv_nsec) 87 | 88 | typedef std::chrono::system_clock::time_point nsync_cpp_time_point_; 89 | nsync_time nsync_from_time_point_ (nsync_cpp_time_point_); 90 | nsync_cpp_time_point_ nsync_to_time_point_ (nsync_time); 91 | #define NSYNC_COUNTER_CPP_OVERLOAD_ \ 92 | static inline uint32_t nsync_counter_wait (nsync_counter c, \ 93 | nsync_cpp_time_point_ abs_deadline) { \ 94 | return (nsync_counter_wait (c, nsync_from_time_point_ (abs_deadline))); \ 95 | } 96 | #define NSYNC_CV_CPP_OVERLOAD_ \ 97 | static inline int nsync_cv_wait_with_deadline (nsync_cv *cv, nsync_mu *mu, \ 98 | nsync_cpp_time_point_ abs_deadline, struct nsync_note_s_ *cancel_note) { \ 99 | return (nsync_cv_wait_with_deadline (cv, mu, \ 100 | nsync_from_time_point_ (abs_deadline), \ 101 | cancel_note)); \ 102 | } \ 103 | static inline int nsync_cv_wait_with_deadline_generic (nsync_cv *cv, \ 104 | void *mu, void (*lock) (void *), void (*unlock) (void *), \ 105 | nsync_cpp_time_point_ abs_deadline, struct nsync_note_s_ *cancel_note) { \ 106 | return (nsync_cv_wait_with_deadline_generic (cv, mu, lock, unlock, \ 107 | nsync_from_time_point_ (abs_deadline), \ 108 | cancel_note)); \ 109 | } 110 | #define NSYNC_MU_WAIT_CPP_OVERLOAD_ \ 111 | static inline int nsync_mu_wait_with_deadline (nsync_mu *mu, \ 112 | int (*condition) (const void *condition_arg), const void *condition_arg, \ 113 | int (*condition_arg_eq) (const void *a, const void *b), \ 114 | nsync_cpp_time_point_ abs_deadline, struct nsync_note_s_ *cancel_note) { \ 115 | return (nsync_mu_wait_with_deadline (mu, condition, condition_arg, \ 116 | condition_arg_eq, \ 117 | nsync_from_time_point_ (abs_deadline), \ 118 | cancel_note)); \ 119 | } 120 | #define NSYNC_NOTE_CPP_OVERLOAD_ \ 121 | static inline nsync_note nsync_note_new (nsync_note parent, \ 122 | nsync_cpp_time_point_ abs_deadline) { \ 123 | return (nsync_note_new (parent, nsync_from_time_point_ (abs_deadline))); \ 124 | } \ 125 | static inline int nsync_note_wait (nsync_note n, nsync_cpp_time_point_ abs_deadline) { \ 126 | return (nsync_note_wait (n, nsync_from_time_point_ (abs_deadline))); \ 127 | } \ 128 | static inline nsync_cpp_time_point_ nsync_note_expiry_timepoint (nsync_note n) { \ 129 | return (nsync_to_time_point_ (nsync_note_expiry (n))); \ 130 | } 131 | #define NSYNC_WAITER_CPP_OVERLOAD_ \ 132 | static inline int nsync_wait_n (void *mu, void (*lock) (void *), \ 133 | void (*unlock) (void *), \ 134 | nsync_cpp_time_point_ abs_deadline, \ 135 | int count, struct nsync_waitable_s *waitable[]) { \ 136 | return (nsync_wait_n (mu, lock, unlock, \ 137 | nsync_from_time_point_ (abs_deadline), count, waitable)); \ 138 | } 139 | 140 | NSYNC_CPP_END_ 141 | 142 | #else 143 | /* Default is to use timespec. */ 144 | #include 145 | NSYNC_CPP_START_ 146 | typedef struct timespec nsync_time; 147 | #define NSYNC_TIME_SEC(t) ((t).tv_sec) 148 | #define NSYNC_TIME_NSEC(t) ((t).tv_nsec) 149 | NSYNC_CPP_END_ 150 | 151 | #endif 152 | 153 | #if !defined(NSYNC_COUNTER_CPP_OVERLOAD_) 154 | #define NSYNC_COUNTER_CPP_OVERLOAD_ 155 | #define NSYNC_CV_CPP_OVERLOAD_ 156 | #define NSYNC_MU_WAIT_CPP_OVERLOAD_ 157 | #define NSYNC_NOTE_CPP_OVERLOAD_ 158 | #define NSYNC_WAITER_CPP_OVERLOAD_ 159 | #endif 160 | 161 | #endif /*NSYNC_PUBLIC_NSYNC_TIME_INTERNAL_H_*/ 162 | -------------------------------------------------------------------------------- /src/third_party/nsync_waiter.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 Google Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. */ 14 | 15 | #ifndef NSYNC_PUBLIC_NSYNC_WAITER_H_ 16 | #define NSYNC_PUBLIC_NSYNC_WAITER_H_ 17 | 18 | /* nsync_wait_n() allows the client to wait on multiple objects (condition 19 | variables, nsync_notes, nsync_counters, etc.) until at least one of them 20 | becomes ready, or a deadline expires. 21 | 22 | It can be thought of as rather like Unix's select() or poll(), 23 | except the the objects being waited for are synchronization 24 | data structures, rather than file descriptors. 25 | 26 | The client can construct new objects that can be waited for by implementing 27 | three routines. 28 | 29 | Examples: 30 | 31 | To wait on two nsync_notes n0, n1, and a nsync_counter c0, 32 | with a deadline of abs_deadline: 33 | 34 | // Form an array of struct nsync_waitable_s, identifying the 35 | // objects and the corresponding descriptors. (static initialization 36 | // syntax is used for brevity) 37 | static struct nsync_waitable_s w[] = { 38 | { &n0, &nsync_note_waitable_funcs }, 39 | { &n1, &nsync_note_waitable_funcs }, 40 | { &c0, &nsync_counter_waitable_funcs } 41 | }; 42 | static struct nsync_waitable_s *pw[] = { &w[0], &w[1], &w[2] }; 43 | int n = sizeof (w) / sizeof (w[0]); 44 | 45 | // Wait. The mu, lock, and unlock arguments are NULL because 46 | // no condition variables are invovled. 47 | int i = nsync_wait_n (NULL, NULL, NULL, abs_deadline, n, pw); 48 | if (i == n) { 49 | // timeout 50 | } else { 51 | // w[i].v became ready. 52 | } 53 | 54 | To wait on multiple condition variables, the mu/lock/unlock parameters are 55 | used. Imagine cv0 and cv1 are signalled when predicates pred0() (under 56 | lock mu0) and pred1() (under lock mu1) become true respectively. Assume 57 | that mu0 is acquired before mu1. 58 | static void lock2 (void *v) { // lock two mutexes in order 59 | nsync_mu **mu = (nsync_mu **) v; 60 | nsync_mu_lock (mu[0]); 61 | nsync_mu_lock (mu[1]); 62 | } 63 | static void unlock2 (void *v) { // unlock two mutexes. 64 | nsync_mu **mu = (nsync_mu **) v; 65 | nsync_mu_unlock (mu[1]); 66 | nsync_mu_unlock (mu[0]); 67 | } 68 | 69 | // Describe the condition variables and the locks. 70 | static struct nsync_waitable_s w[] = { 71 | { &cv0, &nsync_cv_waitable_funcs }, 72 | { &cv1, &nsync_cv_waitable_funcs } 73 | }; 74 | static struct nsync_waitable_s *pw[] = { &w[0], &w[1] }; 75 | nsync_mu *lock_list[] = { &mu0, &mu1 }; 76 | int n = sizeof (w) / sizeof (w[0]); 77 | 78 | lock2 (list_list); 79 | while (!pred0 () && !pred1 ()) { 80 | // Wait for one of the condition variables to be signalled, 81 | // with no timeout. 82 | nsync_wait_n (lock_list, &lock2, &unlock2, 83 | nsync_time_no_deadline, n, pw); 84 | } 85 | if (pred0 ()) { ... } 86 | if (pred1 ()) { ... } 87 | unlock2 (list_list); 88 | 89 | */ 90 | 91 | #include 92 | #include 93 | #include "nsync_cpp.h" 94 | #include "nsync_atomic.h" 95 | #include "nsync_time.h" 96 | 97 | NSYNC_CPP_START_ 98 | 99 | struct nsync_waitable_funcs_s; /* forward declaration of struct that contains 100 | type dependent wait operations */ 101 | 102 | /* Clients wait on objects by forming an array of struct nsync_waitable_s. 103 | Each each element points to one object and its type-dependent functions. */ 104 | struct nsync_waitable_s { 105 | void *v; /* pointer to object */ 106 | /* pointer to type-dependent functions. Use 107 | &nsync_note_waitable_funcs for an nsync_note, 108 | &nsync_counternote_waitable_funcs for an nsync_counter, 109 | &nsync_cv_waitable_funcs for an nsync_cv. */ 110 | const struct nsync_waitable_funcs_s *funcs; 111 | }; 112 | 113 | /* Wait until at least one of *waitable[0,..,count-1] is has been notified, or 114 | abs_deadline is reached. Return the index of the notified element of 115 | waitable[], or count if no such element exists. 116 | If mu!=NULL, (*unlock)(mu) is called after the thread is queued on the 117 | various waiters, and (*lock)(mu) is called before return; mu/lock/unlock are 118 | used to acquire and release the relevant locks whan waiting on condition 119 | variables. */ 120 | int nsync_wait_n (void *mu, void (*lock) (void *), void (*unlock) (void *), 121 | nsync_time abs_deadline, int count, 122 | struct nsync_waitable_s *waitable[]); 123 | 124 | /* --------------------------------------------------- */ 125 | 126 | /* A "struct nsync_waitable_s" implementation must implement these functions. 127 | Clients should ignore the internals. */ 128 | struct nsync_waiter_s; 129 | struct nsync_waitable_funcs_s { 130 | /* Return the time when *v will be ready (max time if 131 | unknown), or 0 if it is already ready. The parameter nw may be 132 | passed as NULL, in which case the result should indicate whether the 133 | thread would block if it were to wait on *v. 134 | All calls with the same *v must report the same result until the 135 | object becomes ready, from which point calls must report 0. */ 136 | nsync_time (*ready_time) (void *v, struct nsync_waiter_s *nw); 137 | 138 | /* If *v is ready, return zero; otherwise enqueue *nw on *v and return 139 | non-zero. */ 140 | int (*enqueue) (void *v, struct nsync_waiter_s *nw); 141 | 142 | /* If nw has been previously dequeued, return zero; otherwise dequeue 143 | *nw from *v and return non-zero. */ 144 | int (*dequeue) (void *v, struct nsync_waiter_s *nw); 145 | }; 146 | 147 | /* The "struct nsync_waitable_s" for nsync_note, nsync_counter, and nsync_cv. */ 148 | extern const struct nsync_waitable_funcs_s nsync_note_waitable_funcs; 149 | extern const struct nsync_waitable_funcs_s nsync_counter_waitable_funcs; 150 | extern const struct nsync_waitable_funcs_s nsync_cv_waitable_funcs; 151 | 152 | NSYNC_WAITER_CPP_OVERLOAD_ 153 | NSYNC_CPP_END_ 154 | 155 | #endif /*NSYNC_PUBLIC_NSYNC_WAITER_H_*/ 156 | -------------------------------------------------------------------------------- /src/thread_pool.cpp: -------------------------------------------------------------------------------- 1 | #include "thread_pool.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | class ThreadPool { 13 | private: 14 | typedef std::chrono::duration Duration; 15 | static __thread bool in_thread_pool; 16 | // c++ assigns random id to each thread. This is not a thread_id 17 | // it's a number inside this thread pool. 18 | static __thread int thread_number; 19 | 20 | bool should_terminate; 21 | std::mutex queue_mutex; 22 | std::condition_variable is_idle; 23 | int active_count; 24 | 25 | std::deque > work; 26 | std::vector pool; 27 | Duration between_queue_checks; 28 | 29 | void thread_body(int _thread_id); 30 | public: 31 | // Creates a thread pool composed of num_threads threads. 32 | // threads are started immediately and exit only once ThreadPool 33 | // goes out of scope. Threads periodically check for new work 34 | // and the frequency of those checks is at minimum between_queue_checks 35 | // (it can be higher due to thread scheduling). 36 | ThreadPool(int num_threads, Duration between_queue_checks=std::chrono::milliseconds(1)); 37 | 38 | // Run a function on a thread in pool. 39 | void run(std::function f); 40 | 41 | // Wait until queue is empty and all the threads have finished working. 42 | // If timeout is specified function waits at most timeout until the 43 | // threads are idle. If they indeed become idle returns true. 44 | bool wait_until_idle(Duration timeout); 45 | bool wait_until_idle(); 46 | 47 | // Retruns true if all the work is done. 48 | bool idle() const; 49 | // Return number of active busy workers. 50 | int active_workers(); 51 | ~ThreadPool(); 52 | }; 53 | 54 | __thread bool ThreadPool::in_thread_pool = false; 55 | __thread int ThreadPool::thread_number = -1; 56 | 57 | ThreadPool::ThreadPool(int num_threads, Duration between_queue_checks) : 58 | between_queue_checks(between_queue_checks), 59 | should_terminate(false), 60 | active_count(0) { 61 | // Thread pool inception is not supported at this time. 62 | assert(!in_thread_pool); 63 | 64 | ThreadPool::between_queue_checks = between_queue_checks; 65 | for (int thread_number = 0; thread_number < num_threads; ++thread_number) { 66 | pool.emplace_back(&ThreadPool::thread_body, this, thread_number); 67 | } 68 | } 69 | 70 | void ThreadPool::thread_body(int _thread_id) { 71 | in_thread_pool = true; 72 | thread_number = _thread_id; 73 | bool am_i_active = false; 74 | 75 | while (true) { 76 | std::function f; 77 | { 78 | std::lock_guard lock(queue_mutex); 79 | bool was_i_active = am_i_active; 80 | if (should_terminate && work.empty()) 81 | break; 82 | if (!work.empty()) { 83 | am_i_active = true; 84 | f = work.front(); 85 | work.pop_front(); 86 | } else { 87 | am_i_active = false; 88 | } 89 | 90 | if (am_i_active != was_i_active) { 91 | active_count += am_i_active ? 1 : -1; 92 | if (active_count == 0) { 93 | // number of workers decrease so maybe all are idle 94 | is_idle.notify_all(); 95 | } 96 | } 97 | } 98 | // Function defines implicit conversion to bool 99 | // which is true only if call target was set. 100 | if (static_cast(f)) { 101 | f(); 102 | } else { 103 | std::this_thread::sleep_for(between_queue_checks); 104 | } 105 | std::this_thread::yield(); 106 | } 107 | } 108 | 109 | int ThreadPool::active_workers() { 110 | std::lock_guard lock(queue_mutex); 111 | return active_count; 112 | } 113 | 114 | bool ThreadPool::wait_until_idle(Duration timeout) { 115 | std::unique_lock lock(queue_mutex); 116 | is_idle.wait_for(lock, timeout, [this]{ 117 | return active_count == 0 && work.empty(); 118 | }); 119 | return idle(); 120 | } 121 | 122 | bool ThreadPool::wait_until_idle() { 123 | int retries = 3; 124 | while (retries--) { 125 | try { 126 | std::unique_lock lock(queue_mutex); 127 | is_idle.wait(lock, [this]{ 128 | return active_count == 0 && work.empty(); 129 | }); 130 | return idle(); 131 | } catch (...) {} 132 | } 133 | throw std::runtime_error( 134 | "exceeded retries when waiting until idle." 135 | ); 136 | return false; 137 | } 138 | 139 | bool ThreadPool::idle() const { 140 | return active_count == 0 && work.empty(); 141 | } 142 | 143 | void ThreadPool::run(std::function f) { 144 | int retries = 3; 145 | while (retries--) { 146 | try { 147 | std::unique_lock lock(queue_mutex); 148 | work.push_back(f); 149 | return; 150 | } catch (...) {} 151 | } 152 | throw std::runtime_error( 153 | "exceeded retries when trying to run operation on thread pool." 154 | ); 155 | } 156 | 157 | ThreadPool::~ThreadPool() { 158 | // Terminates thread pool making sure that all the work 159 | // is completed. 160 | should_terminate = true; 161 | for (auto& t : pool) 162 | t.join(); 163 | } 164 | 165 | 166 | void ParallelFor(int max_parallelism, int num_threads, int cost_per_unit, int total, 167 | std::function work) { 168 | cost_per_unit = std::max(1, cost_per_unit); 169 | // We shard [0, total) into "num_shards" shards. 170 | // 1 <= num_shards <= num worker threads 171 | // 172 | // If total * cost_per_unit is small, it is not worth shard too 173 | // much. Let us assume each cost unit is 1ns, kMinCostPerShard=10000 174 | // is 10us. 175 | static const int kMinCostPerShard = 10000; 176 | const int num_shards = 177 | std::max(1, std::min(static_cast(max_parallelism), 178 | total * cost_per_unit / kMinCostPerShard)); 179 | // Each shard contains up to "block_size" units. [0, total) is sharded 180 | // into: 181 | // [0, block_size), [block_size, 2*block_size), ... 182 | // The 1st shard is done by the caller thread and the other shards 183 | // are dispatched to the worker threads. The last shard may be smaller than 184 | // block_size. 185 | const int block_size = (total + num_shards - 1) / num_shards; 186 | if (block_size <= 0) { 187 | throw std::runtime_error("block_size must be > 0."); 188 | } 189 | if (block_size >= total) { 190 | work(0, total); 191 | return; 192 | } 193 | const int num_shards_used = (total + block_size - 1) / block_size; 194 | ThreadPool pool(num_threads); 195 | for (int start = block_size; start < total; start += block_size) { 196 | auto limit = std::min(start + block_size, total); 197 | pool.run([&work, start, limit]() { 198 | work(start, limit); // Compute the shard. 199 | }); 200 | } 201 | // Inline execute the 1st shard. 202 | work(0, std::min(block_size, total)); 203 | pool.wait_until_idle(); 204 | } 205 | 206 | 207 | void Shard(int max_parallelism, int num_threads, int total, 208 | int cost_per_unit, std::function work) { 209 | if (total < 0) { 210 | std::runtime_error("total must be > 0."); 211 | } 212 | if (total == 0) { 213 | return; 214 | } 215 | if (max_parallelism <= 1) { 216 | // Just inline the whole work since we only have 1 thread (core). 217 | work(0, total); 218 | return; 219 | } 220 | ParallelFor(max_parallelism, num_threads, 221 | std::max(1, cost_per_unit), total, work); 222 | return; 223 | } 224 | -------------------------------------------------------------------------------- /src/thread_pool.h: -------------------------------------------------------------------------------- 1 | #ifndef THREAD_POOL_H 2 | #define THREAD_POOL_H 3 | 4 | #include 5 | 6 | void Shard(int max_parallelism, int num_threads, int total, 7 | int cost_per_unit, std::function work); 8 | 9 | #endif // #ifndef THREAD_POOL_H 10 | -------------------------------------------------------------------------------- /test/test_fo_pool.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | from tensorflow.python.client import device_lib 7 | from tensorflow.python.ops import gradient_checker 8 | from tensorflow.python.framework import constant_op 9 | 10 | from qrnn import fo_pool, time_major_fo_pool_unsliced, batch_major_fo_pool_unsliced 11 | 12 | def np_fo_pooling(x, forget, initial_state, time_major): 13 | if not time_major: 14 | return np.transpose(np_fo_pooling(np.transpose(x, (1, 0, 2)), 15 | np.transpose(forget, (1, 0, 2)), 16 | initial_state, 17 | time_major=True), (1, 0, 2)) 18 | timesteps, batch, hidden = x.shape 19 | dst = np.zeros((timesteps + 1, batch, hidden), dtype=x.dtype) 20 | dst[0] = initial_state 21 | for ts in range(1, timesteps + 1): 22 | dst[ts] = (forget[ts - 1] * x[ts - 1] + 23 | (1.0 - forget[ts - 1]) * dst[ts - 1]) 24 | return dst[1:] 25 | 26 | 27 | class TestFoPool(unittest.TestCase): 28 | """ Tests the FoPool operator """ 29 | 30 | def setUp(self): 31 | # Obtain a list of GPU device specifications ['/gpu:0', '/gpu:1', ...] 32 | self.gpu_devs = [d.name for d in device_lib.list_local_devices() 33 | if d.device_type == 'GPU'] 34 | 35 | def test_fo_pool(self): 36 | """ Test the FoPool operator """ 37 | # List of type constraint for testing this operator 38 | type_permutations = [np.float32, np.float64] 39 | 40 | # Run test with the type combinations above 41 | for FT in type_permutations: 42 | for time_major in [False]: 43 | self._impl_test_fo_pool(FT, time_major) 44 | 45 | def _impl_test_fo_pool(self, FT, time_major): 46 | """ Implementation of the FoPool operator test """ 47 | # Create input variables 48 | timesteps = 20 49 | batch_size = 32 50 | channels = 64 51 | if time_major: 52 | shape = (timesteps, batch_size, channels) 53 | else: 54 | shape = (batch_size, timesteps, channels) 55 | x = np.random.random(size=shape).astype(FT) 56 | forget = np.random.uniform(0, 1, size=shape).astype(FT) 57 | initial_state = np.random.random(size=(batch_size, channels)).astype(FT) 58 | 59 | # Argument list 60 | np_args = [x, forget, initial_state] 61 | # Argument string name list 62 | arg_names = ["x", "forget", "initial_state"] 63 | # Constructor tensorflow variables 64 | tf_args = [tf.Variable(v, name=n) for v, n in zip(np_args, arg_names)] 65 | 66 | def _pin_op(device, *tf_args): 67 | """ Pin operation to device """ 68 | with tf.device(device): 69 | return fo_pool(*tf_args, time_major=time_major) 70 | 71 | # Pin operation to CPU 72 | cpu_op = _pin_op("/cpu:0", *tf_args) 73 | 74 | # Run the op on all GPUs 75 | gpu_ops = [_pin_op(d, *tf_args) for d in self.gpu_devs] 76 | 77 | # Initialise variables 78 | init_op = tf.global_variables_initializer() 79 | 80 | with tf.Session() as S: 81 | S.run(init_op) 82 | cpu_result = S.run(cpu_op) 83 | self.assertEqual(cpu_result.shape, shape) 84 | gpu_results = S.run(gpu_ops) 85 | for gpu_result in gpu_results: 86 | self.assertEqual(gpu_result.shape, shape) 87 | expected = np_fo_pooling(x, forget, initial_state, time_major=time_major) 88 | self.assertTrue(np.allclose(cpu_result, expected)) 89 | for gpu_result in gpu_results: 90 | self.assertTrue(np.allclose(gpu_result, expected)) 91 | 92 | def test_time_major_fo_pool_grad(self): 93 | """ Test the FoPool Gradient operator """ 94 | # List of type constraint for testing this operator 95 | type_permutations = [(np.float32, 1e-2), (np.float64, 1e-4)] 96 | 97 | # Run test with the type combinations above 98 | for FT, tolerance in type_permutations: 99 | self._impl_test_time_major_fo_pool_grad(FT, tolerance) 100 | 101 | 102 | def _impl_test_time_major_fo_pool_grad(self, FT, tolerance): 103 | shape = (5, 3, 2) 104 | np_args = [np.random.random(size=shape).astype(FT), 105 | np.random.uniform(0, 1, size=shape).astype(FT), 106 | np.random.random(size=shape[1:]).astype(FT)] 107 | with tf.Session() as S: 108 | tf_args = [constant_op.constant(arg, shape=arg.shape, dtype=FT) for arg in np_args] 109 | y = tf.reduce_sum(time_major_fo_pool_unsliced(*tf_args)) 110 | for d in ["cpu"] + self.gpu_devs: 111 | with tf.device(d): 112 | err = gradient_checker.compute_gradient_error( 113 | tf_args, [arg.shape for arg in np_args], y, [], 114 | x_init_value=np_args) 115 | self.assertLess(err, tolerance) 116 | 117 | def test_batch_major_fo_pool_grad(self): 118 | """ Test the FoPool Gradient operator """ 119 | # List of type constraint for testing this operator 120 | type_permutations = [(np.float32, 1e-2), (np.float64, 1e-4)] 121 | 122 | # Run test with the type combinations above 123 | for FT, tolerance in type_permutations: 124 | self._impl_test_batch_major_fo_pool_grad(FT, tolerance) 125 | 126 | 127 | def _impl_test_batch_major_fo_pool_grad(self, FT, tolerance): 128 | shape = (3, 5, 2) 129 | np_args = [np.random.random(size=shape).astype(FT), 130 | np.random.uniform(0, 1, size=shape).astype(FT), 131 | np.random.random(size=(shape[0], shape[-1])).astype(FT)] 132 | with tf.Session() as S: 133 | tf_args = [constant_op.constant(arg, shape=arg.shape, dtype=FT) for arg in np_args] 134 | y = tf.reduce_sum(batch_major_fo_pool_unsliced(*tf_args)) 135 | for d in ["cpu"] + self.gpu_devs: 136 | with tf.device(d): 137 | err = gradient_checker.compute_gradient_error( 138 | tf_args, [arg.shape for arg in np_args], y, [], 139 | x_init_value=np_args) 140 | self.assertLess(err, tolerance) 141 | 142 | 143 | if __name__ == "__main__": 144 | unittest.main() 145 | --------------------------------------------------------------------------------