├── .gitignore ├── .travis.yml ├── CMakeLists.txt ├── LICENSE ├── README.md ├── benchmarks └── allgpu-allreduce.lua ├── doc ├── BackgroundTask.md ├── BackgroundTaskPool.md ├── channel.md ├── index.md ├── map.md ├── marshal.md ├── mutex.md ├── sharedtable.md ├── spawn.md └── workqueue.md ├── examples ├── allreduce.lua ├── allreduce.sh ├── allreduce.slurm ├── allreduce_slurm.lua ├── allreduce_slurm.sh ├── client-server.lua ├── client-server.sh ├── map.lua ├── model-parallelism.lua └── workqueue.lua ├── ipc-scm-1.rockspec ├── lua ├── BackgroundTask.lua ├── BackgroundTaskPool.lua ├── DiscoveredTree.lua ├── LocalhostTree.lua ├── NullTree.lua ├── SlurmTree.lua ├── StaticTree.lua ├── Tree.lua └── utils.lua ├── src ├── channel.c ├── channel.h ├── cliser.c ├── cliser.h ├── error.h ├── flock.c ├── flock.h ├── generic │ └── cliser.c ├── ipc.c ├── map.c ├── map.h ├── marshal.c ├── marshal.h ├── mutex.c ├── mutex.h ├── ringbuffer.c ├── ringbuffer.h ├── serialize.c ├── serialize.h ├── sharedtable.c ├── sharedtable.h ├── spawn.c ├── spawn.h ├── workqueue.c └── workqueue.h └── test ├── test.lua ├── test_BackgroundTask.lua ├── test_BackgroundTaskPool.lua ├── test_Tree.lua ├── test_channel.lua ├── test_cliser.lua ├── test_flock.lua ├── test_map.lua ├── test_marshal.lua ├── test_mutex.lua ├── test_sharedtable.lua ├── test_spawn.lua └── test_workqueue.lua /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: false 2 | language: c 3 | compiler: 4 | - gcc 5 | env: 6 | - TORCH_LUA_VERSION=LUAJIT21 7 | - TORCH_LUA_VERSION=LUA51 8 | - TORCH_LUA_VERSION=LUA52 9 | addons: 10 | apt: 11 | sources: 12 | - ubuntu-toolchain-r-test 13 | packages: 14 | - cmake 15 | - gfortran 16 | - gcc-multilib 17 | - gfortran-multilib 18 | - liblapack-dev 19 | - build-essential 20 | - gcc-4.8 21 | - g++-4.8 22 | - clang 23 | - curl 24 | - cmake 25 | - libreadline-dev 26 | - git-core 27 | - libjpeg-dev 28 | - libpng-dev 29 | - ncurses-dev 30 | - imagemagick 31 | - libzmq3-dev 32 | - gfortran 33 | - unzip 34 | - gnuplot 35 | - gnuplot-x11 36 | before_script: 37 | - export ROOT_TRAVIS_DIR=$(pwd) 38 | - export INSTALL_PREFIX=~/torch/install 39 | - cd /tmp/ 40 | - git clone https://github.com/xianyi/OpenBLAS.git -b master 41 | - cd OpenBLAS 42 | - make clean 43 | - make USE_THREAD=0 USE_THREADS=0 USE_OPENMP=0 NO_AFFINITY=1 -j$(getconf _NPROCESSORS_ONLN) 44 | - make PREFIX=$HOME/OpenBlasInstall install 45 | - if [ "$CC" = "gcc" ]; then export CXX="g++-4.8" CC="gcc-4.8"; fi 46 | - git clone https://github.com/torch/distro.git ~/torch --recursive 47 | - cd ~/torch && git submodule update --init --recursive 48 | - mkdir build && cd build 49 | - export CMAKE_LIBRARY_PATH=$HOME/OpenBlasInstall/include:$HOME/OpenBlasInstall/lib:$CMAKE_LIBRARY_PATH 50 | - cmake .. -DCMAKE_INSTALL_PREFIX="${INSTALL_PREFIX}" -DCMAKE_BUILD_TYPE=Release -DWITH_${TORCH_LUA_VERSION}=ON 51 | - make && make install 52 | - git clone https://github.com/torch/xlua && cd xlua && ${INSTALL_PREFIX}/bin/luarocks make xlua-1.1-0.rockspec 53 | - if [[ $TORCH_LUA_VERSION != 'LUAJIT21' && $TORCH_LUA_VERSION != 'LUAJIT20' ]]; then ${INSTALL_PREFIX}/bin/luarocks install luaffi; fi 54 | - cd $ROOT_TRAVIS_DIR 55 | - export LD_LIBRARY_PATH=${INSTALL_PREFIX}/lib:$LD_LIBRARY_PATH 56 | script: 57 | - ${INSTALL_PREFIX}/bin/luarocks make 58 | - export PATH=${INSTALL_PREFIX}/bin:$PATH 59 | - export OMP_NUM_THREADS=1 60 | - th test/test.lua -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | SET(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}) 2 | 3 | CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR) 4 | CMAKE_POLICY(VERSION 2.6) 5 | 6 | FIND_PACKAGE(Torch REQUIRED) 7 | FIND_PACKAGE(CUDA 6.5) 8 | 9 | FIND_PACKAGE(OpenMP) 10 | IF(OPENMP_FOUND) 11 | SET(CMAKE_C_FLAGS "-D_OPENMP ${CMAKE_C_FLAGS}") 12 | ENDIF () 13 | 14 | SET(CMAKE_C_FLAGS "-std=c11 -pedantic -Werror -Wall -Wextra -Wno-unused-function -D_GNU_SOURCE ${CMAKE_C_FLAGS}") 15 | SET(src 16 | src/ipc.c 17 | src/workqueue.c 18 | src/ringbuffer.c 19 | src/serialize.c 20 | src/cliser.c 21 | src/map.c 22 | src/spawn.c 23 | src/flock.c 24 | src/mutex.c 25 | src/sharedtable.c 26 | src/marshal.c 27 | src/channel.c 28 | ) 29 | SET(luasrc 30 | lua/Tree.lua 31 | lua/StaticTree.lua 32 | lua/DiscoveredTree.lua 33 | lua/LocalhostTree.lua 34 | lua/SlurmTree.lua 35 | lua/NullTree.lua 36 | lua/utils.lua 37 | lua/BackgroundTask.lua 38 | lua/BackgroundTaskPool.lua 39 | test/test.lua 40 | test/test_BackgroundTask.lua 41 | test/test_BackgroundTaskPool.lua 42 | test/test_cliser.lua 43 | test/test_map.lua 44 | test/test_spawn.lua 45 | test/test_Tree.lua 46 | test/test_workqueue.lua 47 | test/test_mutex.lua 48 | test/test_sharedtable.lua 49 | test/test_marshal.lua 50 | test/test_channel.lua 51 | ) 52 | 53 | ADD_TORCH_PACKAGE(ipc "${src}" "${luasrc}" "A set of primitives for ipc computation in Torch") 54 | 55 | IF (CUDA_FOUND AND NOT ("$ENV{CUDA}" STREQUAL "NO")) 56 | INCLUDE_DIRECTORIES(${CUDA_INCLUDE_DIRS}) 57 | INCLUDE_DIRECTORIES("${CUDA_SDK_ROOT_DIR}/common/inc") 58 | SET(CMAKE_C_FLAGS "-DUSE_CUDA ${CMAKE_C_FLAGS}") 59 | IF (NOT "$ENV{TH_ONLY_STATIC}" STREQUAL "YES") 60 | TARGET_LINK_LIBRARIES(ipc luaT TH THC ${CUDA_LIBRARIES}) 61 | ENDIF() 62 | ELSE() 63 | TARGET_LINK_LIBRARIES(ipc luaT TH) 64 | ENDIF() 65 | 66 | IF (BUILD_STATIC OR "$ENV{STATIC_TH}" STREQUAL "YES") 67 | SET_TARGET_PROPERTIES(ipc_static PROPERTIES COMPILE_FLAGS "-fPIC -DSTATIC_TH") 68 | ENDIF() 69 | 70 | INSTALL(FILES ${luasrc} DESTINATION "${Torch_INSTALL_LUA_PATH_SUBDIR}/ipc") 71 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | IPC 2 | === 3 | 4 | A set of [primitives](doc/index.md) that extend Torch for high performance 5 | parallel computation across thread and process boundaries. 6 | 7 | Tree 8 | ---- 9 | 10 | Implements an AllReduce style binary tree of connected processes 11 | on top of Client-Server nodes. This enables AllReduce style operations 12 | on Tensors across a set of machines. The Tree is topology aware 13 | such that it creates an optimal binary tree of processes 14 | where each link is the fastest possible communication means 15 | available. Allows user code to use one abstraction for 16 | parallel Tensor computation and get high performance without 17 | coding specifically for the underlying topology. 18 | 19 | ```lua 20 | -- Every node will end up with the sum of all nodes 21 | tree.allReduce(grads, function(a, b) 22 | return a:add(b) 23 | end) 24 | ``` 25 | 26 | See the [AllReduce example](examples/allreduce.lua) to try it out. 27 | 28 | SlurmTree 29 | --------- 30 | 31 | An implementation of Tree that integrates with the [Slurm cluster manager](https://slurm.schedmd.com/). 32 | It builds the communication tree by reading in the slurm variables which are 33 | specified via [SBATCH directives](https://slurm.schedmd.com/sbatch.html) 34 | (i.e. --nodes, --tasks-per-node, etc...) and minimizing the inter node 35 | communication (when there are more than one tasks per node) 36 | 37 | SlurmTree takes two optional arguments: 38 | 1. File path - For the file that coordinates the initial connection 39 | of processes. The file location has to be shared across nodes. 40 | (By default '~/.torch') 41 | 2. Tasks per gpu - Used to calculate the gpu id property (By default 1) 42 | 43 | See the [slurm script](examples/allreduce.slurm) for an example of how to 44 | start the processes. 45 | 46 | Client-Server 47 | ------------- 48 | 49 | A classic client-server implementation over TCP sockets, used for IPC. 50 | Can transfer all Lua primitives as well as Torch Tensor and Storage 51 | types across the network. Includes a very fast implementation for 52 | sending CUDA Tensors between machines and an even faster implementation 53 | for passing CUDA Tensors between GPUs on the same machine using 54 | the PCI-E bus via CUDA IPC. 55 | 56 | The implementation is not tied to any specific cluster or discovery 57 | mechanism. All you need to do is ensure your nodes can reach each 58 | other over TCP. 59 | 60 | ```lua 61 | -- Create a server 62 | local server = ipc.server('127.0.0.1', 8080) 63 | -- Create a client and connect to the server 64 | local client = ipc.client('127.0.0.1', 8080) 65 | -- Say hello 66 | client:send('hi') 67 | -- Listen for any client to say something 68 | local msg = server:recvAny() 69 | assert(msg == 'hi') 70 | -- Disconnect and shutdown 71 | client:close() 72 | server:close() 73 | ``` 74 | 75 | Map 76 | --- 77 | 78 | A map function to spawn a set of worker threads and run a 79 | computation in the background. This is very handy for doing 80 | IO off of the main thread, as IO is usually blocked on a 81 | file or socket descriptor. 82 | 83 | ```lua 84 | -- See examples/map.lua for the complete listing 85 | -- Load 3 files in parallel 86 | local t1,t2,t3 = ipc.map(3, function(fileNames, mapid) 87 | return torch.load(fileNames[mapid]) 88 | end, {'f1.t7', 'f2.t7', 'f3.t7'}):join() 89 | ``` 90 | 91 | Read more complete documentation on [ipc.map](doc/map.md). 92 | 93 | Workqueue 94 | --------- 95 | 96 | A simple single writer multiple reader command queue. Really useful when 97 | combined with map to keep a bunch of background threads grabbing work 98 | off the queue, processing it and then returning answers back to the main 99 | thread. 100 | 101 | ```lua 102 | -- See examples/workqueue.lua for the complete listing 103 | -- Create a named workqueue 104 | local q = ipc.workqueue('my queue') 105 | 106 | -- Create 2 background workers that read from the named workqueue 107 | local workers = ipc.map(2, function() 108 | -- This function is not a closure, it is a totally clean Lua environment 109 | local ipc = require 'libipc' 110 | -- Open the queue by name (the main thread already created it) 111 | local q = ipc.workqueue('my queue') 112 | repeat 113 | -- Read the next file name off the workqueue 114 | local fileName = q:read() 115 | if fileName then 116 | -- Load the file and write its contents back into the workqueue 117 | q:write(torch.load(fileName)) 118 | end 119 | until fileName == nil 120 | end) 121 | 122 | -- Write the file names into the workqueue 123 | q:write('f1.t7') 124 | q:write('f2.t7') 125 | q:write('f3.t7') 126 | 127 | -- Read back the 3 answers and print them 128 | print(q:read()) 129 | print(q:read()) 130 | print(q:read()) 131 | ``` 132 | 133 | Read more complete documentation on [ipc.workqueue](doc/workqueue.md). 134 | 135 | Channels 136 | -------- 137 | 138 | Channels are a thread synchronization primitive based on 139 | message-passing. Threads communicate via channels by writing messages 140 | onto them and reading messages out of them, in FIFO order. There is no 141 | restriction on which threads or how many threads can read or write 142 | from a channel. This allows one to define concurrent workflows easily. 143 | 144 | Channels can also be closed, which prevents further writes to it. Once 145 | all items are read from a closed channel, that channel becomes drained 146 | and nothing further can be read from it. DAGs of computation made up 147 | of channels can be shut down via cascading closing/draining of 148 | channels. 149 | 150 | The 151 | [producer-consumer example](doc/channel.md/#producer-consumer-example) 152 | shows a group of producer threads and a group of consumer threads 153 | being set up to communicate via a channel. The main thread tears the 154 | entire setup down by closing the channel. 155 | 156 | The 157 | [local model parallelism for forward inference example](examples/model-parallelism.lua) shows 158 | how to set up a `nn.Sequential`-based model so that each of its 159 | submodules can execute forward inference in parallel. 160 | 161 | The full documentation can be found at [ipc.channel](doc/channel.md). 162 | 163 | Examples 164 | -------- 165 | 166 | Simple scripts you can run locally can be found [here](examples/). 167 | See the unit tests for a ton more detailed examples. 168 | 169 | Documentation 170 | ------------- 171 | 172 | Full API documentation can be found [here](doc/index.md). 173 | 174 | License 175 | ------- 176 | 177 | Licensed under the Apache License, Version 2.0. 178 | [See LICENSE file](LICENSE). 179 | -------------------------------------------------------------------------------- /benchmarks/allgpu-allreduce.lua: -------------------------------------------------------------------------------- 1 | local opt = lapp [[ 2 | Options: 3 | -d,--dimensions (default '1000,1000') comma delimited tensor dimensions 4 | -i,--iterations (default 1000) number of send/recv iterations 5 | -v,--verify verify the results (affects performance) 6 | ]] 7 | 8 | local ipc = require 'libipc' 9 | 10 | -- Use a child to get the device count (can't fork after CUDA init) 11 | local ppid = ipc.getppid() 12 | local pid = ipc.fork() 13 | if pid == 0 then 14 | local cutorch = require 'cutorch' 15 | os.exit(cutorch.getDeviceCount()) 16 | end 17 | local deviceCount = ipc.waitpid(pid) 18 | 19 | -- Fork a new process per device 20 | print('Found '..deviceCount..' GPUs, forking children...') 21 | local device = 1 22 | for i = 2,deviceCount do 23 | local pid = ipc.fork() 24 | if pid == 0 then 25 | device = i 26 | break 27 | end 28 | end 29 | 30 | -- This is the forked child process 31 | local cutorch = require 'cutorch' 32 | local sys = require 'sys' 33 | local LocalhostTree = require 'ipc.LocalhostTree' 34 | 35 | -- grab a GPU 36 | cutorch.setDevice(device) 37 | 38 | -- Create the tree of nodes (one per GPU) 39 | local tree = LocalhostTree(device, deviceCount, ppid) 40 | 41 | -- Create a big tensor 42 | local dimensions = string.split(opt.dimensions, ",") 43 | for i = 1,#dimensions do 44 | dimensions[i] = tonumber(dimensions[i]) 45 | end 46 | local unpack = unpack or table.unpack 47 | local t0 = torch.randn(unpack(dimensions)):cuda() 48 | 49 | -- Iterate! 50 | sys.tic() 51 | for i = 1,opt.iterations do 52 | if opt.verify then 53 | t0:fill(i / opt.iterations):pow(device) 54 | end 55 | tree.allReduce(t0, function(a, b) return a:add(b) end) 56 | if opt.verify then 57 | local expected = 0 58 | for d = 1,deviceCount do 59 | expected = expected + math.pow(i / opt.iterations, d) 60 | end 61 | t0:add(-expected) 62 | local err = math.max(math.abs(t0:min()), math.abs(t0:max())) 63 | assert(err < 1e-6, 'err too high: '..err) 64 | end 65 | end 66 | if device == 1 then 67 | print('did '..opt.iterations..' in '..sys.toc()..' seconds') 68 | tree.netStats() 69 | end 70 | -------------------------------------------------------------------------------- /doc/BackgroundTask.md: -------------------------------------------------------------------------------- 1 | # ipc.BackgroundTask # 2 | 3 | BackgroundTask combines [ipc.map](map.md) and [ipc.workqueue](workqueue.md) to implement 4 | a very simple, pollable, way to run a Lua function as background task. 5 | Internally it is a [ipc.BackgroundTaskPool](BackgroundTaskPool.md) with a pool size of 1 6 | and a single task added to the pool. 7 | 8 | You construct a BackgroundTask with a function and a set of arguments 9 | to be passed to the function. At any point you can check if the task 10 | is done running with __isDone__ or to get all the return values using 11 | __getResult__. 12 | 13 | BackgroundTask is very useful for doing periodic backups during long 14 | running jobs. For example, saving out your Torch model every so often 15 | during a training run that takes many hours or days. It is quite common 16 | for data centers to only provide ephemeral storage on compute nodes, that 17 | means if a job dies its results are lost forever. Saving to something 18 | more persistent like S3 or HDFS is slow, so we can use a BackgroundTask 19 | to hide that from the main training loop. In this example below we 20 | use curl to upload the saved file to a hypothetical website. 21 | 22 | ```lua 23 | local BackgroundTask = require 'ipc.BackgroundTask' 24 | 25 | -- Make a random "model" to save 26 | local model = { x = torch.randn(13, 42), y = math.random() } 27 | 28 | -- Write it out to a temp file 29 | -- This is quick and the easiest way to marshal 30 | -- the model into a background task 31 | local tempFilename = os.tmpname() 32 | torch.save(tempFilename, model) 33 | print('temporarily saved model to '..tempFilename) 34 | 35 | -- Create a background task to upload the model 36 | local background = BackgroundTask(function(filename, url) 37 | -- This function is a clean Lua environment 38 | local ipc = require 'libipc' 39 | local sys = require 'sys' 40 | -- Time the upload 41 | sys.tic() 42 | -- Spawn curl to upload the file 43 | local p = ipc.spawn({ 44 | file = 'curl' 45 | args = { 46 | '-i', 47 | '-F name=saveme', 48 | '-F filedata=@'..filename, 49 | url, 50 | } 51 | }) 52 | -- Wait on the upload, return curl's exit code and how long it took 53 | return p:wait(), sys.toc() 54 | end, tempFilename, "http://yourserver.com/here") 55 | 56 | -- Do something while the save happens 57 | while not background.isDone() do 58 | print('still saving...') 59 | sys.sleep(1) 60 | end 61 | 62 | -- Its done so check check the results 63 | local ret,t = background.getResults() 64 | print('curl returned '..ret..' in '..t..' seconds') 65 | ``` 66 | 67 | Note that __ipc.BackgroundTask__ will throw an error when attempting to serialize closures/upvalues. 68 | However [ipc.workqueue](workqueue.md) provides __:writeup()__ for serializing closures/upvalues. 69 | 70 | -------------------------------------------------------------------------------- /doc/BackgroundTaskPool.md: -------------------------------------------------------------------------------- 1 | # ipc.BackgroundTaskPool # 2 | 3 | BackgroundTaskPool combines [ipc.map](map.md) and [ipc.workqueue](workqueue.md) to implement 4 | a very simple, pollable, way to run a set of arbitrary Lua functions as background tasks. 5 | You construct a BackgroundTaskPool with the size of the thread pool 6 | you would like to run your tasks on. 7 | 8 | ```lua 9 | local BackgroundTaskPool = require 'ipc.BackgroundTaskPool' 10 | local pool = BackgroundTaskPool(13) -- create a pool with 13 worker threads 11 | ``` 12 | 13 | Call addTask with a function and a set of arguments to add a task into the pool. 14 | 15 | ```lua 16 | for i = 1,42 do 17 | pool.addTask(function(i) 18 | return math.sqrt(i) 19 | end, i) 20 | end 21 | ``` 22 | You can optionally poll for the completion of all tasks or just one task. 23 | 24 | ```lua 25 | -- is the task with id 7 done? 26 | if pool.isDone(7) then 27 | print('the 7th is done!') 28 | end 29 | -- are all the tasks done? 30 | if pool.isDone() then 31 | print('all tasks are done!') 32 | end 33 | ``` 34 | 35 | When you want the result for a task, call getResult with the task id. 36 | If the task function threw an error then the getResult call will throw that same 37 | error. 38 | 39 | ```lua 40 | for i = 1,42 do 41 | assert(pool.getResult(i) == math.sqrt(i)) 42 | end 43 | ``` 44 | 45 | Note that __ipc.BackgroundTaskPool__ will throw an error when attempting to serialize closures/upvalues. 46 | However [ipc.workqueue](workqueue.md) provides __:writeup()__ for serializing closures/upvalues. 47 | -------------------------------------------------------------------------------- /doc/channel.md: -------------------------------------------------------------------------------- 1 | # ipc.channel # 2 | Channels are a thread synchronization primitive based on 3 | message-passing. Threads communicate via channels by writing messages 4 | onto them and reading messages out of them, in FIFO order. There is no 5 | restriction on which threads or how many threads can read or write 6 | from a channel. This allows one to define concurrent workflows easily. 7 | 8 | Channels can also be closed, which prevents further writes to it. Once 9 | all items are read from a closed channel, that channel becomes drained 10 | and nothing further can be read from it. DAGs of computation made up 11 | of channels can be shut down via cascading closing/draining of 12 | channels. 13 | 14 | ``` lua 15 | local c = ipc.channel([]) 16 | ``` 17 | 18 | The constructor does not take any arguments. However, this will most 19 | likely change in the near-future as additional functionality is added. 20 | 21 | The following methods are defined on the channel. 22 | * __:write()__ 23 | * __:read()__ 24 | * __:num_items()__ 25 | * __:close()__ 26 | * __:closed()__ 27 | * __:drained()__ 28 | 29 | ## Lifecycle of channels 30 | Channels can be in one of three states: `ipc.channel.OPEN`, 31 | `ipc.channel.CLOSED` or `ipc.channel.DRAINED`. A newly-created channel 32 | is open and will accept reads and writes. A channel can be closed 33 | using the __:close()__ method. A closed channel will no longer accept 34 | writes, but any items that remain on the channel can be read. Once all 35 | of the items on a closed channel are read, that channel becomes 36 | drained. Empty channels that are closed will go into the drained state 37 | immediately. 38 | 39 | The state of the channel is returned as a `status` return variable in 40 | calls to __:write()__ and __:read()__. The state of the channel can 41 | also be queried using the __:closed()__ and __:drained()__ methods. 42 | 43 | ## Reading and writing from channels 44 | Any thread can write values into a channel and any thread can read 45 | those values out of the channel. Writes onto an open channel should 46 | always succeed, assuming that no errors occurred. Reads on an empty 47 | and non-drained channel can either cause the thread to block (for 48 | blocking reads) or return nil (for non-blocking reads). Reads on a 49 | drained channel return immediately with the `ipc.channel.DRAINED` 50 | status. 51 | 52 | ### Producer-consumer example 53 | The following example illustrates using a channel to send items from 54 | one group of threads to another. A group of producer threads and a 55 | group of consumer threads are set up to communicate via a channel. The 56 | main thread tears the entire setup down by closing the channel. 57 | 58 | ``` lua 59 | local ipc = require 'libipc' 60 | local c = ipc.channel() -- create channel 61 | 62 | -- Spawn producer threads that write items to channel and checks the 63 | -- returned status. If the status is not ipc.channel.OPEN, then the 64 | -- channel has been closed and the producers should terminate. 65 | local nproducers = 3 66 | local producers = ipc.map(nproducers, function(c, tid) 67 | local ipc = require 'libipc' 68 | local sys = require 'sys' 69 | while true do 70 | local x = {tid, math.floor(torch.random(10))} -- generate item 71 | local status = c:write(x) -- write item onto channel 72 | if status ~= ipc.channel.OPEN then 73 | break -- channel is no longer open, so terminate 74 | end 75 | sys.sleep(0.1) -- don't generate too fast 76 | end 77 | end, c) 78 | 79 | -- Spawn consumer threads that read items from the channel and checks 80 | -- the returned status. If the status is ipc.channel.DRAINED, then 81 | -- there will not be any more items to read and the consumers should 82 | -- terminate. 83 | local consumers = ipc.map(1, function(c) 84 | local ipc = require 'libipc' 85 | local nonblocking = false 86 | while true do 87 | local status, item = c:read(nonblocking) -- read item from channel 88 | if status == ipc.channel.DRAINED then 89 | break -- channel has been drained, so terminate 90 | else 91 | print('tid: '..item[1]..' r: '..item[2]) -- do the thing 92 | end 93 | end 94 | end, c) 95 | 96 | -- It is possible to write to the channel from any thread, including 97 | -- this one. 98 | c:write({0, 'from main thread'}) 99 | 100 | sys.sleep(5) -- wait 5 secs so producers and consumers can run 101 | 102 | -- Close the channel so producers and consumers will terminate. 103 | c:close() 104 | producers:join() 105 | consumers:join() 106 | assert(c:num_items() == 0) 107 | ``` 108 | 109 | ### Multi-write example 110 | The following example shows how to write multiple values into a 111 | channel with a single __:write()__ call. __:write()__ can accept 112 | multiple arguments. Each of these arguments is written to the channel. 113 | 114 | ``` lua 115 | local c = ipc.channel() 116 | local data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,} 117 | local unpack = unpack or table.unpack 118 | local status = c:write(unpack(data)) 119 | assert(status == ipc.channel.OPEN) 120 | assert(c:num_items() == 12, 'number of items in channel is incorrect') 121 | local nonblocking = false 122 | for i=1,#data do 123 | local status, readData = c:read(nonblocking) 124 | assert(status == ipc.channel.OPEN) 125 | assert(readData == i) 126 | end 127 | ``` 128 | 129 | ## Closing and draining channels 130 | Open channels can be closed, which changes its state from 131 | `ipc.channel.OPEN` to `ipc.channel.CLOSED`. Closed channels will no 132 | longer accept writes, but any items that are still on the channel can 133 | be read. 134 | 135 | Once all of the items on a closed channel are read, its state is 136 | changed from `ipc.channel.CLOSED` to `ipc.channel.DRAINED`. Further 137 | reads will return immediately with the drained status. An empty 138 | channel which is closed will immediately be put into the drained 139 | state. 140 | 141 | Closing and draining channels can be used to signal to downstream 142 | threads that there will no longer be anything to read on the 143 | channel. These threads might close the channels that they are writing 144 | to, resulting in a cascading teardown of channels further downstream. 145 | 146 | See the [Producer-consumer example](### Producer-consumer example) to 147 | see an example of threads checking statuses of __:read()__ and 148 | __:write()__ calls to determine whether they should terminate and the 149 | main thread closing a channel to teardown a collection of threads 150 | operating on a channel. 151 | 152 | ## Behaviors not yet implemented 153 | 1. It is not possible to specify the max number of items that can be 154 | written to a channel. The channel just grows to allow the write, 155 | instead of blocking on write. Therefore, just 156 | like [ipc.workqueue](workqueue.md), there is no backpressure 157 | mechanism. 158 | 2. There is no select call to select between a number of channels. 159 | 3. The __:read()__ call, when called in non-blocking mode, does not 160 | allow one to distinguish between reading a nil from the channel and 161 | not reading an item at all. 162 | 163 | ## Examples 164 | The 165 | [ipc.channel unit tests](../test/test_channel.lua) 166 | provide a rich set of examples, in addition to the 167 | [local model parallelism for forward inference example](../examples/model-parallelism.lua). 168 | 169 | Two examples are described in detail here. 170 | 171 | ### Building workqueues with channels 172 | The `channelsAsWorkQueue` unit test shows how to build a workqueue as 173 | described in [ipc.workqueue](workqueue.md) using channels, while 174 | allowing for more than one owner thread. 175 | 176 | Multiple threads can write onto the channel that is used to send work 177 | items to the workers. They can write onto this channel until it is 178 | closed. Multiple workers read from the work item channel. Each worker 179 | only knows how to read a work item from the channel, process it and 180 | then write it into a results channel. Each worker terminates as soon 181 | as it sees that either the work item channel has been drained or the 182 | results channel has been closed. 183 | 184 | ### Local model parallelism for forward inference 185 | The 186 | [local model parallelism example](../examples/model-parallelism.lua) 187 | shows how to set up a `nn.Sequential`-based model so that each of its 188 | submodules can execute forward inference in parallel. 189 | 190 | The code runs forward inference on the following model: 191 | ``` lua 192 | [nn.Sequential { 193 | [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> output] 194 | (1): nn.TemporalConvolution(20 -> 10, 5) 195 | (2): nn.Tanh 196 | (3): nn.TemporalConvolution(10 -> 5, 5) 197 | (4): nn.Tanh 198 | (5): nn.TemporalConvolution(5 -> 1, 5) 199 | (6): nn.Tanh 200 | (7): nn.Max 201 | } 202 | ``` 203 | 204 | The unit test first measures the time taken to perform forward 205 | inference on a single thread. Then the unit test measures the time 206 | taken to perform forward inference when the model is split and run in 207 | parallel across multiple threads. 208 | 209 | Each model is distributed across multiple threads as follows: 210 | 211 | 1. For i = `1` to `3`, submodules `2i-1` (`nn.TemporalConvolution`) 212 | and `2i` (`nn.Tanh`) are instantiated on thread `i`. 213 | 2. The `nn.Max` submodule is instantiated on thread `4`. 214 | 215 | Each thread has an input channel and an output channel. The output 216 | tensor of the previous submodule will become available on the input 217 | channel. The thread reads this tensor from the input channel and then 218 | executes the __:forward()__ call on its layer with the tensor as input 219 | and writes the resulting output tensor onto its output channel. 220 | 221 | The unit test shows that splitting up the model and running each part 222 | of it in parallel on separate threads is faster across the generated 223 | workload than running the entire model sequentially on a single 224 | thread. However, multiple threads running their own copies of the 225 | entire model (data-parallel) should be as fast as the model-parallel 226 | version. 227 | 228 | ## Background on channels 229 | The semantics of [ipc.channel](channel.md) is based on the following 230 | resources: 231 | 232 | 1. https://gobyexample.com/channels 233 | 2. https://tour.golang.org/concurrency/2 234 | 3. https://github.com/clojure/core.async/blob/master/examples/walkthrough.clj 235 | 236 | Channels are a simplification of the workqueues provided 237 | by [ipc.workqueue](workqueue.md). A workqueue has two queues - one 238 | that is used to send items to workers and the other that is used to 239 | send back results to the single owner thread. A channel just has a 240 | single queue that any thread can enqueue or dequeue from. 241 | -------------------------------------------------------------------------------- /doc/index.md: -------------------------------------------------------------------------------- 1 | # IPC Package Reference Manual # 2 | 3 | * [ipc.spawn](spawn.md) a more modern replacement to Lua's io.popen function. 4 | * [ipc.map](map.md) maps a Lua function onto a set of threads. 5 | * [ipc.mutex](mutex.md) provides utilies for locking resources and creating synchronization barriers. 6 | * [ipc.workqueue](workqueue.md) allows the main thread to communicate with a set of worker threads. 7 | * [ipc.channel](channel.md) allows thread to communicate with one another via messaging-passing. 8 | * [ipc.sharedtable](sharedtable.md) provides a table that can be shared between threads. 9 | * [ipc.marshal](marshal.md) serializes objects into a compact `userdata` instance. 10 | * [ipc.BackgroundTask](BackgroundTask.md) provides a simple, pollable interface to a single Lua function run in the background. 11 | * [ipc.BackgroundTaskPool](BackgroundTaskPool.md) provides a simple, pollable interface to a set of arbitrary Lua functions run in the background. 12 | -------------------------------------------------------------------------------- /doc/map.md: -------------------------------------------------------------------------------- 1 | # ipc.map # 2 | 3 | ```lua 4 | local map = ipc.map(nThread, threadFunc, [...]) 5 | ``` 6 | 7 | Maps a Lua function onto a set of `nThread` threads. 8 | The `threadFunc` function is run in an entirely new and clean Lua environment. 9 | The optional variable length arguments `...` are passed to the function: `threadFunc(...)`. 10 | 11 | No upvalues of the function are marshaled into the thread. 12 | Torch Tensor and Storage objects are shared across the thread boundary, no copies are made. 13 | That said, we also do not provide thread safe access to those Tensors. 14 | It is up to the programmer to implement their own lockless system. 15 | 16 | You can pass as many arguments as you wish to the function and 17 | the function can return as many values as it wants. 18 | 19 | ```lua 20 | local ipc = require 'libipc' 21 | local m = ipc.map(2, function(a, b, c, threadId) 22 | -- the last argument passed to the function is the ID of the thread 23 | assert((threadId == 1) or (threadId == 2)) 24 | return math.sqrt(a*a + b*b + c*c), "hi" 25 | end, 1, 2, 3) 26 | local p1,s1,p2,s2 = m:join() 27 | print(p1) -- will print 3.7416573867739 28 | print(s1) -- will print "hi" 29 | print(p2) -- will print 3.7416573867739 30 | print(s2) -- will print "hi" 31 | ``` 32 | 33 | Note how an additional hidden argument is passed to the mapped function: `threadId`. 34 | This `threadId` is a number identifying the thread (from 1 to `nThread`). 35 | It is always passed as the last argument to the function. 36 | 37 | Similar to posix threads, when you want to wait for all the child 38 | threads to end and get their return values you must call __:join()__. 39 | If any child threads had errors these will bubble up via a call 40 | to __:join()__. You can wrap this in a pcall if you think the 41 | error is recoverable (generally, it is not). 42 | 43 | At any point the parent thread can check if any of the child 44 | threads have died. The __:checkErrors()__ function will bubble 45 | up an errors. You can wrap this in a pcall if you think the 46 | error is recoverable (generally, it is not). 47 | 48 | Note that __ipc.map()__ will throw an error when attempting to serialize closures/upvalues. 49 | However [ipc.workqueue](workqueue.md) provides __:writeup()__ for serializing closures/upvalues. 50 | 51 | Most often [ipc.map](map.md) is combined with an [ipc.workqueue](workqueue.md) 52 | in order to distribute work across the threads. 53 | A more concrete example of combining [ipc.map](map.md) and [ipc.workqueue](workqueue.md) 54 | can be found in [ipc.BackgroundTask](BackgroundTask.md) 55 | 56 | -------------------------------------------------------------------------------- /doc/marshal.md: -------------------------------------------------------------------------------- 1 | # ipc.marshal # 2 | 3 | ```lua 4 | local m = ipc.marshal(obj[, upval, size, size_increment]) 5 | ``` 6 | 7 | The `ipc.marshal` constructor serializes a Lua object (`obj` above). 8 | It returns an `ipc.marshal` instance (`m` above). 9 | The resulting `m` instance provides a simple __:read__ method for deserializing the object: 10 | 11 | ```lua 12 | local obj = m:read() 13 | ``` 14 | 15 | When `upval` is true, the serialized `obj` can contain upvalues/closures (default is false) 16 | The `size` and `size_increment` are used to incrementally grow the write buffer (see `ipc.workqueue`). 17 | 18 | ## Example 19 | 20 | More concretely, suppose we want to serialize the following table: 21 | 22 | ```lua 23 | local obj = {1,2,3,v=4,g=5} 24 | ``` 25 | 26 | We can marshal (that is, serialize) the data into `m` 27 | 28 | ```lua 29 | local m = ipc.marshal(obj) 30 | ``` 31 | 32 | Contrary to normal serialization, `m` is not a string, but a `userdata` object. 33 | The most useful thing about `m` is that if can be read multiple times (like a string): 34 | 35 | ```lua 36 | local obj2 = m:read() 37 | local obj3 = m:read() 38 | ``` 39 | 40 | This is useful when you want to serialize something only once (say, in the main thread), 41 | and unserialize it many times (say, in worker threads). 42 | 43 | -------------------------------------------------------------------------------- /doc/mutex.md: -------------------------------------------------------------------------------- 1 | # ipc.mutex # 2 | 3 | ```lua 4 | local mutex = ipc.mutex() 5 | ``` 6 | A `mutex` is used to lock resources and create barriers. 7 | It is used to coordinate and synchronize threads. 8 | 9 | ## Locking 10 | 11 | For locking, consider the following example: 12 | 13 | ```lua 14 | local mutex = ipc.mutex() 15 | local shared = torch.FloatTensor(10) 16 | shared:fill(0) 17 | local m = ipc.map(3, function(mutex, shared, mapid) 18 | local ipc = require 'libipc' 19 | mutex:lock() 20 | shared:fill(mapid) 21 | mutex:unlock() 22 | end, mutex, shared) 23 | ``` 24 | 25 | In the above example, the 3 worker threads uses the mutex to protect access to the `shared` tensor. 26 | Only one thread can call `:lock()` at any given time. 27 | The remaining threads will block on that call until the locking thread calls `:unlock()` on that same `mutex`. 28 | The last thread to call `:lock()` will by fill the `shared` tensor with its `mapid`, i.e. the id of the thread. 29 | 30 | ## Barrier 31 | 32 | Barriers are used to synchronize threads. 33 | A call to `mutex:barrier(nThread)` blocks until `nThread` have called `:barrier()`. 34 | When the call returns, all threads are synchronized at this point in the code. 35 | 36 | Consider the following example: 37 | 38 | ```lua 39 | local shared = torch.FloatTensor(1) 40 | shared:fill(0) 41 | local m = ipc.map(3, function(mutex, shared, mapid) 42 | local ipc = require 'libipc' 43 | local sys = require 'sys' 44 | assert(shared[1] == 0) 45 | mutex:barrier(4) -- first barrier 46 | -- main thread updates shared[1] 47 | mutex:barrier(4) -- second barrier 48 | assert(shared[1] ~= 0) 49 | end, mutex, shared) 50 | 51 | assert(shared[1] == 0) 52 | mutex:barrier(4) -- first barrier 53 | shared[1] = 1000 54 | mutex:barrier(4) -- second barrier 55 | assert(shared[1] ~= 0) 56 | m:join() 57 | ``` 58 | 59 | The example uses 3 worker threads. With the main thread, we have a total of 4 threads. 60 | We all 4 threads to synchronize using a 2 barriers. 61 | The first `:barrier()` is to make certain everyone executed `assert(shared[1] == 0)` 62 | Afterwhich, the main thread updates `shared[1] = 1000`. 63 | After the second `:barrier()` returns, everyone executes `assert(shared[1] ~= 1000)`. 64 | 65 | -------------------------------------------------------------------------------- /doc/sharedtable.md: -------------------------------------------------------------------------------- 1 | # sharedtable # 2 | 3 | ```lua 4 | local t = ipc.sharedtable([tbl, move]) 5 | ``` 6 | 7 | A `sharedtable` is a table that can be shared between threads. 8 | The constructor can take an optional `tbl` argument, which is a Lua table. 9 | When provided, `tbl` is used to initialize the `sharedtable`. 10 | 11 | The optional `move` argument specifyies that elements should be moved from original table to the new one (defaults to `false`). 12 | This is useful in case the table is very big and we don't want to have 2 copies of the full data in memory. 13 | 14 | A `sharedtable` can be shared between threads, either by passing it through 15 | a `ipc.workqueue` or as an argument to `ipc.map`. 16 | Here is an example using `ipc.map`: 17 | 18 | ```lua 19 | local t = ipc.sharedtable({0}) 20 | ipc.map(2, function(tab, threadid) 21 | for i=1,100 do 22 | if (i % 2) + 1 == threadid then 23 | tab[1] = i 24 | end 25 | end 26 | end, t):join() 27 | 28 | for i=1,100 do 29 | assert(t[i] == i) 30 | end 31 | ``` 32 | 33 | Indeed, the `sharedtable` instance `t` is shared between 2 worker threads and the main thread. 34 | 35 | Internally, the `sharedtable` is implemented by storing the table in a separate `lua_State`. 36 | Recall that each thread also has its own `lua_State`. 37 | And the way data is passed between threads is by serializing/deserializing the data from one `lua_State` to another. 38 | The `sharedtable` uses the same principle. 39 | By this I mean that all read and writes to the table (get/set) are implemented by (de)serializing between the calling thread's and the shared table's `lua_State`. 40 | 41 | The `sharedtable` is subject to some caveats. 42 | Some use-cases will still require a lock. 43 | For example: 44 | 45 | ```lua 46 | 47 | local t = ipc.sharedtable() 48 | local m = ipc.mutex() 49 | ipc.map(10, function(tab, mutex) 50 | for i=1,100 do 51 | mutex:lock() 52 | tab[i] = (tab[i] or 0) + 1 53 | mutex:unlock() 54 | end 55 | end, t, m):join() 56 | assert(#t == 100) 57 | for i=1,100 do 58 | assert(t[i] == 10) 59 | end 60 | ``` 61 | 62 | In the above example, we want to atomically add 1 to the 100 first element of the table. 63 | Doing this requires a `mutex`. 64 | 65 | ## Nested tables 66 | 67 | Another very big caveat of `sharedtable` should be highlighted. 68 | When you de-serialize, you're actually creating an unnamed variable in the thread's state. 69 | This causes some confusion, as in: 70 | 71 | ```lua 72 | local ipc = require 'libipc' 73 | local tbl = {key={subkey=1}} 74 | local shared_tbl = ipc.sharedtable(tbl) 75 | shared_tbl.key.subkey = 2 76 | assert(shared_tbl.key.subkey == 1) 77 | ``` 78 | 79 | The reason for this is that when accessing the sub-table pointed to by `shared_tbl.key`, the entire sub-table is unserialized. 80 | That is, we are writing directly to that unserialized table without re-serializing it into the shared table. 81 | A solution could be to re-write the entire modified sub-table: 82 | 83 | ```lua 84 | local sub_tbl = shared_tbl.key 85 | sub_tbl.subkey = 2 86 | shared_tbl.key = sub_tbl 87 | assert(shared_tbl.key.subkey == 2) 88 | ``` 89 | 90 | But that is inefficient. A more interesting alternative is to nest shared tables: 91 | 92 | ```lua 93 | local shared_tbl = ipc.sharedtable{key=ipc.sharedtable{subkey=1}} 94 | shared_tbl.key.subkey = 2 95 | assert(shared_tbl.key.subkey == 2) 96 | ``` 97 | 98 | -------------------------------------------------------------------------------- /doc/spawn.md: -------------------------------------------------------------------------------- 1 | # ipc.spawn # 2 | 3 | The IPC library provides a more modern replacement for Lua's built in 4 | io.popen functionality. You can use the ipc.spawn function to create 5 | child processes, read from stdout, and write to stdin. 6 | 7 | ```lua 8 | local ipc = require 'libipc' 9 | local exitCode = ipc.spawn({ file = 'which', args = { 'nvcc' } }:wait() 10 | if exitCode == 0 then 11 | print('nvcc is present') 12 | end 13 | ``` 14 | 15 | The function ipc.spawn takes a table of options. 16 | 17 | * __file__ is the name of the executable you wish to run. Finding the executable 18 | follows the rules laid down in [posix_spawn](http://linux.die.net/man/3/posix_spawn). 19 | In short, if __file__ contains a slash '/' then it is used a direct path, 20 | relative or absolute, to the executable. If __file__ does not contain a slash '/' 21 | then the environment variable __PATH__ is used to search. 22 | 23 | * __args__ is an optional table of arguments to be passed to the executable. You do 24 | not need to worry about quoting or escaping arguments that contains spaces as 25 | these args are passed directly to the executable. 26 | 27 | ```lua 28 | local ipc = require 'libipc' 29 | local p = ipc.spawn({ file = 'echo', args = { 'there -is- no need\nto "escape"' } }) 30 | print(p:stdout('*line')) -- will print 'there -is- no need' 31 | print(p:stdout('*line')) -- will print 'to "escape"' 32 | p:wait() 33 | ``` 34 | 35 | * __env__ is an optional table of environment variables to pass to the executable. 36 | By default if __env__ is not specified the executable inherits the exact same 37 | environment as the spawning process. The elements of the __env__ table take the 38 | form of __VAR=VALUE__. 39 | 40 | ```lua 41 | local ipc = require 'libipc' 42 | local p = ipc.spawn({ file = 'printenv', args = { 'SOME_VAR' }, env = { 'SOME_VAR=42' } }) 43 | print(p:stdout('*all')) -- will print '42' 44 | p:wait() 45 | ``` 46 | 47 | Once the child executable is spawned ipc.spawn will return a Lua object that 48 | can be used to control the child process. 49 | 50 | * __:pid()__ returns the system process ID of the child process. 51 | 52 | ```lua 53 | local ipc = require 'libipc' 54 | local p = ipc.spawn({ file = 'pwd' }) 55 | print('the child process pid is '..p:pid()) 56 | p:wait() 57 | ``` 58 | 59 | * __:running()__ returns true if the child process is still running, false if it is done. 60 | 61 | ```lua 62 | local ipc = require 'libipc' 63 | local p = ipc.spawn({ file = 'sleep', args = { 1 } }) 64 | while p:running() do 65 | print('still sleeping!') 66 | end 67 | p:wait() 68 | ``` 69 | 70 | * __:wait(optional string)__ waits for the child process to complete and returns its exit code. 71 | The __:wait()__ function is blocking, if the child process never ends then __:wait()__ 72 | will never return. You can optionally send a signal to the child process by passing 73 | in an extra string argument to __:wait()__. The supported signals are "TERM" and "KILL". 74 | 75 | ```lua 76 | local ipc = require 'libipc' 77 | local p1 = ipc.spawn({ file = 'sleep', args = { 1 } }) 78 | p1:wait() -- this will block for 1 second 79 | local p2 = ipc.spawn({ file = 'sleep', args = { 1 } }) 80 | p2:wait("TERM") -- this will return immediately since we sent a SIGTERM 81 | ``` 82 | 83 | * __:stdin(optional string)__ will write a string to the child processes stdin. 84 | You can call this as many times as you need to. If you wish to close the stdin pipe 85 | then call __:stdin()__ with no arguments. Some processes will not quit until stdin 86 | is closed. Calling __:wait()__ will also close the stdin pipe. 87 | 88 | ```lua 89 | local ipc = require 'libipc' 90 | local p = ipc.spawn({ file = 'tee' }) 91 | p:stdin('hi\n') 92 | p:stdin() -- close the stdin pipe so tee will terminate 93 | print(p:stdout('*all')) -- read all of tee's stdout and print it 94 | p:wait() 95 | ``` 96 | 97 | * __:stdout(arg)__ will read some amount of data from the child processes stdout pipe. 98 | The stdout function supports the same arguments as Lua files. There are 3 ways to 99 | read stdout. You can pass in __*line__ to read a single line (the new line will not be 100 | returned with the resulting string). You can pass in __*all__ to read all of stdout. 101 | Finally, you can pass in a number indicating how many bytes of stdout you would like to read. 102 | When there is no more stdout to read, due to EOF, the function will return __nil__. 103 | 104 | ```lua 105 | local ipc = require 'libipc' 106 | local p = ipc.spawn({ file = 'tee' }) 107 | p:stdin('a\n') 108 | p:stdin('b\n') 109 | p:stdin('c') 110 | p:stdin('d\n') 111 | p:stdin() -- close the stdin pipe so tee will terminate 112 | -- read and print stdout until we see nil 113 | repeat 114 | local line = p:stdout('*line') 115 | if line then 116 | print(line) 117 | end 118 | until not line 119 | p:wait() 120 | ``` 121 | 122 | Orpaned ipc.spawn objects that are eventually garbage collected by Lua will automatically 123 | send SIGTERM to the child process and then wait on it to exit gracefully. This could lead 124 | to pauses or infinite hangs during garbage collection. It is therefore **highly** recommended 125 | that you always call __:wait()__ on spawned child processes before you let go of the 126 | reference to the ipc.spawn object. 127 | -------------------------------------------------------------------------------- /doc/workqueue.md: -------------------------------------------------------------------------------- 1 | # ipc.workqueue # 2 | 3 | Creating an [ipc.workqueue](workqueue.md) allows a thread to communicate with a set of 4 | worker threads created by [ipc.map](map.md). The queue is bidirectional between 5 | an owner thread and the workers. All native Lua types can be quickly marshaled 6 | by value across thread boundaries (a copy is made). 7 | Torch [Tensor](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor) 8 | and Storage objects are marshaled by reference (no copy is made). 9 | 10 | 11 | ```lua 12 | local q = ipc.workqueue([name, size, size_increment]) 13 | ``` 14 | 15 | The constructor takes the following optional arguments: 16 | * `name` a string identifying the queue within the process (defaults to nil); 17 | * `size` is the initial size in bytes of the workqueue (defaults to 1024*16); 18 | * `size_increment` is the size in bytes by which the workqueue is incremented when it runs out of memory (defaults to `size`). 19 | 20 | The two main methods are __:write()__ and __:read()__. Their usage depends 21 | on the perspective of the caller. From the owner thread's perspective, 22 | __:write()__ will put a *question* on the queue for one of the workers to process. 23 | Whenever the owner thread would like get the *answer* it can call __:read()__. 24 | From the perspective of the worker thread the functions are reversed. 25 | A worker can call __:read()__ to get the next *question* off the queue to process and it 26 | can call __:write()__ to return the answer to the owner thread. 27 | 28 | ```lua 29 | local ipc = require 'libipc' 30 | 31 | -- Create a named workqueue 32 | local q = ipc.workqueue('my queue') 33 | 34 | -- Create 2 background workers that read from the named workqueue 35 | local workers = ipc.map(2, function(threadId) 36 | -- the last argument passed to the function is the ID of the thread 37 | assert((threadId == 1) or (threadId == 2)) 38 | -- This function is not a closure, it is a totally clean Lua environment 39 | local ipc = require 'libipc' 40 | -- Open the queue by name (the main thread already created it) 41 | local q = ipc.workqueue('my queue') 42 | repeat 43 | -- Read the next file name off the workqueue 44 | local fileName = q:read() 45 | if fileName then 46 | -- Load the file and write its contents back into the workqueue 47 | q:write(torch.load(fileName)) 48 | end 49 | until fileName == nil 50 | end) 51 | 52 | -- Write the file names into the workqueue 53 | q:write('f1.t7') 54 | q:write('f2.t7') 55 | q:write('f3.t7') 56 | 57 | -- Read back the 3 answers and print them 58 | print(q:read()) 59 | print(q:read()) 60 | print(q:read()) 61 | ``` 62 | 63 | The owner thread can also do non-blocking reads from the queue. 64 | This is useful to poll for *answers* while doing other work. 65 | Passing *true* into __:read(true)__ will check the queue for 66 | an answer, if one is ready it will return it, else __:read(true)__ 67 | will return nil, indicating no *answer* is currently ready. 68 | 69 | ### `writeup()` 70 | 71 | Lua supports closures. These are functions with upvalues, i.e. non-global variables outside the scope of the function: 72 | 73 | ```lua 74 | local upvalue = 1 75 | local closure = function() 76 | return upvalue 77 | end 78 | ``` 79 | 80 | By default __:write()__ doesn't support closures. Calling __:write(closure)__ will throw an error. 81 | The reason for this is that we want to discourage users from serializing upvalues unless they absolutely need to. 82 | However, we do provide __:writeup()__ for serializing closures/upvalues. 83 | So calling __:writeup(closure)__ will not fail. 84 | 85 | ### Closure caveats 86 | 87 | In Lua, almost everything is an upvalues. 88 | For example, the `nn` variable is an upvalue in the following `heavyclosure()`: 89 | 90 | ```lua 91 | local nn = require 'nn' -- upvalue 92 | local heavyclosure = function(input) 93 | return nn.Linear:forward(input) 94 | end 95 | ``` 96 | 97 | Calling __:writeup(heavyclosure)__ will attempt to serialize the entire `nn` package. 98 | To avoid this kind of mistake, we recommend calling require from inside the closure: 99 | 100 | ```lua 101 | local lightfunction = function(input) 102 | local nn = require 'nn' 103 | return nn.Linear:forward(input) 104 | end 105 | ``` 106 | 107 | Calling __:write(lightfunction)__ will be much more efficient than calling __:write(heavyfunction)__. 108 | 109 | As a final note for the powerusers out there, know that _:writeup()__ does not serialize the `_ENV` upvalue of closures. 110 | Typically `_ENV = _G` in the writing thread, which would be too heavy to serialize. 111 | Instead we set the `_ENV` upvalue of the deserialized closure to the reading threads `_G`. 112 | So if you dont see any `_ENV` in your code, you should be fine. 113 | 114 | ### multi-write 115 | 116 | The __:write()__ and __:writeup()__ methods can be used to write multiple objects into the queue at once. 117 | For example, `q:write(1, 2, 3)` is equivalent to `q:write(1);q:write(2);q:write(3)`. 118 | As such, each argument passed to __:write()__ and __:writeup()__ 119 | will require their own `q:read()` to be read. 120 | 121 | 122 | A more concrete example of combining [ipc.map](map.md) and [ipc.workqueue](workqueue.md) 123 | can be found in [ipc.BackgroundTask](BackgroundTask.md) 124 | -------------------------------------------------------------------------------- /examples/allreduce.lua: -------------------------------------------------------------------------------- 1 | local opt = lapp [[ 2 | Options: 3 | -h,--host (default '127.0.0.1') host name of the server 4 | -p,--port (default 8080) port number of the server 5 | -n,--numNodes (default 1) number of nodes 6 | -x,--node (default 1) which node index is this? 7 | -b,--base (default 2) power of 2 base of the tree of nodes 8 | -d,--dimensions (default '1000,1000') comma delimited tensor dimensions 9 | -i,--iterations (default 1000) number of send/recv iterations 10 | --verify verify contents of transmission (slows things down) 11 | --verbose print lots of network stats 12 | --cuda use CUDA tensors 13 | ]] 14 | 15 | -- Load our requires 16 | local ipc = require 'libipc' 17 | local sys = require 'sys' 18 | local Tree = require 'ipc.Tree' 19 | 20 | -- Load cutorch if CUDA was requested 21 | if opt.cuda then 22 | print('loading cutorch...') 23 | local ok = pcall(require, 'cutorch') 24 | if ok then 25 | print('cutorch loaded ok.') 26 | end 27 | end 28 | 29 | -- Create a big tensor 30 | local dimensions = string.split(opt.dimensions, ",") 31 | for i = 1,#dimensions do 32 | dimensions[i] = tonumber(dimensions[i]) 33 | end 34 | local unpack = unpack or table.unpack 35 | local t0 = torch.randn(unpack(dimensions)):float() 36 | if opt.cuda then 37 | t0 = t0:cuda() 38 | end 39 | 40 | -- Create the tree of nodes 41 | local client,server 42 | if opt.node == 1 then 43 | server = ipc.server(opt.host, opt.port) 44 | server:clients(opt.numNodes - 1, function(client) end) 45 | else 46 | client = ipc.client(opt.host, opt.port) 47 | end 48 | local tree = Tree(opt.node, opt.numNodes, opt.base, server, client, opt.host, opt.port + opt.node) 49 | 50 | -- Iterate! 51 | sys.tic() 52 | for i = 1,opt.iterations do 53 | tree.allReduce(t0, function(a, b) return a:add(b) end) 54 | end 55 | print('did '..opt.iterations..' in '..sys.toc()..' seconds') 56 | if opt.verbose then 57 | tree.netStats() 58 | end 59 | -------------------------------------------------------------------------------- /examples/allreduce.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # run 4 nodes in a slow mode 4 | th allreduce.lua --numNodes 4 --node 1 --base 4 & 5 | th allreduce.lua --numNodes 4 --node 2 --base 4 & 6 | th allreduce.lua --numNodes 4 --node 3 --base 4 & 7 | th allreduce.lua --numNodes 4 --node 4 --base 4 & 8 | 9 | # wait for them all 10 | wait 11 | 12 | # run 4 nodes in a fast mode 13 | th allreduce.lua --numNodes 4 --node 1 & 14 | th allreduce.lua --numNodes 4 --node 2 & 15 | th allreduce.lua --numNodes 4 --node 3 & 16 | th allreduce.lua --numNodes 4 --node 4 & 17 | 18 | # wait for them all 19 | wait 20 | -------------------------------------------------------------------------------- /examples/allreduce.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=TorchIPC 4 | #SBATCH --nodes=2 5 | #SBATCH --ntasks-per-node=2 6 | 7 | srun th allreduce_slurm.lua 8 | 9 | wait 10 | -------------------------------------------------------------------------------- /examples/allreduce_slurm.lua: -------------------------------------------------------------------------------- 1 | local opt = lapp [[ 2 | Options: 3 | -d,--dimensions (default '1000,1000') comma delimited tensor dimensions 4 | -i,--iterations (default 1000) number of send/recv iterations 5 | --verify verify contents of transmission (slows things down) 6 | --verbose print lots of network stats 7 | --cuda use CUDA tensors 8 | ]] 9 | 10 | -- Load our requires 11 | local ipc = require 'libipc' 12 | local sys = require 'sys' 13 | -- Build the AllReduce tree 14 | local tree = require 'ipc.SlurmTree'() 15 | local node = tree.nodeIndex 16 | local numNodes = tree.numNodes 17 | local gpu = tree.gpu 18 | 19 | -- Requires 20 | if opt.cuda then 21 | print('loading cutorch...') 22 | local ok = pcall(require, 'cutorch') 23 | if ok then 24 | print('cutorch loaded ok.') 25 | end 26 | cutorch.setDevice(gpu) 27 | end 28 | 29 | -- Load cutorch if CUDA was requested 30 | if opt.cuda then 31 | print('loading cutorch...') 32 | local ok = pcall(require, 'cutorch') 33 | if ok then 34 | print('cutorch loaded ok.') 35 | end 36 | end 37 | 38 | -- Create a big tensor 39 | local dimensions = string.split(opt.dimensions, ",") 40 | for i = 1,#dimensions do 41 | dimensions[i] = tonumber(dimensions[i]) 42 | end 43 | local unpack = unpack or table.unpack 44 | local t0 = torch.randn(unpack(dimensions)):float() 45 | if opt.cuda then 46 | t0 = t0:cuda() 47 | end 48 | 49 | -- Iterate! 50 | sys.tic() 51 | for i = 1,opt.iterations do 52 | tree.allReduce(t0, function(a, b) return a:add(b) end) 53 | end 54 | print('Task '..node..' did '..opt.iterations..' in '..sys.toc()..' seconds') 55 | if opt.verbose then 56 | tree.netStats() 57 | end 58 | -------------------------------------------------------------------------------- /examples/allreduce_slurm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | sbatch allreduce.slurm 4 | -------------------------------------------------------------------------------- /examples/client-server.lua: -------------------------------------------------------------------------------- 1 | local opt = lapp [[ 2 | Options: 3 | -h,--host (default '127.0.0.1') host name of the server 4 | -p,--port (default 8080) port number of the server 5 | -s,--server (default 0) number of clients the server should expect to connect 6 | -d,--dimensions (default '1000,1000') comma delimited tensor dimensions 7 | -i,--iterations (default 1000) number of send/recv iterations 8 | --verify verify contents of transmission (slows things down) 9 | --verbose print lots of network stats 10 | --cuda use CUDA tensors 11 | ]] 12 | 13 | -- Load our requires 14 | local ipc = require 'libipc' 15 | local sys = require 'sys' 16 | 17 | -- Load cutorch if CUDA was requested 18 | if opt.cuda then 19 | print('loading cutorch...') 20 | local ok = pcall(require, 'cutorch') 21 | if ok then 22 | print('cutorch loaded ok.') 23 | end 24 | end 25 | 26 | -- Compute the total size of the tensor 27 | local total = 4 28 | local dimensions = string.split(opt.dimensions, ",") 29 | for i = 1,#dimensions do 30 | dimensions[i] = tonumber(dimensions[i]) 31 | total = total * dimensions[i] 32 | end 33 | 34 | -- Print the network performance so far 35 | local function printStat(n, t, i) 36 | if t > 0 then 37 | print('Did '..i..' '..n..' in '..t..' seconds') 38 | print('\t'..math.floor(i/t)..' ops per second') 39 | print('\t'..math.floor(i*total/(t*1024*1024))..' MB/s') 40 | end 41 | end 42 | 43 | if opt.server > 0 then 44 | local server = ipc.server(opt.host, opt.port) 45 | local unpack = unpack or table.unpack 46 | local t0 = torch.randn(unpack(dimensions)):float() 47 | if opt.cuda then 48 | t0 = t0:cuda() 49 | end 50 | local t1 = t0:clone() 51 | local t2 = t0:clone():mul(2) 52 | local totalSend = 0 53 | local totalRecv = 0 54 | for i = 1,opt.iterations do 55 | server:clients(opt.server, function(client) 56 | local s0 = sys.clock() 57 | client:send(t0) 58 | local received = client:recv() 59 | if received ~= "received" then 60 | error('bad received string') 61 | end 62 | local s1 = sys.clock() 63 | totalSend = totalSend + (s1 - s0) 64 | client:recv(t1) 65 | local s2 = sys.clock() 66 | totalRecv = totalRecv + (s2 - s1) 67 | if opt.verify then 68 | if torch.all(torch.eq(t1, t2)) == false then 69 | error('tensors dont match.') 70 | end 71 | end 72 | end) 73 | if i % 100 == 0 then 74 | printStat('sends', totalSend, i) 75 | printStat('recvs', totalRecv, i) 76 | if opt.verbose then 77 | print(server:netStats()) 78 | end 79 | end 80 | end 81 | server:close() 82 | else 83 | local client = ipc.client(opt.host, opt.port) 84 | local unpack = unpack or table.unpack 85 | local t0 = torch.randn(unpack(dimensions)):float() 86 | if opt.cuda then 87 | t0 = t0:cuda() 88 | end 89 | for i = 1,opt.iterations do 90 | client:recv(t0) 91 | if opt.verify then 92 | t0:mul(2) 93 | end 94 | client:send("received") 95 | client:send(t0) 96 | end 97 | client:close() 98 | end 99 | -------------------------------------------------------------------------------- /examples/client-server.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # run a server 4 | th client-server.lua --server 4 & 5 | 6 | # run 4 clients 7 | th client-server.lua & 8 | th client-server.lua & 9 | th client-server.lua & 10 | th client-server.lua & 11 | 12 | # wait for them all 13 | wait 14 | -------------------------------------------------------------------------------- /examples/map.lua: -------------------------------------------------------------------------------- 1 | local ipc = require 'libipc' 2 | 3 | -- Make 3 files 4 | torch.save('f1.t7', torch.randn(1, 2)) 5 | torch.save('f2.t7', torch.randn(2, 2)) 6 | torch.save('f3.t7', torch.randn(3, 2)) 7 | 8 | -- Load 3 files in ipc 9 | local t1,t2,t3 = ipc.map(3, function(fileNames, mapid) 10 | return torch.load(fileNames[mapid]) 11 | end, {'f1.t7', 'f2.t7', 'f3.t7'}):join() 12 | 13 | -- Show what we loaded 14 | print(t1) 15 | print(t2) 16 | print(t3) 17 | 18 | -- Cleanup 19 | os.execute('rm f1.t7') 20 | os.execute('rm f2.t7') 21 | os.execute('rm f3.t7') 22 | -------------------------------------------------------------------------------- /examples/model-parallelism.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | local ipc = require 'libipc' 3 | 4 | local layerFrameSizes = {20, 10, 5, 1} 5 | local datasetSize = 10000 6 | 7 | -- Generate data 8 | local inputSize = layerFrameSizes[1] 9 | local dataset = {} 10 | for i=1,datasetSize do 11 | local timesteps = 100+torch.random(6) 12 | table.insert(dataset, torch.rand(timesteps, inputSize)) 13 | end 14 | 15 | local function makeLayer(seq, inputFrameSize, outputFrameSize) 16 | local kW = 5 17 | seq:add(nn.TemporalConvolution(inputFrameSize, outputFrameSize, kW)) 18 | seq:add(nn.Tanh()) 19 | end 20 | 21 | local function pairseq(xs) 22 | return tablex.zip(tablex.sub(xs,1,-2), tablex.sub(xs,2,-1)) 23 | end 24 | 25 | -- Compute results for single-threaded model 26 | local single = nn.Sequential() 27 | for i,x in ipairs(pairseq(layerFrameSizes)) do 28 | local inputFrameSize = x[1] 29 | local outputFrameSize = x[2] 30 | makeLayer(single, inputFrameSize, outputFrameSize) 31 | end 32 | single:add(nn.Max(1)) 33 | print(single) 34 | 35 | sys.tic() 36 | local results = {} 37 | for i,x in ipairs(dataset) do 38 | table.insert(results, single:forward(x)[1]) 39 | end 40 | local t = sys.toc() 41 | print('single-threaded took '..t..'s') 42 | 43 | -- Compute results for multi-threaded model 44 | local stages = {} 45 | for i,x in ipairs(pairseq(layerFrameSizes)) do 46 | local inputFrameSize = x[1] 47 | local outputFrameSize = x[2] 48 | local layer = { 49 | layerType='nn.TemporalConvolution', 50 | inputFrameSize=inputFrameSize, 51 | outputFrameSize=outputFrameSize, 52 | kW=5, 53 | } 54 | table.insert(stages, {layer=layer, output=ipc.channel()}) 55 | end 56 | table.insert(stages, {layer={layerType='nn.Max',dimension=1}, output=ipc.channel()}) 57 | 58 | local input = ipc.channel() 59 | 60 | -- set up input channels in each stage 61 | for i,stage in ipairs(stages) do 62 | if i == 1 then 63 | stage.input = input 64 | else 65 | stage.input = stages[i-1].output 66 | end 67 | end 68 | 69 | -- Set stages up with workers 70 | local stageWorkers = {} 71 | for i,stage in ipairs(stages) do 72 | local worker = ipc.map(1, function(layerSpec, input, output) 73 | local nn = require 'nn' 74 | local ipc = require 'libipc' 75 | local layer 76 | if layerSpec.layerType == 'nn.Max' then 77 | layer = nn.Max(layerSpec.dimension) 78 | else 79 | layer = nn.TemporalConvolution( 80 | layerSpec.inputFrameSize, 81 | layerSpec.outputFrameSize, 82 | layerSpec.kW 83 | ) 84 | end 85 | while true do 86 | local nonblocking = false 87 | local status, x = input:read(nonblocking) 88 | if status == ipc.channel.DRAINED then 89 | output:close() 90 | break 91 | else 92 | output:write(layer:forward(x)) 93 | end 94 | end 95 | end, stage.layer, stage.input, stage.output) 96 | table.insert(stageWorkers, worker) 97 | end 98 | 99 | sys.tic() 100 | for i,x in ipairs(dataset) do 101 | input:write(x) 102 | end 103 | input:close() 104 | for i,worker in ipairs(stageWorkers) do 105 | worker:join() 106 | end 107 | local output = stages[#stages].output 108 | local multithreadedResults = {} 109 | while true do 110 | local nonblocking = false 111 | local status, x = output:read(nonblocking) 112 | if status == ipc.channel.DRAINED then 113 | break 114 | else 115 | table.insert(multithreadedResults, x[1]) 116 | end 117 | end 118 | local t = sys.toc() 119 | print('multi-threaded took '..t..'s') 120 | -------------------------------------------------------------------------------- /examples/workqueue.lua: -------------------------------------------------------------------------------- 1 | local ipc = require 'libipc' 2 | 3 | -- Make 3 files 4 | torch.save('f1.t7', torch.randn(1, 2)) 5 | torch.save('f2.t7', torch.randn(2, 2)) 6 | torch.save('f3.t7', torch.randn(3, 2)) 7 | 8 | -- Create a named workqueue 9 | local q = ipc.workqueue('my queue') 10 | 11 | -- Create 2 background workers that read from the named workqueue 12 | local workers = ipc.map(2, function() 13 | -- This function is not a closure, its a totally clean Lua environment 14 | local ipc = require 'libipc' 15 | -- Open the queue by name (the main thread already created it) 16 | local q = ipc.workqueue('my queue') 17 | repeat 18 | -- Read the next file name off the workqueue 19 | local fileName = q:read() 20 | if fileName then 21 | -- Load the file and write its contents back into the workqueue 22 | q:write(torch.load(fileName)) 23 | end 24 | until fileName == nil 25 | end) 26 | 27 | -- Write the file names into the workqueue 28 | q:write('f1.t7') 29 | q:write('f2.t7') 30 | q:write('f3.t7') 31 | 32 | -- Read back the 3 answers and print them 33 | print(q:read()) 34 | print(q:read()) 35 | print(q:read()) 36 | 37 | -- Write nil 2X to tell both workers to finish 38 | q:write(nil) 39 | q:write(nil) 40 | 41 | -- Wait for the workers to finish up 42 | workers:join() 43 | 44 | -- Shutdown the workqueue 45 | q:close() 46 | 47 | -- Cleanup 48 | os.execute('rm f1.t7') 49 | os.execute('rm f2.t7') 50 | os.execute('rm f3.t7') 51 | -------------------------------------------------------------------------------- /ipc-scm-1.rockspec: -------------------------------------------------------------------------------- 1 | package = "ipc" 2 | version = "scm-1" 3 | 4 | source = { 5 | url = "git://github.com/twitter/torch-ipc.git", 6 | } 7 | 8 | description = { 9 | summary = "A set of primitives for ipc computation in Torch", 10 | homepage = "-", 11 | license = "MIT" 12 | } 13 | 14 | dependencies = { 15 | "torch >= 7.0", 16 | "regress", 17 | } 18 | 19 | build = { 20 | type = "command", 21 | build_command = [[ 22 | cmake -E make_directory build; 23 | cd build; 24 | cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH="$(LUA_BINDIR)/.." -DCMAKE_INSTALL_PREFIX="$(PREFIX)" -DCMAKE_C_FLAGS=-fPIC -DCMAKE_CXX_FLAGS=-fPIC; 25 | $(MAKE) 26 | ]], 27 | install_command = "cd build && $(MAKE) install" 28 | } 29 | -------------------------------------------------------------------------------- /lua/BackgroundTask.lua: -------------------------------------------------------------------------------- 1 | local BackgroundTaskPool = require 'ipc.BackgroundTaskPool' 2 | 3 | local function BackgroundTask(func, ...) 4 | local pool = BackgroundTaskPool(1, { closeOnLastTask = true }) 5 | pool.addTask(func, ...) 6 | local function getResult() 7 | return pool.getResult(1) 8 | end 9 | return { 10 | isDone = pool.isDone, 11 | getResult = getResult, 12 | } 13 | end 14 | 15 | return BackgroundTask 16 | -------------------------------------------------------------------------------- /lua/BackgroundTaskPool.lua: -------------------------------------------------------------------------------- 1 | local ipc = require 'libipc' 2 | 3 | local function BackgroundTaskPool(poolSize, opt) 4 | 5 | -- Options 6 | opt = opt or { } 7 | local closeOnLastTask = opt.closeOnLastTask or false 8 | 9 | -- Keep track of some stuff 10 | local numTasks = 0 11 | local numResults = 0 12 | local successes = { } 13 | local failures = { } 14 | 15 | -- Create a shared queue with a random name 16 | local name = os.tmpname() 17 | local q = ipc.workqueue(name) 18 | 19 | -- Create a pool of workers 20 | local m = ipc.map(poolSize, function(name) 21 | local ipc = require 'libipc' 22 | local q = ipc.workqueue(name) 23 | while true do 24 | local task = q:read() 25 | if task then 26 | local unpack = unpack or table.unpack 27 | local ok,ret = pcall(function() return {task.func(unpack(task.args))} end) 28 | if ok then 29 | q:write({ id = task.id, success = ret }) 30 | else 31 | q:write({ id = task.id, failure = ret }) 32 | end 33 | else 34 | break 35 | end 36 | end 37 | end, name) 38 | 39 | -- Is this task already complete? 40 | local function hasResult(id) 41 | return id and (successes[id] or failures[id]) 42 | end 43 | 44 | -- Check if one or all the tasks are finished 45 | local function isDone(id, shouldBlock) 46 | -- Get all the completed tasks off the queue 47 | while numResults < numTasks and not hasResult(id) do 48 | local result = q:read(shouldBlock ~= true) 49 | if result then 50 | if result.failure then 51 | failures[result.id] = result.failure 52 | else 53 | successes[result.id] = result.success 54 | end 55 | numResults = numResults + 1 56 | else 57 | -- No more results pending 58 | break 59 | end 60 | end 61 | -- Did we see results for one or all? 62 | return numResults == numTasks or hasResult(id) 63 | end 64 | 65 | -- Add a task to the queue 66 | local function addTask(func, ...) 67 | assert(type(func) == 'function') 68 | local args = {...} 69 | -- Keep the queue moving by reading some results 70 | isDone() 71 | -- Add the new task 72 | numTasks = numTasks + 1 73 | q:write({ id = numTasks, func = func, args = args }) 74 | -- Return the task's id 75 | return numTasks 76 | end 77 | 78 | -- Is there a task in flight? 79 | local function hasTask() 80 | return numResults < numTasks 81 | end 82 | 83 | -- Get the result for one of the tasks 84 | local function getResult(id) 85 | -- Make sure the task is done 86 | isDone(id, true) 87 | -- If it is the last result then cleanup 88 | if closeOnLastTask and numResults == numTasks and m then 89 | -- Shutdown everything down 90 | for _ = 1,poolSize do 91 | q:write(nil) 92 | end 93 | m:join() 94 | m = nil 95 | end 96 | -- Return the task result 97 | id = id or numResults 98 | if successes[id] then 99 | local ret = successes[id] 100 | successes[id] = nil 101 | return (unpack or table.unpack)(ret) 102 | elseif failures[id] then 103 | local ret = failures[id] 104 | failures[id] = nil 105 | error(ret) 106 | end 107 | end 108 | 109 | return { 110 | addTask = addTask, 111 | hasTask = hasTask, 112 | isDone = isDone, 113 | getResult = getResult, 114 | } 115 | end 116 | 117 | return BackgroundTaskPool 118 | -------------------------------------------------------------------------------- /lua/DiscoveredTree.lua: -------------------------------------------------------------------------------- 1 | local ipc = require 'libipc' 2 | local Tree = require 'ipc.Tree' 3 | local NullTree = require 'ipc.NullTree' 4 | 5 | local function DiscoveredTree(nodeIndex, numNodes, nodeHost, nodePort, publish, query) 6 | if numNodes == 1 then 7 | return NullTree() 8 | end 9 | if nodeIndex == 1 then 10 | local server,nodePort = ipc.server(nodeHost, nodePort) 11 | publish(nodeHost, nodePort) 12 | return Tree(nodeIndex, numNodes, 2, server, nil, nodeHost, nodePort) 13 | else 14 | local rootHost,rootPort = query() 15 | local client = ipc.client(rootHost, rootPort) 16 | return Tree(nodeIndex, numNodes, 2, nil, client, nodeHost, nodePort) 17 | end 18 | end 19 | 20 | return DiscoveredTree 21 | -------------------------------------------------------------------------------- /lua/LocalhostTree.lua: -------------------------------------------------------------------------------- 1 | local DiscoveredTree = require 'ipc.DiscoveredTree' 2 | local ipc = require 'libipc' 3 | 4 | local function LocalhostTree(nodeIndex, numNodes, ppid) 5 | local fn = '/tmp/'..(ppid or ipc.getppid())..'.localhost' 6 | local function publish(host, port) 7 | local f = io.open(fn, 'w') 8 | f:write(host..':'..port) 9 | f:close() 10 | end 11 | local function query() 12 | while true do 13 | local f = io.open(fn, 'r') 14 | if f then 15 | local s = f:read('*all') 16 | if type(s) == 'string' then 17 | local p = s:split(':') 18 | if type(p) == 'table' and #p == 2 then 19 | return p[1], tonumber(p[2]) 20 | end 21 | end 22 | f:close() 23 | end 24 | end 25 | end 26 | return DiscoveredTree(nodeIndex, numNodes, '127.0.0.1', nil, publish, query) 27 | end 28 | 29 | return LocalhostTree 30 | -------------------------------------------------------------------------------- /lua/NullTree.lua: -------------------------------------------------------------------------------- 1 | local walkTable = require 'ipc.utils'.walkTable 2 | 3 | local function NullTree() 4 | return { 5 | nodeIndex = 1, 6 | numNodes = 1, 7 | walkTable = walkTable, 8 | allReduce = function(value) return value, 1 end, 9 | scatter = function(value) return value end, 10 | netStats = function() end, 11 | } 12 | end 13 | 14 | return NullTree 15 | -------------------------------------------------------------------------------- /lua/SlurmTree.lua: -------------------------------------------------------------------------------- 1 | local ipc = require 'libipc' 2 | local Tree = require 'ipc.Tree' 3 | local NullTree = require 'ipc.NullTree' 4 | 5 | local function SlurmTree(fn, tasksPerGpu) 6 | tasksPerGpu = tasksPerGpu or 1 7 | local slurmProcId = tonumber(os.getenv("SLURM_PROCID")) 8 | local numNodes = tonumber(os.getenv("SLURM_NTASKS")) 9 | local slurmNNodes = tonumber(os.getenv("SLURM_JOB_NUM_NODES")) 10 | local tasksPerHost = math.ceil(numNodes / slurmNNodes) 11 | local nodeIndex = slurmProcId + 1 12 | 13 | fn = fn or os.getenv("HOME")..'/.torch' 14 | fpath = fn..'/slurm.'..os.getenv("SLURM_JOBID")..'.server' 15 | local function publish(host, port) 16 | os.execute('mkdir -p '..fn) 17 | local f = io.open(fpath, 'w') 18 | f:write(host..':'..port) 19 | f:close() 20 | end 21 | local function query() 22 | while true do 23 | local f = io.open(fpath, 'r') 24 | if f then 25 | local s = f:read('*all') 26 | if type(s) == 'string' then 27 | local p = s:split(':') 28 | if type(p) == 'table' and #p == 2 then 29 | return p[1], tonumber(p[2]) 30 | end 31 | end 32 | f:close() 33 | end 34 | end 35 | end 36 | local function rcsvAllPairs(base, numNodes, index, depth, linkFunc) 37 | local function link(a, b, d) 38 | if a <= numNodes and b <= numNodes then 39 | linkFunc(a, b, d) 40 | end 41 | end 42 | if depth == 0 then 43 | local skip = math.pow(base, depth + 1) 44 | for j = index + 2, index + skip do 45 | link(index + 1, j, depth) 46 | end 47 | else 48 | local skip = math.pow(base, depth) 49 | link(index + 1, index + skip + 1, depth) 50 | for c = 0, base - 1 do 51 | rcsvAllPairs(base, numNodes, index + (c * skip), depth - 1, linkFunc) 52 | end 53 | end 54 | end 55 | -- cluster the tasks on each host then build tree between hosts 56 | local function buildTree(base, numNodes, index, depth, linkFunc) 57 | local numHosts = math.ceil(numNodes / tasksPerHost) 58 | --build local trees between tasks on each host 59 | rcsvAllPairs(base, tasksPerHost, index, depth, function(to, from) 60 | for hostIdx = 0, numHosts-1 do 61 | local startIdx = hostIdx * tasksPerHost 62 | linkFunc(startIdx + to, startIdx + from) 63 | end 64 | end) 65 | --build tree between the primary tasks on each host 66 | rcsvAllPairs(base, numHosts, index, depth, function(to, from) 67 | linkFunc((to-1) * tasksPerHost + 1, (from-1) * tasksPerHost + 1) 68 | end) 69 | end 70 | 71 | local tree = nil 72 | if numNodes == 1 then 73 | tree = NullTree() 74 | else 75 | local nodeHost = sys.execute('/bin/hostname') 76 | local nodePort = nil 77 | if nodeIndex == 1 then 78 | local server,nodePort = ipc.server(nodeHost, nodePort) 79 | publish(nodeHost, nodePort) 80 | tree = Tree(nodeIndex, numNodes, 2, server, nil, nodeHost, nodePort, buildTree) 81 | else 82 | local rootHost,rootPort = query() 83 | local client = ipc.client(rootHost, rootPort) 84 | tree = Tree(nodeIndex, numNodes, 2, nil, client, nodeHost, nodePort, buildTree) 85 | end 86 | end 87 | tree['gpu'] = math.floor((slurmProcId % tasksPerHost) / tasksPerGpu) + 1 88 | return tree 89 | end 90 | 91 | return SlurmTree 92 | -------------------------------------------------------------------------------- /lua/StaticTree.lua: -------------------------------------------------------------------------------- 1 | local ipc = require 'libipc' 2 | local Tree = require 'ipc.Tree' 3 | local NullTree = require 'ipc.NullTree' 4 | 5 | local function StaticTree(nodeIndex, numNodes, nodeHost, nodePort, rootHost, rootPort) 6 | if numNodes == 1 then 7 | return NullTree() 8 | end 9 | if nodeIndex == 1 then 10 | local server = ipc.server(nodeHost, nodePort) 11 | return Tree(nodeIndex, numNodes, 2, server, nil, nodeHost, nodePort) 12 | else 13 | local client = ipc.client(rootHost, rootPort) 14 | return Tree(nodeIndex, numNodes, 2, nil, client, nodeHost, nodePort) 15 | end 16 | end 17 | 18 | return StaticTree 19 | -------------------------------------------------------------------------------- /lua/Tree.lua: -------------------------------------------------------------------------------- 1 | local ipc = require 'libipc' 2 | local walkTable = require 'ipc.utils'.walkTable 3 | 4 | local function rcsvAllPairs(base, numNodes, index, depth, linkFunc) 5 | local function link(a, b, d) 6 | if a <= numNodes and b <= numNodes then 7 | linkFunc(a, b, d) 8 | end 9 | end 10 | if depth == 0 then 11 | local skip = math.pow(base, depth + 1) 12 | for j = index + 2, index + skip do 13 | link(index + 1, j, depth) 14 | end 15 | else 16 | local skip = math.pow(base, depth) 17 | link(index + 1, index + skip + 1, depth) 18 | for c = 0, base - 1 do 19 | rcsvAllPairs(base, numNodes, index + (c * skip), depth - 1, linkFunc) 20 | end 21 | end 22 | end 23 | 24 | local function Tree(nodeIndex, numNodes, base, server, client, host, port, buildTree) 25 | buildTree = buildTree or rcsvAllPairs 26 | 27 | local maxDepth = math.ceil(math.log(numNodes) / math.log(base)) 28 | 29 | local function initialServer() 30 | -- Get every node's address and nodeIndex 31 | local addresses = { } 32 | addresses[nodeIndex] = { 33 | nodeIndex = nodeIndex, 34 | host = host, 35 | port = port, 36 | } 37 | server:clients(numNodes - 1, function(client) 38 | client:send({ q = "address?" }) 39 | local msg = client:recv() 40 | assert(msg.q == "address") 41 | local clientNodeIndex = msg.nodeIndex or #addresses + 1 42 | addresses[clientNodeIndex] = { 43 | nodeIndex = clientNodeIndex, 44 | host = msg.host, 45 | port = msg.port 46 | } 47 | client:send({ 48 | q = "clientIndex", 49 | clientIndex = clientNodeIndex 50 | }) 51 | end) 52 | -- Build a tree of connections to establish 53 | local tree = { } 54 | buildTree(base, numNodes, 0, maxDepth - 1, function(to, from, depth) 55 | tree[from] = tree[from] or { } 56 | tree[from].connect = addresses[to] 57 | tree[to] = tree[to] or { } 58 | tree[to].listen = (tree[to].listen or 0) + 1 59 | end) 60 | -- Broadcast the tree of connections all nodes 61 | server:broadcast(tree) 62 | -- Order the nodes that stay connected to this server 63 | server:clients(function(client) 64 | -- Expect some clients to disconnect 65 | local ok,msg = pcall(function() 66 | client:send("order?") 67 | return client:recv() 68 | end) 69 | if ok then 70 | client:id(msg.order) 71 | else 72 | -- Node has a new parent, so close it 73 | client:close() 74 | end 75 | end) 76 | -- Make sure the entire tree is ready for action 77 | server:clients(function(client) 78 | client:send("start?") 79 | assert(client:recv() == "start") 80 | end) 81 | end 82 | 83 | local function initialClient() 84 | -- Open a new server, we may end up a parent (reuse the same server upvalue) 85 | server, port = ipc.server(host, tonumber(port)) 86 | -- Register our address and nodeIndex 87 | local msg = client:recv() 88 | assert(msg.q == "address?") 89 | client:send({ 90 | q = "address", 91 | nodeIndex = nodeIndex, 92 | host = host, 93 | port = port 94 | }) 95 | msg = client:recv() 96 | assert(msg.q == "clientIndex") 97 | nodeIndex = msg.clientIndex 98 | -- Get the tree of connections 99 | local tree = client:recv() 100 | local node = tree[nodeIndex] 101 | if node.listen and node.listen > 0 then 102 | -- If we are a parent, connect the children and order them 103 | server:clients(node.listen, function(client) 104 | client:send("order?") 105 | local msg = client:recv() 106 | client:id(msg.order) 107 | end) 108 | else 109 | -- Just a leaf 110 | server:close() 111 | server = nil 112 | end 113 | if node.connect then 114 | if node.connect.nodeIndex == 1 then 115 | -- Already connnected to our parent 116 | assert(client:recv() == "order?") 117 | client:send({ 118 | order = nodeIndex, 119 | }) 120 | else 121 | -- A new parent is required (reuse the same client upvalue) 122 | client:close() 123 | client = ipc.client(node.connect.host, node.connect.port) 124 | assert(client:recv() == "order?") 125 | client:send({ 126 | order = nodeIndex, 127 | }) 128 | end 129 | end 130 | -- Wait for the start message 131 | assert(client:recv() == "start?") 132 | if server then 133 | -- Get our subtree ready to start 134 | server:clients(function(client) 135 | client:send("start?") 136 | assert(client:recv() == "start") 137 | end) 138 | end 139 | -- This subtree is ready 140 | client:send("start") 141 | end 142 | 143 | -- Establish the tree structure 144 | if server then 145 | initialServer() 146 | else 147 | initialClient() 148 | end 149 | 150 | -- We need temp space to receive tensors 151 | local tempValue 152 | local function getTempValue(value) 153 | if torch.isTensor(value) then 154 | if tempValue then 155 | tempValue = tempValue:typeAs(value):resizeAs(value) 156 | else 157 | tempValue = value:clone() 158 | end 159 | return tempValue 160 | end 161 | end 162 | 163 | -- Not the prettiest but it conserves memory when 164 | -- ending the allReduce on uneven # of steps per node 165 | local lastValue 166 | 167 | local function allReduceInner(value, reduce, zero) 168 | -- Handle uneven endings 169 | if zero then 170 | -- Restore the last value if a zero function is supplied 171 | value = (value[1] ~= nil and value) or lastValue 172 | local i = 0 173 | walkTable(value, function(valuei) 174 | i = i + 1 175 | return zero(valuei, i) 176 | end) 177 | else 178 | -- Save the last value we saw (for uneven ending) 179 | lastValue = value 180 | end 181 | -- Keep track of the number of done nodes 182 | local numDone = zero and 1 or 0 183 | -- Reduce the value up to the root 184 | if server then 185 | -- Recv from the shortest branch first 186 | server:clients(function(client) 187 | walkTable(value, function(valuei) 188 | return reduce(valuei, client:recv(getTempValue(valuei))) 189 | end) 190 | numDone = numDone + client:recv() 191 | end) 192 | end 193 | if client then 194 | walkTable(value, function(valuei) 195 | client:send(valuei) 196 | end) 197 | client:send(numDone) 198 | end 199 | -- Map the root value back down the tree 200 | if client then 201 | walkTable(value, function(valuei) 202 | return client:recv(valuei) 203 | end) 204 | numDone = client:recv() 205 | end 206 | if server then 207 | -- Send the longest branch first 208 | server:clients(function(client) 209 | walkTable(value, function(valuei) 210 | client:send(valuei) 211 | end) 212 | client:send(numDone) 213 | end, 1) -- Magic bit to invert the client order (longest branch first) 214 | end 215 | if zero and numDone < numNodes then 216 | -- If we are done, but not everyone else is, then do it again 217 | return allReduceInner(value, reduce, zero) 218 | else 219 | -- Return the final value and how many nodes contributed 220 | return value, numNodes - numDone 221 | end 222 | end 223 | 224 | -- Classic MPI style all reduce (reduce where all nodes get the final value) 225 | local function allReduce(value, reduce, zero) 226 | -- Support tables of values (as multiple sequential transfers) 227 | local isTable = type(value) == 'table' 228 | value = (isTable and value) or { value } 229 | local finalValue, numNodes = allReduceInner(value, reduce, zero) 230 | return (isTable and finalValue) or finalValue[1], numNodes 231 | end 232 | 233 | -- Classic MPI style scatter (root value to all nodes) 234 | local function scatter(value) 235 | -- Support tables of tensors 236 | local isTable = type(value) == 'table' 237 | value = (isTable and value) or { value } 238 | -- Map the root value back down the tree 239 | if client then 240 | walkTable(value, function(valuei) 241 | return client:recv(valuei) 242 | end) 243 | end 244 | if server then 245 | -- Send the longest branch first 246 | server:clients(function(client) 247 | walkTable(value, function(valuei) 248 | client:send(valuei) 249 | end) 250 | end, 1) -- Magic bit to invert the client order (longest branch first) 251 | end 252 | return (isTable and value) or value[1] 253 | end 254 | 255 | -- Handy debug info on network performance 256 | local function netStats() 257 | if server then 258 | print(server:netStats()) 259 | end 260 | if client then 261 | print(client:netStats()) 262 | end 263 | end 264 | 265 | return { 266 | nodeIndex = nodeIndex, 267 | numNodes = numNodes, 268 | walkTable = walkTable, 269 | allReduce = allReduce, 270 | scatter = scatter, 271 | netStats = netStats, 272 | } 273 | end 274 | 275 | return Tree 276 | -------------------------------------------------------------------------------- /lua/utils.lua: -------------------------------------------------------------------------------- 1 | 2 | -- Walk a table in a deterministic order 3 | local function walkTable(t, f) 4 | local kk = { } 5 | for k,_ in pairs(t) do 6 | table.insert(kk, k) 7 | end 8 | table.sort(kk) 9 | for _,k in ipairs(kk) do 10 | local tk = t[k] 11 | if type(tk) == 'table' then 12 | walkTable(tk, f) 13 | else 14 | local tk1 = f(tk) 15 | if tk1 ~= nil then 16 | t[k] = tk1 17 | end 18 | end 19 | end 20 | end 21 | 22 | return { 23 | walkTable = walkTable 24 | } 25 | -------------------------------------------------------------------------------- /src/channel.c: -------------------------------------------------------------------------------- 1 | #include "TH.h" 2 | #include "luaT.h" 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "ringbuffer.h" 9 | #include "serialize.h" 10 | #include "error.h" 11 | #include "channel.h" 12 | 13 | #define DEFAULT_CHANNEL_SIZE (16*1024) 14 | 15 | #define CHANNEL_VERBOSE (0) 16 | 17 | typedef struct channel_t { 18 | struct ringbuffer_t* rb; 19 | pthread_mutex_t mutex; 20 | pthread_cond_t read_avail_cond; 21 | int closed; 22 | int drained; 23 | uint32_t num_items; 24 | int refcount; 25 | size_t size_increment; 26 | } channel_t; 27 | 28 | static void channel_init_queue(channel_t *channel, size_t size) { 29 | // init queue mutex 30 | pthread_mutexattr_t mutex_attr; 31 | pthread_mutexattr_init(&mutex_attr); 32 | pthread_mutexattr_settype(&mutex_attr, PTHREAD_MUTEX_RECURSIVE); 33 | pthread_mutex_init(&channel->mutex, &mutex_attr); 34 | 35 | // init condition variables 36 | pthread_cond_init(&channel->read_avail_cond, NULL); 37 | 38 | // init ring buffer 39 | channel->rb = ringbuffer_create(size); 40 | } 41 | 42 | int channel_create(lua_State *L) { 43 | channel_t *channel = calloc(1, sizeof(channel_t)); 44 | channel->refcount = 1; 45 | channel->closed = 0; 46 | channel->drained = 0; 47 | channel_t **ud = (channel_t **)lua_newuserdata(L, sizeof(channel_t*)); 48 | channel_init_queue(channel, DEFAULT_CHANNEL_SIZE); 49 | channel->size_increment = DEFAULT_CHANNEL_SIZE; 50 | *ud = channel; 51 | luaL_getmetatable(L, "ipc.channel"); 52 | lua_setmetatable(L, -2); 53 | return 1; 54 | } 55 | 56 | int channel_close(lua_State *L) { 57 | channel_t *channel = *(channel_t **)lua_touserdata(L, 1); 58 | if (!channel) return LUA_HANDLE_ERROR_STR(L, "invalid channel"); 59 | pthread_mutex_lock(&channel->mutex); 60 | if (!channel->closed) { 61 | channel->closed = 1; 62 | if (channel->num_items == 0) { 63 | channel->drained = 1; 64 | } 65 | pthread_cond_broadcast(&channel->read_avail_cond); 66 | } 67 | pthread_mutex_unlock(&channel->mutex); 68 | return 0; 69 | } 70 | 71 | int channel_closed(lua_State *L) { 72 | channel_t *channel = *(channel_t **)lua_touserdata(L, 1); 73 | if (!channel) return LUA_HANDLE_ERROR_STR(L, "invalid channel"); 74 | pthread_mutex_lock(&channel->mutex); 75 | lua_pushboolean(L, channel->closed); 76 | pthread_mutex_unlock(&channel->mutex); 77 | return 1; 78 | } 79 | 80 | int channel_drained(lua_State *L) { 81 | channel_t *channel = *(channel_t **)lua_touserdata(L, 1); 82 | if (!channel) return LUA_HANDLE_ERROR_STR(L, "invalid channel"); 83 | pthread_mutex_lock(&channel->mutex); 84 | lua_pushboolean(L, channel->drained); 85 | pthread_mutex_unlock(&channel->mutex); 86 | return 1; 87 | } 88 | 89 | int channel_read(lua_State *L) { 90 | channel_t *channel = *(channel_t **)lua_touserdata(L, 1); 91 | if (!channel) return LUA_HANDLE_ERROR_STR(L, "invalid channel"); 92 | int doNotBlock = luaT_optboolean(L, 2, 0); 93 | pthread_mutex_lock(&channel->mutex); 94 | while (1) { 95 | if (channel->num_items) { 96 | if (channel->closed && channel->num_items == 1) { 97 | channel->drained = 1; 98 | pthread_cond_broadcast(&channel->read_avail_cond); 99 | } 100 | if (channel->closed) { 101 | lua_pushinteger(L, STATUS_CLOSED); 102 | } else { 103 | lua_pushinteger(L, STATUS_OPEN); 104 | } 105 | int ret = rb_load(L, channel->rb); 106 | channel->num_items--; 107 | pthread_mutex_unlock(&channel->mutex); 108 | if (ret < 0) return LUA_HANDLE_ERROR(L, ret); 109 | return ret + 1; 110 | } else if (channel->drained) { 111 | pthread_mutex_unlock(&channel->mutex); 112 | lua_pushinteger(L, STATUS_DRAINED); 113 | return 1; 114 | } else if (doNotBlock) { 115 | break; 116 | } else { 117 | pthread_cond_wait(&channel->read_avail_cond, &channel->mutex); 118 | } 119 | } 120 | if (channel->drained) { 121 | lua_pushinteger(L, STATUS_DRAINED); 122 | } else if (channel->closed) { 123 | lua_pushinteger(L, STATUS_CLOSED); 124 | } else { 125 | lua_pushinteger(L, STATUS_OPEN); 126 | } 127 | pthread_mutex_unlock(&channel->mutex); 128 | return 1; 129 | } 130 | 131 | // TODO: Blocking writes should also be supported. This allows for 132 | // backpressure to work. The current implementation grows the 133 | // underlying ringbuffer if it is full. 134 | int channel_write(lua_State *L) { 135 | channel_t *channel = *(channel_t **)lua_touserdata(L, 1); 136 | if (!channel) return LUA_HANDLE_ERROR_STR(L, "invalid channel"); 137 | pthread_mutex_lock(&channel->mutex); 138 | if (channel->drained) { 139 | lua_pushinteger(L, STATUS_DRAINED); 140 | pthread_mutex_unlock(&channel->mutex); 141 | return 1; 142 | } else if (channel->closed) { 143 | lua_pushinteger(L, STATUS_CLOSED); 144 | pthread_mutex_unlock(&channel->mutex); 145 | return 1; 146 | } 147 | int upval = 0; 148 | int index = 2; 149 | int top = lua_gettop(L); 150 | while (index <= top) { 151 | ringbuffer_push_write_pos(channel->rb); 152 | int ret = rb_save(L, index, channel->rb, 0, upval); 153 | if (ret == -ENOMEM) { 154 | ringbuffer_pop_write_pos(channel->rb); 155 | ringbuffer_grow_by(channel->rb, channel->size_increment); 156 | #if CHANNEL_VERBOSE 157 | fprintf(stderr, "INFO: ipc.channel grew to %zu bytes\n", channel->rb->cb); 158 | #endif 159 | } else if (ret) { 160 | ringbuffer_pop_write_pos(channel->rb); 161 | pthread_mutex_unlock(&channel->mutex); 162 | return LUA_HANDLE_ERROR(L, -ret); 163 | } else { 164 | index++; 165 | channel->num_items++; 166 | } 167 | } 168 | pthread_cond_signal(&channel->read_avail_cond); 169 | lua_pushinteger(L, STATUS_OPEN); 170 | pthread_mutex_unlock(&channel->mutex); 171 | return 1; 172 | } 173 | 174 | int channel_num_items(lua_State *L) { 175 | channel_t *channel = *(channel_t **)lua_touserdata(L, 1); 176 | if (!channel) return LUA_HANDLE_ERROR_STR(L, "invalid channel"); 177 | pthread_mutex_lock(&channel->mutex); 178 | lua_pushinteger(L, channel->num_items); 179 | pthread_mutex_unlock(&channel->mutex); 180 | return 1; 181 | } 182 | 183 | int channel_gc(lua_State *L) { 184 | channel_t **ud = (channel_t **)lua_touserdata(L, 1); 185 | channel_t *channel = *ud; 186 | if (!channel) return LUA_HANDLE_ERROR_STR(L, "invalid channel"); 187 | pthread_mutex_lock(&channel->mutex); 188 | if (THAtomicDecrementRef(&channel->refcount)) { 189 | pthread_cond_destroy(&channel->read_avail_cond); 190 | ringbuffer_destroy(channel->rb); 191 | pthread_mutex_unlock(&channel->mutex); 192 | pthread_mutex_destroy(&channel->mutex); 193 | free(channel); 194 | *ud = NULL; 195 | } else { 196 | pthread_mutex_unlock(&channel->mutex); 197 | } 198 | return 0; 199 | } 200 | 201 | int channel_retain(lua_State *L) { 202 | channel_t *channel = *(channel_t **)lua_touserdata(L, 1); 203 | if (!channel) return LUA_HANDLE_ERROR_STR(L, "invalid channel"); 204 | pthread_mutex_lock(&channel->mutex); 205 | THAtomicIncrementRef(&channel->refcount); 206 | pthread_mutex_unlock(&channel->mutex); 207 | return 0; 208 | } 209 | 210 | int channel_metatablename(lua_State *L) { 211 | lua_pushstring(L, "ipc.channel"); 212 | return 1; 213 | } 214 | -------------------------------------------------------------------------------- /src/channel.h: -------------------------------------------------------------------------------- 1 | #ifndef _CHANNEL_H_ 2 | #define _CHANNEL_H_ 3 | 4 | #include "luaT.h" 5 | 6 | #define STATUS_OPEN 0 7 | #define STATUS_CLOSED 1 8 | #define STATUS_DRAINED 2 9 | 10 | int channel_create(lua_State *L); 11 | int channel_close(lua_State *L); 12 | int channel_closed(lua_State *L); 13 | int channel_drained(lua_State *L); 14 | int channel_read(lua_State *L); 15 | int channel_write(lua_State *L); 16 | int channel_num_items(lua_State *L); 17 | int channel_gc(lua_State *L); 18 | int channel_retain(lua_State *L); 19 | int channel_metatablename(lua_State *L); 20 | 21 | #endif 22 | -------------------------------------------------------------------------------- /src/cliser.h: -------------------------------------------------------------------------------- 1 | #ifndef _CLISER_H_ 2 | #define _CLISER_H_ 3 | 4 | #include "luaT.h" 5 | 6 | int cliser_server(lua_State *L); 7 | int cliser_server_close(lua_State *L); 8 | int cliser_server_clients(lua_State *L); 9 | int cliser_server_tag(lua_State *L); 10 | int cliser_server_id(lua_State *L); 11 | int cliser_server_client_close(lua_State *L); 12 | int cliser_server_client_address(lua_State *L); 13 | int cliser_server_broadcast(lua_State *L); 14 | int cliser_server_recv_any(lua_State *L); 15 | int cliser_server_send(lua_State *L); 16 | int cliser_server_recv(lua_State *L); 17 | int cliser_server_net_stats(lua_State *L); 18 | 19 | int cliser_client(lua_State *L); 20 | int cliser_client_close(lua_State *L); 21 | int cliser_client_send(lua_State *L); 22 | int cliser_client_recv(lua_State *L); 23 | int cliser_client_recv_async(lua_State *L); 24 | int cliser_client_retain(lua_State *L); 25 | int cliser_client_metatablename(lua_State *L); 26 | int cliser_client_net_stats(lua_State *L); 27 | 28 | void Lcliser_CharInit(lua_State *L); 29 | void Lcliser_ByteInit(lua_State *L); 30 | void Lcliser_ShortInit(lua_State *L); 31 | void Lcliser_IntInit(lua_State *L); 32 | void Lcliser_LongInit(lua_State *L); 33 | void Lcliser_FloatInit(lua_State *L); 34 | void Lcliser_DoubleInit(lua_State *L); 35 | #ifdef USE_CUDA 36 | void Lcliser_CudaInit(lua_State *L); 37 | #endif 38 | 39 | #endif 40 | -------------------------------------------------------------------------------- /src/error.h: -------------------------------------------------------------------------------- 1 | #ifndef _ERROR_H_ 2 | #define _ERROR_H_ 3 | 4 | #include "luaT.h" 5 | #include 6 | 7 | static inline int _lua_error(lua_State *L, int ret, const char* file, int line) { 8 | int pos_ret = ret >= 0 ? ret : -ret; 9 | return luaL_error(L, "ERROR: (%s, %d): (%d, %s)\n", file, line, pos_ret, strerror(pos_ret)); 10 | } 11 | 12 | static inline int _lua_error_str(lua_State *L, const char *str, const char* file, int line) { 13 | return luaL_error(L, "ERROR: (%s, %d): (%s)\n", file, line, str); 14 | } 15 | 16 | #define LUA_HANDLE_ERROR(L, ret) _lua_error(L, ret, __FILE__, __LINE__) 17 | #define LUA_HANDLE_ERROR_STR(L, str) _lua_error_str(L, str, __FILE__, __LINE__) 18 | 19 | #endif 20 | -------------------------------------------------------------------------------- /src/flock.c: -------------------------------------------------------------------------------- 1 | #include "flock.h" 2 | #include 3 | #include 4 | #include 5 | #include "error.h" 6 | 7 | #ifndef O_CLOEXEC 8 | #define O_CLOEXEC (0) 9 | #endif 10 | 11 | int flock_open(lua_State *L) { 12 | const char *file_name = lua_tostring(L, 1); 13 | int no_block = luaT_optboolean(L, 2, 0); 14 | int flags = O_CLOEXEC | O_RDWR; 15 | if (!no_block) { 16 | flags |= O_CREAT; 17 | } 18 | int fd = open(file_name, flags, S_IRUSR | S_IWUSR); 19 | if (fd < 0) { 20 | if (errno == ENOENT || errno == EACCES) return 0; 21 | return LUA_HANDLE_ERROR(L, errno); 22 | } 23 | flags = LOCK_EX; 24 | if (no_block) { 25 | flags |= LOCK_NB; 26 | } 27 | int ret = flock(fd, flags); 28 | if (ret < 0) { 29 | close(fd); 30 | if ((flags & LOCK_NB) && (errno == EWOULDBLOCK)) return 0; 31 | return LUA_HANDLE_ERROR(L, errno); 32 | } 33 | int *handle = lua_newuserdata(L, sizeof(int)); 34 | *handle = fd; 35 | luaL_getmetatable(L, "ipc.flock"); 36 | lua_setmetatable(L, -2); 37 | return 1; 38 | } 39 | 40 | int flock_close(lua_State *L) { 41 | int *handle = lua_touserdata(L, 1); 42 | int fd = *handle; 43 | if (fd) { 44 | int ret = flock(fd, LOCK_UN); 45 | close(fd); 46 | *handle = 0; 47 | if (ret < 0) return LUA_HANDLE_ERROR(L, errno); 48 | } 49 | return 0; 50 | } 51 | -------------------------------------------------------------------------------- /src/flock.h: -------------------------------------------------------------------------------- 1 | #ifndef _FLOCK_H_ 2 | #define _FLOCK_H_ 3 | 4 | #include "luaT.h" 5 | 6 | int flock_open(lua_State *L); 7 | int flock_close(lua_State *L); 8 | int flock_read(lua_State *L); 9 | int flock_write(lua_State *L); 10 | int flock_truncate(lua_State *L); 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /src/ipc.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "luaT.h" 9 | #include "workqueue.h" 10 | #include "cliser.h" 11 | #include "map.h" 12 | #include "error.h" 13 | #include "spawn.h" 14 | #include "flock.h" 15 | #include "mutex.h" 16 | #include "sharedtable.h" 17 | #include "marshal.h" 18 | #include "channel.h" 19 | 20 | int ipc_getpid(lua_State *L) { 21 | pid_t pid = getpid(); 22 | lua_pushinteger(L, pid); 23 | return 1; 24 | } 25 | 26 | int ipc_getppid(lua_State *L) { 27 | pid_t pid = getppid(); 28 | lua_pushinteger(L, pid); 29 | return 1; 30 | } 31 | 32 | int ipc_gettid(lua_State *L) { 33 | pthread_t tid = pthread_self(); 34 | lua_pushinteger(L, (intptr_t)tid); 35 | return 1; 36 | } 37 | 38 | int ipc_fork(lua_State *L) { 39 | pid_t pid = fork(); 40 | lua_pushinteger(L, pid); 41 | return 1; 42 | } 43 | 44 | int ipc_waitpid(lua_State *L) { 45 | int status; 46 | pid_t pid = lua_tointeger(L, 1); 47 | do { 48 | int ret = waitpid(pid, &status, WUNTRACED | WCONTINUED); 49 | if (ret < 0) { 50 | return LUA_HANDLE_ERROR(L, errno); 51 | } 52 | if (WIFEXITED(status)) { 53 | lua_pushinteger(L, WEXITSTATUS(status)); 54 | return 1; 55 | } 56 | } while (!WIFEXITED(status) && !WIFSIGNALED(status)); 57 | return 0; 58 | } 59 | 60 | int ipc_is_osx(lua_State *L) { 61 | #ifdef __APPLE__ 62 | lua_pushboolean(L, 1); 63 | return 1; 64 | #else 65 | (void)L; 66 | return 0; 67 | #endif 68 | } 69 | 70 | int ipc_is_devel(lua_State *L) { 71 | #ifdef __APPLE__ 72 | int is_devel = 1; 73 | #else 74 | char *devel_mode = getenv("CKOIA_DEVEL_MODE"); 75 | int is_devel = devel_mode && devel_mode[0] == '1'; 76 | #endif 77 | lua_pushboolean(L, is_devel); 78 | return 1; 79 | } 80 | 81 | static const struct luaL_Reg ipc_routines[] = { 82 | {"isOSX", ipc_is_osx}, 83 | {"workqueue", workqueue_open}, 84 | {"server", cliser_server}, 85 | {"client", cliser_client}, 86 | {"getpid", ipc_getpid}, 87 | {"getppid", ipc_getppid}, 88 | {"gettid", ipc_gettid}, 89 | {"fork", ipc_fork}, 90 | {"waitpid", ipc_waitpid}, 91 | {"map", map_open}, 92 | {"map_extended", map_extended_open}, 93 | {"spawn", spawn_open}, 94 | {"flock", flock_open}, 95 | {"mutex", mutex_create}, 96 | {"sharedtable", sharedtable_create}, 97 | {"sharedtable_size", sharedtable_size}, 98 | {"marshal", marshal_open}, 99 | {"isDevel", ipc_is_devel}, 100 | {"channel", channel_create}, 101 | {NULL, NULL} 102 | }; 103 | 104 | static const struct luaL_Reg workqueue_routines[] = { 105 | {"close", workqueue_close}, 106 | {"read", workqueue_read}, 107 | {"write", workqueue_write}, 108 | {"writeup", workqueue_writeup}, 109 | {"drain", workqueue_drain}, 110 | {"retain", workqueue_retain}, 111 | {"metatablename", workqueue_metatablename}, 112 | {"__gc", workqueue_gc}, 113 | {NULL, NULL} 114 | }; 115 | 116 | static const struct luaL_Reg server_routines[] = { 117 | {"close", cliser_server_close}, 118 | {"clients", cliser_server_clients}, 119 | {"broadcast", cliser_server_broadcast}, 120 | {"recvAny", cliser_server_recv_any}, 121 | {"netStats", cliser_server_net_stats}, 122 | {NULL, NULL} 123 | }; 124 | 125 | static const struct luaL_Reg server_client_routines[] = { 126 | {"send", cliser_server_send}, 127 | {"recv", cliser_server_recv}, 128 | {"tag", cliser_server_tag}, 129 | {"id", cliser_server_id}, 130 | {"close", cliser_server_client_close}, 131 | {"address", cliser_server_client_address}, 132 | {NULL, NULL} 133 | }; 134 | 135 | static const struct luaL_Reg client_routines[] = { 136 | {"close", cliser_client_close}, 137 | {"__gc", cliser_client_close}, 138 | {"send", cliser_client_send}, 139 | {"recv", cliser_client_recv}, 140 | {"recvAsync", cliser_client_recv_async}, 141 | {"retain", cliser_client_retain}, 142 | {"metatablename", cliser_client_metatablename}, 143 | {"netStats", cliser_client_net_stats}, 144 | {NULL, NULL} 145 | }; 146 | 147 | static const struct luaL_Reg map_routines[] = { 148 | {"join", map_join}, 149 | {"checkErrors", map_check_errors}, 150 | {NULL, NULL} 151 | }; 152 | 153 | static const struct luaL_Reg spawn_routines[] = { 154 | {"stdin", spawn_stdin}, 155 | {"stdout", spawn_stdout}, 156 | {"stdoutFileId", spawn_stdout_file_id}, 157 | {"wait", spawn_wait}, 158 | {"pid", spawn_pid}, 159 | {"running", spawn_running}, 160 | {"__gc", spawn_gc}, 161 | {NULL, NULL} 162 | }; 163 | 164 | static const struct luaL_Reg flock_routines[] = { 165 | {"close", flock_close}, 166 | {"__gc", flock_close}, 167 | {NULL, NULL} 168 | }; 169 | 170 | static const struct luaL_Reg mutex_routines[] = { 171 | {"lock", mutex_lock}, 172 | {"unlock", mutex_unlock}, 173 | {"barrier", mutex_barrier}, 174 | {"retain", mutex_retain}, 175 | {"metatablename", mutex_metatablename}, 176 | {"__gc", mutex_gc}, 177 | {NULL, NULL} 178 | }; 179 | 180 | static const struct luaL_Reg sharedtable_routines[] = { 181 | {"retain", sharedtable_retain}, 182 | {"metatablename", sharedtable_metatablename}, 183 | {"__gc", sharedtable_gc}, 184 | {"__index", sharedtable_read}, 185 | {"__newindex", sharedtable_write}, 186 | {"__len", sharedtable_len}, 187 | {"__pairs", sharedtable_pairs}, 188 | {NULL, NULL} 189 | }; 190 | 191 | static const struct luaL_Reg marshal_routines[] = { 192 | {"close", marshal_close}, 193 | {"read", marshal_read}, 194 | {"retain", marshal_retain}, 195 | {"metatablename", marshal_metatablename}, 196 | {"__gc", marshal_gc}, 197 | {NULL, NULL} 198 | }; 199 | 200 | static const struct luaL_Reg channel_routines[] = { 201 | {"close", channel_close}, 202 | {"closed", channel_closed}, 203 | {"drained", channel_drained}, 204 | {"read", channel_read}, 205 | {"write", channel_write}, 206 | {"num_items", channel_num_items}, 207 | {"retain", channel_retain}, 208 | {"metatablename", channel_metatablename}, 209 | {"__gc", channel_gc}, 210 | {NULL, NULL} 211 | }; 212 | 213 | static void set_channel_table(lua_State *L) { 214 | const char* statusNames[] = { 215 | "OPEN", "CLOSED", "DRAINED" 216 | }; 217 | const int statusValues[] = { 218 | STATUS_OPEN, STATUS_CLOSED, STATUS_DRAINED 219 | }; 220 | lua_createtable(L, 0, 4); 221 | for (int i = 0; i < 3; ++i) { 222 | lua_pushstring(L, statusNames[i]); 223 | lua_pushinteger(L, statusValues[i]); 224 | lua_settable(L, -3); 225 | } 226 | lua_createtable(L, 0, 1); 227 | lua_pushstring(L, "__call"); 228 | lua_pushcfunction(L, channel_create); 229 | lua_settable(L, -3); 230 | lua_setmetatable(L, -2); 231 | lua_pushstring(L, "channel"); 232 | lua_pushvalue(L, -2); 233 | lua_settable(L, -4); 234 | lua_pop(L, 1); 235 | } 236 | 237 | DLL_EXPORT int luaopen_libipc(lua_State *L) { 238 | signal(SIGPIPE, SIG_IGN); // don't die for SIGPIPE 239 | luaL_newmetatable(L, "ipc.workqueue"); 240 | lua_pushstring(L, "__index"); 241 | lua_pushvalue(L, -2); 242 | lua_settable(L, -3); 243 | luaT_setfuncs(L, workqueue_routines, 0); 244 | lua_pop(L, 1); 245 | luaL_newmetatable(L, "ipc.server"); 246 | lua_pushstring(L, "__index"); 247 | lua_pushvalue(L, -2); 248 | lua_settable(L, -3); 249 | luaT_setfuncs(L, server_routines, 0); 250 | lua_pop(L, 1); 251 | luaL_newmetatable(L, "ipc.server.client"); 252 | lua_pushstring(L, "__index"); 253 | lua_pushvalue(L, -2); 254 | lua_settable(L, -3); 255 | luaT_setfuncs(L, server_client_routines, 0); 256 | lua_pop(L, 1); 257 | luaL_newmetatable(L, "ipc.client"); 258 | lua_pushstring(L, "__index"); 259 | lua_pushvalue(L, -2); 260 | lua_settable(L, -3); 261 | luaT_setfuncs(L, client_routines, 0); 262 | lua_pop(L, 1); 263 | luaL_newmetatable(L, "ipc.map"); 264 | lua_pushstring(L, "__index"); 265 | lua_pushvalue(L, -2); 266 | lua_settable(L, -3); 267 | luaT_setfuncs(L, map_routines, 0); 268 | lua_pop(L, 1); 269 | luaL_newmetatable(L, "ipc.spawn"); 270 | lua_pushstring(L, "__index"); 271 | lua_pushvalue(L, -2); 272 | lua_settable(L, -3); 273 | luaT_setfuncs(L, spawn_routines, 0); 274 | lua_pop(L, 1); 275 | luaL_newmetatable(L, "ipc.flock"); 276 | lua_pushstring(L, "__index"); 277 | lua_pushvalue(L, -2); 278 | lua_settable(L, -3); 279 | luaT_setfuncs(L, flock_routines, 0); 280 | lua_pop(L, 1); 281 | luaL_newmetatable(L, "ipc.mutex"); 282 | lua_pushstring(L, "__index"); 283 | lua_pushvalue(L, -2); 284 | lua_settable(L, -3); 285 | luaT_setfuncs(L, mutex_routines, 0); 286 | lua_pop(L, 1); 287 | luaL_newmetatable(L, "ipc.marshal"); 288 | lua_pushstring(L, "__index"); 289 | lua_pushvalue(L, -2); 290 | lua_settable(L, -3); 291 | luaT_setfuncs(L, marshal_routines, 0); 292 | lua_pop(L, 1); 293 | luaL_newmetatable(L, "ipc.sharedtable"); 294 | luaT_setfuncs(L, sharedtable_routines, 0); 295 | lua_pop(L, 1); 296 | luaL_newmetatable(L, "ipc.channel"); 297 | lua_pushstring(L, "__index"); 298 | lua_pushvalue(L, -2); 299 | lua_settable(L, -3); 300 | luaT_setfuncs(L, channel_routines, 0); 301 | lua_pop(L, 1); 302 | Lcliser_CharInit(L); 303 | Lcliser_ByteInit(L); 304 | Lcliser_ShortInit(L); 305 | Lcliser_IntInit(L); 306 | Lcliser_LongInit(L); 307 | Lcliser_FloatInit(L); 308 | Lcliser_DoubleInit(L); 309 | #ifdef USE_CUDA 310 | Lcliser_CudaInit(L); 311 | #endif 312 | lua_newtable(L); 313 | luaT_setfuncs(L, ipc_routines, 0); 314 | set_channel_table(L); 315 | return 1; 316 | } 317 | -------------------------------------------------------------------------------- /src/map.c: -------------------------------------------------------------------------------- 1 | #include "luaT.h" 2 | #include "ringbuffer.h" 3 | #include "serialize.h" 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "error.h" 9 | #ifdef _OPENMP 10 | #include 11 | #endif 12 | #include 13 | #include 14 | 15 | #define MAX_ARG_SIZE (16*1024) 16 | 17 | typedef struct map_thread_t { 18 | pthread_t thread; 19 | ringbuffer_t *rb; 20 | int ret; 21 | } map_thread_t; 22 | 23 | typedef struct map_t { 24 | map_thread_t *threads; 25 | uint32_t num_threads; 26 | } map_t; 27 | 28 | typedef int (*ThreadInitFunc) (lua_State *L); 29 | ThreadInitFunc _ipc_static_init_thread = NULL; 30 | 31 | static int rb_save_with_growth(lua_State *L, int index, struct ringbuffer_t *rb) { 32 | while (1) { 33 | ringbuffer_push_write_pos(rb); 34 | int startsize = lua_gettop(L); 35 | int ret = rb_save(L, index, rb, 0, 0); // map doesn't support upvalues 36 | if (ret == -ENOMEM) { 37 | int top = lua_gettop(L); 38 | if (top > startsize) { 39 | lua_pop(L, top-startsize); 40 | } else if (top < startsize) { 41 | LUA_HANDLE_ERROR_STR(L, "too many items popped during serialization"); 42 | } 43 | ringbuffer_pop_write_pos(rb); 44 | ringbuffer_grow_by(rb, MAX_ARG_SIZE); 45 | } else { 46 | return ret; 47 | } 48 | } 49 | } 50 | 51 | static int start_thread(void* arg, map_thread_t **map_thread, lua_State **L, int *top, const char *name) { 52 | #ifdef _OPENMP 53 | // prevent MKL/BLAS from crashing on the reader threads 54 | // its use of open-mp eats up way too many threads 55 | omp_set_num_threads(1); 56 | #endif 57 | 58 | *map_thread = (map_thread_t *)arg; 59 | *L = luaL_newstate(); 60 | if (_ipc_static_init_thread) { 61 | _ipc_static_init_thread(*L); 62 | } else { 63 | luaL_openlibs(*L); 64 | } 65 | // in order to deserialize arguments we need torch and libipc 66 | // TODO: detect these on the main thread when serializing arguments 67 | *top = lua_gettop(*L); 68 | if (luaL_loadstring(*L, "require 'torch'; require 'libipc'; pcall(function() require 'twutil' end)")) { 69 | lua_close(*L); 70 | return 0; 71 | } 72 | (*map_thread)->ret = lua_pcall(*L, 0, 0, 0); 73 | if ((*map_thread)->ret) { 74 | fprintf(stderr, "WARN1: ipc.%s thread pcall failed: %s\n", name, lua_tostring(*L, -1)); 75 | return 0; 76 | } else { 77 | return 1; 78 | } 79 | } 80 | 81 | static void end_thread(map_thread_t *map_thread, lua_State *L, int top, const char *name) { 82 | int k = lua_gettop(L) - top; 83 | for (int i = 1; i <= k; i++) { 84 | int ret = rb_save_with_growth(L, top + i, map_thread->rb); 85 | if (ret) { 86 | fprintf(stderr, "WARN: ipc.%s thread failed to write results: %s\n", name, strerror(-ret)); 87 | map_thread->ret = ret; 88 | break; 89 | } 90 | } 91 | lua_close(L); 92 | } 93 | 94 | static void core_thread(map_thread_t *map_thread, lua_State *L, const char *name) { 95 | int i = 0; 96 | while (ringbuffer_peek(map_thread->rb)) { 97 | int ret = rb_load(L, map_thread->rb); 98 | if (ret < 0) { 99 | LUA_HANDLE_ERROR_STR(L, "thread Lua data wasn't loaded correctly"); 100 | }; 101 | i++; 102 | } 103 | map_thread->ret = lua_pcall(L, i - 1, LUA_MULTRET, 0); 104 | if (map_thread->ret) { 105 | fprintf(stderr, "WARN2: ipc.%s thread pcall failed: %s\n", name, lua_tostring(L, -1)); 106 | } 107 | } 108 | 109 | static void* thread_func(void *arg) { 110 | map_thread_t *map_thread; 111 | lua_State *L; 112 | int top; 113 | if (start_thread(arg, &map_thread, &L, &top, "map")) { 114 | top = lua_gettop(L); 115 | core_thread(map_thread, L, "map"); 116 | } 117 | end_thread(map_thread, L, top, "map"); 118 | return 0; 119 | } 120 | 121 | static int core_map(lua_State *L, void* (*func)(void*)) { 122 | uint32_t num_threads = lua_tonumber(L, 1); 123 | map_thread_t *threads = (map_thread_t *)calloc(num_threads, sizeof(map_thread_t)); 124 | int k = lua_gettop(L); 125 | for (uint32_t i = 0; i < num_threads; i++) { 126 | threads[i].rb = ringbuffer_create(MAX_ARG_SIZE); 127 | for (int j = 2; j <= k; j++) { // save function and arguments 128 | int ret = rb_save_with_growth(L, j, threads[i].rb); 129 | if (ret) return LUA_HANDLE_ERROR(L, ret); 130 | } 131 | lua_pushinteger(L, i + 1); // mapid is the last argument (id of the thread) 132 | int ret = rb_save_with_growth(L, k + 1, threads[i].rb); 133 | if (ret) return LUA_HANDLE_ERROR(L, ret); 134 | lua_pop(L, 1); 135 | ret = pthread_create(&threads[i].thread, NULL, func, &threads[i]); 136 | if (ret) return LUA_HANDLE_ERROR(L, ret); 137 | } 138 | map_t *map = (map_t *)lua_newuserdata(L, sizeof(map_t)); 139 | map->num_threads = num_threads; 140 | map->threads = threads; 141 | luaL_getmetatable(L, "ipc.map"); 142 | lua_setmetatable(L, -2); 143 | return 1; 144 | } 145 | 146 | int map_open(lua_State *L) { 147 | if (lua_type(L, 2) != LUA_TFUNCTION) return LUA_HANDLE_ERROR_STR(L, "map arg #2 expected a function"); 148 | return core_map(L, thread_func); 149 | } 150 | 151 | static void* thread_extended_func(void *arg) { 152 | map_thread_t *map_thread; 153 | lua_State *L; 154 | int top; 155 | if (start_thread(arg, &map_thread, &L, &top, "map_extended")) { 156 | top = lua_gettop(L); 157 | rb_load(L, map_thread->rb); 158 | if (lua_type(L, top+1) == LUA_TSTRING) { 159 | size_t str_len; 160 | const char *str = lua_tolstring(L, top+1, &str_len); 161 | char *other_str = malloc(str_len+1); 162 | memcpy((void*)other_str, (void*)str, str_len); 163 | other_str[str_len] = 0; 164 | lua_pop(L, 1); 165 | luaL_loadstring(L, other_str); 166 | } 167 | if (lua_isnil(L, top+1)) { 168 | map_thread->ret = 0; 169 | lua_pop(L, 1); 170 | } else { 171 | map_thread->ret = lua_pcall(L, 0, 0, 0); 172 | } 173 | if (map_thread->ret) { 174 | fprintf(stderr, "WARN: ipc.map_extended thread pcall failed: %s\n", lua_tostring(L, -1)); 175 | } else { 176 | core_thread(map_thread, L, "map_extended"); 177 | } 178 | } 179 | end_thread(map_thread, L, top, "map_extended"); 180 | return 0; 181 | } 182 | 183 | int map_extended_open(lua_State *L) { 184 | if (lua_type(L, 2) != LUA_TFUNCTION 185 | && lua_type(L, 2) != LUA_TSTRING 186 | && lua_type(L, 2) != LUA_TNIL) 187 | return LUA_HANDLE_ERROR_STR(L, "map_extended arg #2 expected a function, string or nil"); 188 | if (lua_type(L, 3) != LUA_TFUNCTION) return LUA_HANDLE_ERROR_STR(L, "map_extended arg #3 expected a function"); 189 | return core_map(L, thread_extended_func); 190 | } 191 | 192 | int map_join(lua_State *L) { 193 | int rc = 0; 194 | int err_rc = -1; 195 | map_t *map = (map_t *)lua_touserdata(L, 1); 196 | for (uint32_t i = 0; i < map->num_threads; i++) { 197 | if (map->threads[i].rb) { 198 | int ret = pthread_join(map->threads[i].thread, NULL); 199 | if (ret) return LUA_HANDLE_ERROR(L, ret); 200 | if (map->threads[i].ret) { 201 | err_rc = rc; 202 | } 203 | while (ringbuffer_peek(map->threads[i].rb)) { 204 | rb_load(L, map->threads[i].rb); 205 | rc++; 206 | } 207 | ringbuffer_destroy(map->threads[i].rb); 208 | } 209 | } 210 | free(map->threads); 211 | map->threads = NULL; 212 | map->num_threads = 0; 213 | if (err_rc >= 0) { 214 | return LUA_HANDLE_ERROR_STR(L, lua_tostring(L, err_rc - rc)); 215 | } 216 | return rc; 217 | } 218 | 219 | int map_check_errors(lua_State *L) { 220 | map_t *map = (map_t *)lua_touserdata(L, 1); 221 | for (uint32_t i = 0; i < map->num_threads; i++) { 222 | if (map->threads[i].ret) { 223 | pthread_join(map->threads[i].thread, NULL); 224 | while (ringbuffer_peek(map->threads[i].rb)) { 225 | rb_load(L, map->threads[i].rb); 226 | } 227 | ringbuffer_destroy(map->threads[i].rb); 228 | map->threads[i].rb = NULL; 229 | return LUA_HANDLE_ERROR_STR(L, lua_tostring(L, -1)); 230 | } 231 | } 232 | return 0; 233 | } 234 | -------------------------------------------------------------------------------- /src/map.h: -------------------------------------------------------------------------------- 1 | #ifndef _MAP_H_ 2 | #define _MAP_H_ 3 | 4 | #include "luaT.h" 5 | 6 | int map_open(lua_State *L); 7 | int map_extended_open(lua_State *L); 8 | int map_join(lua_State *L); 9 | int map_check_errors(lua_State *L); 10 | 11 | #endif 12 | -------------------------------------------------------------------------------- /src/marshal.c: -------------------------------------------------------------------------------- 1 | #include "TH.h" 2 | #include "luaT.h" 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "ringbuffer.h" 8 | #include "serialize.h" 9 | #include "error.h" 10 | 11 | #define DEFAULT_MARSHAL_SIZE (1024*16) 12 | 13 | typedef struct marshal_t { 14 | struct ringbuffer_t* rb; 15 | int refcount; 16 | size_t size_increment; 17 | int empty; 18 | } marshal_t; 19 | 20 | static int marshal_write(lua_State *L, int index, marshal_t *marshal, int upval) { 21 | while (1) { 22 | ringbuffer_push_write_pos(marshal->rb); 23 | int ret = rb_save(L, index, marshal->rb, 0, upval); 24 | if (ret == -ENOMEM) { 25 | ringbuffer_pop_write_pos(marshal->rb); 26 | ringbuffer_grow_by(marshal->rb, marshal->size_increment); 27 | } else if (ret) { 28 | ringbuffer_pop_write_pos(marshal->rb); 29 | return LUA_HANDLE_ERROR(L, -ret); 30 | } else { 31 | break; 32 | } 33 | } 34 | return 0; 35 | } 36 | 37 | int marshal_open(lua_State *L) { 38 | int upval = lua_toboolean(L, 2); 39 | size_t size = luaL_optnumber(L, 3, DEFAULT_MARSHAL_SIZE); 40 | size_t size_increment = luaL_optnumber(L, 4, size); 41 | marshal_t *marshal = (marshal_t *)calloc(1, sizeof(marshal_t)); 42 | marshal->refcount = 1; 43 | marshal->size_increment = size_increment; 44 | marshal->rb = ringbuffer_create(size); 45 | marshal_t** ud = (marshal_t **)lua_newuserdata(L, sizeof(marshal_t*)); 46 | *ud = marshal; 47 | luaL_getmetatable(L, "ipc.marshal"); 48 | lua_setmetatable(L, -2); 49 | 50 | if (lua_isnil(L, 1) == 1) { 51 | LUA_HANDLE_ERROR_STR(L, "must provide object to serialize at arg 1"); 52 | } else { 53 | marshal_write(L, 1, marshal, upval); 54 | } 55 | return 1; 56 | } 57 | 58 | int marshal_read(lua_State *L) { 59 | marshal_t *marshal = *(marshal_t **)lua_touserdata(L, 1); 60 | if (!marshal) return LUA_HANDLE_ERROR_STR(L, "marshal is not open"); 61 | ringbuffer_t* rb = ringbuffer_clone(marshal->rb); 62 | int ret = rb_load(L, rb); 63 | free(rb); 64 | if (ret < 0) return LUA_HANDLE_ERROR(L, ret); 65 | return 1; 66 | } 67 | 68 | int marshal_close(lua_State *L) { 69 | marshal_t **ud = (marshal_t **)lua_touserdata(L, 1); 70 | marshal_t *marshal = *ud; 71 | if (!marshal) return LUA_HANDLE_ERROR_STR(L, "marshal is already closed"); 72 | if (THAtomicDecrementRef(&marshal->refcount)) { 73 | ringbuffer_destroy(marshal->rb); 74 | free(marshal); 75 | } 76 | *ud = NULL; 77 | return 0; 78 | } 79 | 80 | int marshal_gc(lua_State *L) { 81 | marshal_t *marshal = *(marshal_t **)lua_touserdata(L, 1); 82 | if (marshal) { 83 | marshal_close(L); 84 | } 85 | return 0; 86 | } 87 | 88 | int marshal_retain(lua_State *L) { 89 | marshal_t *marshal = *(marshal_t **)lua_touserdata(L, 1); 90 | if (!marshal) return LUA_HANDLE_ERROR_STR(L, "marshal is not open"); 91 | THAtomicIncrementRef(&marshal->refcount); 92 | return 0; 93 | } 94 | 95 | int marshal_metatablename(lua_State *L) { 96 | lua_pushstring(L, "ipc.marshal"); 97 | return 1; 98 | } 99 | -------------------------------------------------------------------------------- /src/marshal.h: -------------------------------------------------------------------------------- 1 | #ifndef _MARSHAL_H_ 2 | #define _MARSHAL_H_ 3 | 4 | #include "luaT.h" 5 | #include "ringbuffer.h" 6 | 7 | int marshal_open(lua_State *L); 8 | int marshal_close(lua_State *L); 9 | int marshal_read(lua_State *L); 10 | int marshal_gc(lua_State *L); 11 | int marshal_retain(lua_State *L); 12 | int marshal_metatablename(lua_State *L); 13 | 14 | #endif 15 | -------------------------------------------------------------------------------- /src/mutex.c: -------------------------------------------------------------------------------- 1 | #include "mutex.h" 2 | #include "error.h" 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "TH.h" 8 | #include "luaT.h" 9 | 10 | typedef struct mutex_t { 11 | int ref_count; 12 | pthread_mutex_t mutex; 13 | pthread_cond_t cond; 14 | int64_t barrier; 15 | } mutex_t; 16 | 17 | int mutex_create(lua_State *L) { 18 | mutex_t *mutex = calloc(1, sizeof(mutex_t)); 19 | pthread_mutexattr_t mutex_attr; 20 | pthread_mutexattr_init(&mutex_attr); 21 | pthread_mutexattr_settype(&mutex_attr, PTHREAD_MUTEX_RECURSIVE); 22 | int ret = pthread_mutex_init(&mutex->mutex, &mutex_attr); 23 | if (ret) { 24 | free(mutex); 25 | return LUA_HANDLE_ERROR(L, errno); 26 | } 27 | ret = pthread_cond_init(&mutex->cond, NULL); 28 | if (ret) { 29 | pthread_mutex_destroy(&mutex->mutex); 30 | free(mutex); 31 | return LUA_HANDLE_ERROR(L, errno); 32 | } 33 | mutex_t **umutex = lua_newuserdata(L, sizeof(mutex_t *)); 34 | *umutex = mutex; 35 | mutex->ref_count = 1; 36 | luaL_getmetatable(L, "ipc.mutex"); 37 | lua_setmetatable(L, -2); 38 | return 1; 39 | } 40 | 41 | int mutex_lock(lua_State *L) { 42 | mutex_t *mutex = *(mutex_t **)lua_touserdata(L, 1); 43 | int ret = pthread_mutex_lock(&mutex->mutex); 44 | if (ret) return LUA_HANDLE_ERROR(L, ret); 45 | return 0; 46 | } 47 | 48 | int mutex_unlock(lua_State *L) { 49 | mutex_t *mutex = *(mutex_t **)lua_touserdata(L, 1); 50 | int ret = pthread_mutex_unlock(&mutex->mutex); 51 | if (ret) return LUA_HANDLE_ERROR(L, ret); 52 | return 0; 53 | } 54 | 55 | int mutex_barrier(lua_State *L) { 56 | mutex_t *mutex = *(mutex_t **)lua_touserdata(L, 1); 57 | int64_t count = lua_tointeger(L, 2); 58 | int ret = pthread_mutex_lock(&mutex->mutex); 59 | if (ret) return LUA_HANDLE_ERROR(L, ret); 60 | mutex->barrier++; 61 | if (mutex->barrier == count) { 62 | ret = pthread_cond_broadcast(&mutex->cond); 63 | if (ret) return LUA_HANDLE_ERROR(L, ret); 64 | mutex->barrier = 0; 65 | } else { 66 | ret = pthread_cond_wait(&mutex->cond, &mutex->mutex); 67 | if (ret) return LUA_HANDLE_ERROR(L, ret); 68 | } 69 | ret = pthread_mutex_unlock(&mutex->mutex); 70 | if (ret) return LUA_HANDLE_ERROR(L, ret); 71 | return 0; 72 | } 73 | 74 | int mutex_retain(lua_State *L) { 75 | mutex_t *mutex = *(mutex_t **)lua_touserdata(L, 1); 76 | THAtomicIncrementRef(&mutex->ref_count); 77 | return 0; 78 | } 79 | 80 | int mutex_metatablename(lua_State *L) { 81 | lua_pushstring(L, "ipc.mutex"); 82 | return 1; 83 | } 84 | 85 | int mutex_gc(lua_State *L) { 86 | mutex_t *mutex = *(mutex_t **)lua_touserdata(L, 1); 87 | if (THAtomicDecrementRef(&mutex->ref_count)) { 88 | pthread_mutex_destroy(&mutex->mutex); 89 | pthread_cond_destroy(&mutex->cond); 90 | free(mutex); 91 | } 92 | return 0; 93 | } 94 | -------------------------------------------------------------------------------- /src/mutex.h: -------------------------------------------------------------------------------- 1 | #ifndef _MUTEX_H_ 2 | #define _MUTEX_H_ 3 | 4 | #include "luaT.h" 5 | 6 | int mutex_create(lua_State *L); 7 | int mutex_lock(lua_State *L); 8 | int mutex_unlock(lua_State *L); 9 | int mutex_barrier(lua_State *L); 10 | int mutex_retain(lua_State *L); 11 | int mutex_metatablename(lua_State *L); 12 | int mutex_gc(lua_State *L); 13 | 14 | #endif 15 | -------------------------------------------------------------------------------- /src/ringbuffer.c: -------------------------------------------------------------------------------- 1 | #include "ringbuffer.h" 2 | #include 3 | #include 4 | 5 | ringbuffer_t* ringbuffer_create(size_t cb) { 6 | ringbuffer_t* rb = malloc(sizeof(ringbuffer_t)); 7 | rb->buf = malloc(cb); 8 | rb->cb = cb; 9 | rb->rp = 0; 10 | rb->wp = 0; 11 | rb->rcb = 0; 12 | rb->saved_wp = 0; 13 | rb->saved_rcb = 0; 14 | return rb; 15 | } 16 | 17 | void ringbuffer_destroy(ringbuffer_t* rb) { 18 | free(rb->buf); 19 | free(rb); 20 | } 21 | 22 | void ringbuffer_grow_by(ringbuffer_t *rb, size_t cb) { 23 | size_t new_cb = rb->cb + cb; 24 | uint8_t *new_buf = malloc(new_cb); 25 | size_t rcb = ringbuffer_read(rb, new_buf, new_cb); 26 | free(rb->buf); 27 | rb->buf = new_buf; 28 | rb->cb = new_cb; 29 | rb->rp = 0; 30 | rb->wp = rcb; 31 | rb->rcb = rcb; 32 | rb->saved_wp = 0; 33 | rb->saved_rcb = 0; 34 | } 35 | 36 | static size_t min(size_t a, size_t b) { 37 | if (a < b) { 38 | return a; 39 | } 40 | return b; 41 | } 42 | 43 | size_t ringbuffer_write(ringbuffer_t* rb, const void* in, size_t cb) { 44 | size_t i = min(cb, rb->cb - rb->rcb); 45 | if (rb->wp + i < rb->cb) { 46 | if (in) { 47 | uint8_t* in8 = (uint8_t *)in; 48 | memcpy(&rb->buf[rb->wp], in8, i); 49 | } 50 | rb->wp += i; 51 | } 52 | else { 53 | size_t size2 = (rb->wp + i) % rb->cb; 54 | size_t size1 = i - size2; 55 | if (in) { 56 | uint8_t* in8 = (uint8_t *)in; 57 | memcpy(&rb->buf[rb->wp], in8, size1); 58 | memcpy(rb->buf, &in8[size1], size2); 59 | } 60 | rb->wp = size2; 61 | } 62 | rb->rcb += i; 63 | return i; 64 | } 65 | 66 | size_t ringbuffer_read(ringbuffer_t* rb, void* out, size_t cb) { 67 | size_t i = min(cb, rb->rcb); 68 | if (rb->rp + i < rb->cb) { 69 | uint8_t* out8 = (uint8_t *)out; 70 | memcpy(out8, &rb->buf[rb->rp], i); 71 | rb->rp += i; 72 | } 73 | else { 74 | size_t size2 = (rb->rp + i) % rb->cb; 75 | size_t size1 = i - size2; 76 | uint8_t* out8 = (uint8_t *)out; 77 | memcpy(out8, &rb->buf[rb->rp], size1); 78 | memcpy(&out8[size1], rb->buf, size2); 79 | rb->rp = size2; 80 | } 81 | rb->rcb -= i; 82 | return i; 83 | } 84 | 85 | 86 | size_t ringbuffer_peek(struct ringbuffer_t* rb) { 87 | return rb->rcb; 88 | } 89 | 90 | void ringbuffer_push_write_pos(struct ringbuffer_t* rb) { 91 | rb->saved_wp = rb->wp; 92 | rb->saved_rcb = rb->rcb; 93 | } 94 | 95 | void ringbuffer_pop_write_pos(struct ringbuffer_t* rb) { 96 | rb->wp = rb->saved_wp; 97 | rb->rcb = rb->saved_rcb; 98 | } 99 | 100 | void ringbuffer_reset_read_pos(struct ringbuffer_t* rb) { 101 | rb->rp = 0; 102 | } 103 | 104 | void* ringbuffer_buf_ptr(struct ringbuffer_t* rb) { 105 | return rb->buf; 106 | } 107 | 108 | // clone everthing except buf (the buf is shared) 109 | ringbuffer_t* ringbuffer_clone(ringbuffer_t* rb) { 110 | ringbuffer_t* crb = malloc(sizeof(ringbuffer_t)); 111 | memcpy(crb, rb, sizeof(ringbuffer_t)); 112 | return crb; 113 | } 114 | -------------------------------------------------------------------------------- /src/ringbuffer.h: -------------------------------------------------------------------------------- 1 | #ifndef _RINGBUFFER_H_ 2 | #define _RINGBUFFER_H_ 3 | 4 | #include 5 | #include 6 | 7 | typedef struct ringbuffer_t { 8 | uint8_t* buf; 9 | size_t cb; 10 | size_t rp; 11 | size_t wp; 12 | size_t rcb; 13 | size_t saved_wp; 14 | size_t saved_rcb; 15 | } ringbuffer_t; 16 | 17 | ringbuffer_t* ringbuffer_create(size_t cb); 18 | void ringbuffer_destroy(ringbuffer_t* rb); 19 | void ringbuffer_grow_by(ringbuffer_t *rb, size_t cb); 20 | size_t ringbuffer_write(ringbuffer_t* rb, const void* in, size_t cb); 21 | size_t ringbuffer_read(ringbuffer_t* rb, void* out, size_t cb); 22 | size_t ringbuffer_peek(ringbuffer_t* rb); 23 | void ringbuffer_push_write_pos(ringbuffer_t* rb); 24 | void ringbuffer_pop_write_pos(ringbuffer_t* rb); 25 | void ringbuffer_reset_read_pos(ringbuffer_t* rb); 26 | void* ringbuffer_buf_ptr(ringbuffer_t* rb); 27 | ringbuffer_t* ringbuffer_clone(ringbuffer_t* rb); 28 | 29 | #endif 30 | -------------------------------------------------------------------------------- /src/serialize.h: -------------------------------------------------------------------------------- 1 | #ifndef _SERIALIZE_H_ 2 | #define _SERIALIZE_H_ 3 | 4 | #include "luaT.h" 5 | #include "ringbuffer.h" 6 | 7 | int rb_load(lua_State *L, struct ringbuffer_t *rb); 8 | int rb_save(lua_State *L, int index, struct ringbuffer_t *rb, int oop, int upval); 9 | 10 | #endif 11 | -------------------------------------------------------------------------------- /src/sharedtable.c: -------------------------------------------------------------------------------- 1 | #include "TH.h" 2 | #include "luaT.h" 3 | #include "ringbuffer.h" 4 | #include "serialize.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include "error.h" 10 | #ifdef _OPENMP 11 | #include 12 | #endif 13 | #include 14 | #include 15 | 16 | #define BUFFER_SIZE (16*1024) 17 | 18 | typedef struct sharedtable_t { 19 | struct ringbuffer_t* rb; 20 | pthread_mutex_t mutex; 21 | lua_State* L; // the shared table is stored in its own lua_State 22 | size_t size_increment; 23 | int ref_count; 24 | } sharedtable_t; 25 | 26 | static int rb_save_with_growth(lua_State *L, int index, struct ringbuffer_t *rb, size_t size) { 27 | while (1) { 28 | ringbuffer_push_write_pos(rb); 29 | int ret = rb_save(L, index, rb, 0, 0); 30 | if (ret == -ENOMEM) { 31 | ringbuffer_pop_write_pos(rb); 32 | ringbuffer_grow_by(rb, size); 33 | } else { 34 | return ret; 35 | } 36 | } 37 | } 38 | 39 | static int copy_entry_to_table(lua_State *L, sharedtable_t *table, int move) { 40 | int top = lua_gettop(L); 41 | int key_pos = top-1; 42 | int val_pos = top; 43 | int ret; 44 | 45 | ret = rb_save_with_growth(L, key_pos, table->rb, table->size_increment); 46 | if (ret) { 47 | lua_pop(L, 2); 48 | return ret; 49 | } 50 | ret = rb_load(table->L, table->rb); 51 | 52 | ret = rb_save_with_growth(L, val_pos, table->rb, table->size_increment); 53 | if (ret) { 54 | lua_pop(L, 2); 55 | return ret; 56 | } 57 | ret = rb_load(table->L, table->rb); 58 | 59 | lua_settable(table->L, 1); 60 | if (move) { 61 | lua_pushvalue(L, key_pos); 62 | lua_pushnil(L); 63 | lua_settable(L, 1); 64 | } 65 | 66 | return 0; 67 | } 68 | 69 | static int init_table(lua_State *L, sharedtable_t *table, int move) { 70 | lua_pushnil(L); 71 | while (lua_next(L, 1) != 0) { 72 | int ret = copy_entry_to_table(L, table, move); 73 | if (ret) return ret; 74 | lua_pop(L, 1); 75 | } 76 | return 0; 77 | } 78 | 79 | typedef int (*ThreadInitFunc) (lua_State *L); 80 | extern ThreadInitFunc _ipc_static_init_thread; 81 | 82 | int sharedtable_create(lua_State *L) { 83 | if (lua_gettop(L) > 0 84 | && lua_type(L, 1) != LUA_TTABLE 85 | && lua_type(L, 1) != LUA_TNIL) 86 | return LUA_HANDLE_ERROR_STR(L, "sharedtable arg #1 expected to be a table or nil"); 87 | 88 | int move = lua_toboolean(L, 2); 89 | const char *requires = luaL_optlstring(L, 3, "", NULL); 90 | size_t size = luaL_optnumber(L, 4, BUFFER_SIZE); 91 | size_t size_increment = luaL_optnumber(L, 5, size); 92 | sharedtable_t *table = (sharedtable_t *)calloc(1, sizeof(sharedtable_t)); 93 | table->L = luaL_newstate(); 94 | if (_ipc_static_init_thread) { 95 | _ipc_static_init_thread(table->L); 96 | } else { 97 | luaL_openlibs(table->L); 98 | } 99 | 100 | if (luaL_loadstring(table->L, "require 'torch'; require 'libipc';")) { 101 | lua_close(table->L); 102 | free(table); 103 | return 0; 104 | } 105 | if (lua_pcall(table->L, 0, 0, 0)) { 106 | fprintf(stderr, "WARN: ipc.sharedtable initialization failed: %s\n", lua_tostring(table->L, -1)); 107 | lua_close(table->L); 108 | free(table); 109 | return 0; 110 | } 111 | 112 | if (requires) { 113 | if (luaL_loadstring(table->L, requires)) { 114 | lua_close(table->L); 115 | free(table); 116 | return 0; 117 | } 118 | if (lua_pcall(table->L, 0, 0, 0)) { 119 | fprintf(stderr, "WARN: ipc.sharedtable initialization failed: %s\n", lua_tostring(table->L, -1)); 120 | lua_close(table->L); 121 | free(table); 122 | return 0; 123 | } 124 | } 125 | 126 | table->rb = ringbuffer_create(size); 127 | lua_newtable(table->L); 128 | table->size_increment = size_increment; 129 | 130 | int ret = pthread_mutex_init(&table->mutex, NULL); 131 | if (ret) { 132 | ringbuffer_destroy(table->rb); 133 | lua_close(table->L); 134 | free(table); 135 | return LUA_HANDLE_ERROR(L, errno); 136 | } 137 | 138 | if (lua_gettop(L) > 0 && !lua_isnil(L, 1)) { 139 | ret = init_table(L, table, move); 140 | if (ret) { 141 | pthread_mutex_destroy(&table->mutex); 142 | ringbuffer_destroy(table->rb); 143 | lua_close(table->L); 144 | free(table); 145 | return LUA_HANDLE_ERROR(L, ret); 146 | } 147 | } 148 | 149 | sharedtable_t **ptable = lua_newuserdata(L, sizeof(sharedtable_t *)); 150 | *ptable = table; 151 | table->ref_count = 1; 152 | luaL_getmetatable(L, "ipc.sharedtable"); 153 | lua_setmetatable(L, -2); 154 | return 1; 155 | } 156 | 157 | int sharedtable_retain(lua_State *L) { 158 | lua_touserdata(L, 1); 159 | sharedtable_t **ptable = (sharedtable_t **)lua_touserdata(L, 1); 160 | sharedtable_t *table = *ptable; 161 | THAtomicIncrementRef(&table->ref_count); 162 | return 0; 163 | } 164 | 165 | int sharedtable_gc(lua_State *L) { 166 | lua_touserdata(L, 1); 167 | sharedtable_t **ptable = (sharedtable_t **)lua_touserdata(L, 1); 168 | sharedtable_t *table = *ptable; 169 | if (THAtomicDecrementRef(&table->ref_count)) { 170 | pthread_mutex_destroy(&table->mutex); 171 | ringbuffer_destroy(table->rb); 172 | lua_close(table->L); 173 | free(table); 174 | } 175 | return 0; 176 | } 177 | 178 | int sharedtable_read(lua_State *L) { 179 | sharedtable_t **ptable = (sharedtable_t **)lua_touserdata(L, 1); 180 | sharedtable_t *table = *ptable; 181 | 182 | int ret = pthread_mutex_lock(&table->mutex); 183 | if (ret) { 184 | pthread_mutex_unlock(&table->mutex); 185 | return LUA_HANDLE_ERROR(L, ret); 186 | } 187 | 188 | ret = rb_save_with_growth(L, 2, table->rb, table->size_increment); 189 | if (ret) { 190 | pthread_mutex_unlock(&table->mutex); 191 | return LUA_HANDLE_ERROR(L, ret); 192 | } 193 | ret = rb_load(table->L, table->rb); 194 | lua_gettable(table->L, 1); 195 | 196 | ret = rb_save_with_growth(table->L, 2, table->rb, table->size_increment); 197 | if (ret) { 198 | lua_pop(table->L, 1); 199 | pthread_mutex_unlock(&table->mutex); 200 | return LUA_HANDLE_ERROR(L, ret); 201 | } 202 | ret = rb_load(L, table->rb); 203 | 204 | lua_pop(table->L, 1); 205 | ret = pthread_mutex_unlock(&table->mutex); 206 | if (ret) return LUA_HANDLE_ERROR(L, ret); 207 | 208 | return 1; 209 | } 210 | 211 | int sharedtable_write(lua_State *L) { 212 | sharedtable_t **ptable = (sharedtable_t **)lua_touserdata(L, 1); 213 | sharedtable_t *table = *ptable; 214 | 215 | int ret = pthread_mutex_lock(&table->mutex); 216 | if (ret) { 217 | pthread_mutex_unlock(&table->mutex); 218 | return LUA_HANDLE_ERROR(L, ret); 219 | } 220 | 221 | ret = rb_save_with_growth(L, 2, table->rb, table->size_increment); 222 | if (ret) { 223 | pthread_mutex_unlock(&table->mutex); 224 | return LUA_HANDLE_ERROR(L, ret); 225 | } 226 | ret = rb_load(table->L, table->rb); 227 | 228 | ret = rb_save_with_growth(L, 3, table->rb, table->size_increment); 229 | if (ret) { 230 | lua_pop(table->L, 1); 231 | pthread_mutex_unlock(&table->mutex); 232 | return LUA_HANDLE_ERROR(L, ret); 233 | } 234 | ret = rb_load(table->L, table->rb); 235 | 236 | lua_settable(table->L, 1); 237 | 238 | ret = pthread_mutex_unlock(&table->mutex); 239 | if (ret) return LUA_HANDLE_ERROR(L, ret); 240 | 241 | return 0; 242 | } 243 | 244 | int sharedtable_len(lua_State *L) { 245 | sharedtable_t **ptable = (sharedtable_t **)lua_touserdata(L, 1); 246 | sharedtable_t *table = *ptable; 247 | 248 | int ret = pthread_mutex_lock(&table->mutex); 249 | if (ret) return LUA_HANDLE_ERROR(L, ret); 250 | 251 | size_t counter = 0; 252 | lua_pushnil(table->L); 253 | while (lua_next(table->L, 1) != 0) { 254 | lua_pop(table->L, 1); 255 | counter++; 256 | } 257 | 258 | ret = pthread_mutex_unlock(&table->mutex); 259 | if (ret) return LUA_HANDLE_ERROR(L, ret); 260 | 261 | lua_pushinteger(L, counter); 262 | 263 | return 1; 264 | } 265 | 266 | static int sharedtable_next(lua_State *L) { 267 | sharedtable_t **ptable = (sharedtable_t **)lua_touserdata(L, 1); 268 | sharedtable_t *table = *ptable; 269 | 270 | int ret = pthread_mutex_lock(&table->mutex); 271 | if (ret) return LUA_HANDLE_ERROR(L, ret); 272 | 273 | ret = rb_save_with_growth(L, 2, table->rb, table->size_increment); 274 | if (ret) { 275 | pthread_mutex_unlock(&table->mutex); 276 | return LUA_HANDLE_ERROR(L, ret); 277 | } 278 | ret = rb_load(table->L, table->rb); 279 | ret = lua_next(table->L, 1); 280 | if (ret == 0) { 281 | lua_pushnil(L); 282 | ret = pthread_mutex_unlock(&table->mutex); 283 | if (ret) return LUA_HANDLE_ERROR(L, ret); 284 | return 1; 285 | } 286 | 287 | ret = rb_save_with_growth(table->L, 2, table->rb, table->size_increment); 288 | if (ret) { 289 | lua_pop(table->L, 2); 290 | pthread_mutex_unlock(&table->mutex); 291 | return LUA_HANDLE_ERROR(L, ret); 292 | } 293 | ret = rb_load(L, table->rb); 294 | 295 | ret = rb_save_with_growth(table->L, 3, table->rb, table->size_increment); 296 | if (ret) { 297 | lua_pop(table->L, 2); 298 | lua_pop(L, 1); 299 | pthread_mutex_unlock(&table->mutex); 300 | return LUA_HANDLE_ERROR(L, ret); 301 | } 302 | ret = rb_load(L, table->rb); 303 | 304 | lua_pop(table->L, 2); 305 | ret = pthread_mutex_unlock(&table->mutex); 306 | if (ret) return LUA_HANDLE_ERROR(L, ret); 307 | 308 | return 2; 309 | } 310 | 311 | int sharedtable_pairs(lua_State *L) { 312 | lua_pushcfunction(L, sharedtable_next); 313 | lua_pushvalue(L, 1); 314 | lua_pushnil(L); 315 | return 3; 316 | } 317 | 318 | int sharedtable_metatablename(lua_State *L) { 319 | lua_pushstring(L, "ipc.sharedtable"); 320 | return 1; 321 | } 322 | 323 | int sharedtable_size(lua_State *L) { 324 | sharedtable_t **ptable = (sharedtable_t **)lua_touserdata(L, 1); 325 | sharedtable_t *table = *ptable; 326 | 327 | int ret = pthread_mutex_lock(&table->mutex); 328 | if (ret) { 329 | pthread_mutex_unlock(&table->mutex); 330 | return LUA_HANDLE_ERROR(L, ret); 331 | } 332 | 333 | double count1 = lua_gc(table->L, LUA_GCCOUNT, 0); 334 | double count2 = lua_gc(table->L, LUA_GCCOUNTB, 0); 335 | double count = count1 + count2/1024; 336 | lua_pushnumber(L, count); 337 | 338 | ret = pthread_mutex_unlock(&table->mutex); 339 | if (ret) return LUA_HANDLE_ERROR(L, ret); 340 | 341 | return 1; 342 | } 343 | -------------------------------------------------------------------------------- /src/sharedtable.h: -------------------------------------------------------------------------------- 1 | #ifndef _SHAREDTABLE_H_ 2 | #define _SHAREDTABLE_H_ 3 | 4 | #include "luaT.h" 5 | 6 | int sharedtable_create(lua_State *L); 7 | int sharedtable_retain(lua_State *L); 8 | int sharedtable_gc(lua_State *L); 9 | int sharedtable_read(lua_State *L); 10 | int sharedtable_write(lua_State *L); 11 | int sharedtable_len(lua_State *L); 12 | int sharedtable_pairs(lua_State *L); 13 | int sharedtable_metatablename(lua_State *L); 14 | int sharedtable_size(lua_State *L); 15 | 16 | #endif 17 | -------------------------------------------------------------------------------- /src/spawn.c: -------------------------------------------------------------------------------- 1 | #include "luaT.h" 2 | #include "lualib.h" 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include "error.h" 12 | 13 | typedef struct spawn_t { 14 | pid_t pid; 15 | int fd[2][2]; 16 | posix_spawn_file_actions_t file_actions; 17 | posix_spawnattr_t spawnattr; 18 | } spawn_t; 19 | 20 | void spawn_destroy(spawn_t *spawn) { 21 | for (int i = 0; i < 2; i++) { 22 | for (int j = 0; j < 2; j++) { 23 | if (spawn->fd[i][j]) { 24 | close(spawn->fd[i][j]); 25 | spawn->fd[i][j] = 0; 26 | } 27 | } 28 | } 29 | posix_spawn_file_actions_destroy(&spawn->file_actions); 30 | posix_spawnattr_destroy(&spawn->spawnattr); 31 | free(spawn); 32 | } 33 | 34 | int spawn_open(lua_State *L) { 35 | if (lua_gettop(L) != 1 || lua_type(L, 1) != LUA_TTABLE) return LUA_HANDLE_ERROR_STR(L, "expected a single table argument"); 36 | 37 | lua_pushstring(L, "file"); 38 | lua_gettable(L, 1); 39 | if (lua_type(L, -1) != LUA_TSTRING) { 40 | return LUA_HANDLE_ERROR_STR(L, "file: expected a string"); 41 | } 42 | const char *file = lua_tostring(L, -1); 43 | lua_pop(L, 1); 44 | 45 | lua_pushstring(L, "args"); 46 | lua_gettable(L, 1); 47 | if (lua_type(L, -1) != LUA_TNIL && lua_type(L, -1) != LUA_TTABLE) { 48 | return LUA_HANDLE_ERROR_STR(L, "args: expected a table, or nil"); 49 | } 50 | size_t n = lua_objlen(L, -1); 51 | char **argv = alloca(sizeof(char *) * (n + 2)); 52 | argv[0] = (char *)file; 53 | for (size_t i = 1; i <= n; i++) { 54 | lua_rawgeti(L, -1, i); 55 | argv[i] = (char *)lua_tostring(L, -1); 56 | lua_pop(L, 1); 57 | } 58 | argv[n + 1] = NULL; 59 | lua_pop(L, 1); 60 | 61 | lua_pushstring(L, "env"); 62 | lua_gettable(L, 1); 63 | n = lua_objlen(L, -1); 64 | extern char **environ; 65 | char **envp = environ; 66 | if (n > 0) { 67 | envp = alloca(sizeof(char *) * (n + 1)); 68 | for (size_t i = 1; i <= n; i++) { 69 | lua_rawgeti(L, -1, i); 70 | envp[i - 1] = (char *)lua_tostring(L, -1); 71 | lua_pop(L, 1); 72 | } 73 | envp[n] = NULL; 74 | } 75 | lua_pop(L, 1); 76 | 77 | spawn_t *spawn = calloc(sizeof(spawn_t), 1); 78 | 79 | int ret = posix_spawn_file_actions_init(&spawn->file_actions); 80 | if (ret) { 81 | spawn_destroy(spawn); 82 | return LUA_HANDLE_ERROR(L, errno); 83 | } 84 | 85 | for (int i = 0; i < 2; i++) { 86 | ret = pipe(spawn->fd[i]); 87 | if (ret) { 88 | spawn_destroy(spawn); 89 | return LUA_HANDLE_ERROR(L, errno); 90 | } 91 | int rw = i == 0 ? 0 : 1; 92 | ret = posix_spawn_file_actions_adddup2(&spawn->file_actions, spawn->fd[i][rw], i); 93 | if (ret) { 94 | spawn_destroy(spawn); 95 | return LUA_HANDLE_ERROR(L, errno); 96 | } 97 | ret = posix_spawn_file_actions_addclose(&spawn->file_actions, spawn->fd[i][rw]); 98 | if (ret) { 99 | spawn_destroy(spawn); 100 | return LUA_HANDLE_ERROR(L, errno); 101 | } 102 | ret = posix_spawn_file_actions_addclose(&spawn->file_actions, spawn->fd[i][!rw]); 103 | if (ret) { 104 | spawn_destroy(spawn); 105 | return LUA_HANDLE_ERROR(L, errno); 106 | } 107 | } 108 | 109 | ret = posix_spawnattr_init(&spawn->spawnattr); 110 | if (ret) { 111 | spawn_destroy(spawn); 112 | return LUA_HANDLE_ERROR(L, errno); 113 | } 114 | 115 | ret = posix_spawnp(&spawn->pid, file, &spawn->file_actions, &spawn->spawnattr, argv, envp); 116 | if (ret) { 117 | spawn_destroy(spawn); 118 | return LUA_HANDLE_ERROR(L, errno); 119 | } 120 | 121 | ret = close(spawn->fd[0][0]); 122 | if (ret) { 123 | spawn_destroy(spawn); 124 | return LUA_HANDLE_ERROR(L, errno); 125 | } 126 | spawn->fd[0][0] = 0; 127 | ret = close(spawn->fd[1][1]); 128 | if (ret) { 129 | spawn_destroy(spawn); 130 | return LUA_HANDLE_ERROR(L, errno); 131 | } 132 | spawn->fd[1][1] = 0; 133 | 134 | spawn_t **uspawn = lua_newuserdata(L, sizeof(spawn_t *)); 135 | *uspawn = spawn; 136 | luaL_getmetatable(L, "ipc.spawn"); 137 | lua_setmetatable(L, -2); 138 | return 1; 139 | } 140 | 141 | int spawn_wait(lua_State *L) { 142 | spawn_t **uspawn = lua_touserdata(L, 1); 143 | if (!*uspawn) return LUA_HANDLE_ERROR_STR(L, "spawn was already closed"); 144 | spawn_t *spawn = *uspawn; 145 | // optional signal to send to child process 146 | const char *signame = lua_tostring(L, 2); 147 | if (signame) { 148 | int which; 149 | if (strcmp(signame, "KILL") == 0) { 150 | which = SIGKILL; 151 | } else if (strcmp(signame, "TERM") == 0) { 152 | which = SIGTERM; 153 | } else { 154 | return LUA_HANDLE_ERROR_STR(L, "unknown signal"); 155 | } 156 | int ret = kill(spawn->pid, which); 157 | if (ret) return LUA_HANDLE_ERROR(L, errno); 158 | } 159 | // close stdin 160 | if (spawn->fd[0][1]) { 161 | int ret = close(spawn->fd[0][1]); 162 | if (ret) return LUA_HANDLE_ERROR(L, errno); 163 | spawn->fd[0][1] = 0; 164 | } 165 | if (signame) { 166 | // just close stdout, we dont care at this point 167 | close((*uspawn)->fd[1][0]); 168 | (*uspawn)->fd[1][0] = 0; 169 | } else { 170 | // read whatever is left on stdout, we dont want the process to be stalled on us 171 | char buff[1024]; 172 | while (1) { 173 | ssize_t x = read((*uspawn)->fd[1][0], buff, 1024); 174 | if (x < 0) { 175 | return LUA_HANDLE_ERROR(L, errno); 176 | } else if (x == 0) { 177 | break; 178 | } 179 | } 180 | } 181 | // wait for exit 182 | int status; 183 | do { 184 | int ret = waitpid(spawn->pid, &status, WUNTRACED | WCONTINUED); 185 | if (ret < 0) return LUA_HANDLE_ERROR(L, errno); 186 | } while (!WIFEXITED(status) && !WIFSIGNALED(status)); 187 | // clean up 188 | spawn_destroy(spawn); 189 | *uspawn = NULL; 190 | // return the exit code 191 | lua_pushnumber(L, WEXITSTATUS(status)); 192 | return 1; 193 | } 194 | 195 | int spawn_stdin(lua_State *L) { 196 | spawn_t **uspawn = lua_touserdata(L, 1); 197 | if (!*uspawn) return LUA_HANDLE_ERROR_STR(L, "spawn was already closed"); 198 | if (lua_gettop(L) == 1) { 199 | // close stdin 200 | int ret = close((*uspawn)->fd[0][1]); 201 | if (ret) return LUA_HANDLE_ERROR(L, errno); 202 | (*uspawn)->fd[0][1] = 0; 203 | return 0; 204 | } 205 | // write to stdin 206 | size_t str_len; 207 | const char *str = lua_tolstring(L, 2, &str_len); 208 | size_t cb = 0; 209 | while (cb < str_len) { 210 | ssize_t x = write((*uspawn)->fd[0][1], str + cb, str_len - cb); 211 | if (x < 0) { 212 | return LUA_HANDLE_ERROR(L, errno); 213 | } else { 214 | cb += x; 215 | } 216 | } 217 | return 0; 218 | } 219 | 220 | int spawn_stdout(lua_State *L) { 221 | spawn_t **uspawn = lua_touserdata(L, 1); 222 | if (!*uspawn) return LUA_HANDLE_ERROR_STR(L, "spawn was already closed"); 223 | int type = lua_type(L, 2); 224 | if (type == LUA_TNUMBER) { 225 | // read some number of bytes from stdout 226 | size_t cb = lua_tonumber(L, 2); 227 | char *buff = calloc(cb + 1, 1); 228 | ssize_t n = read((*uspawn)->fd[1][0], buff, cb); 229 | if (n < 0) { 230 | free(buff); 231 | return LUA_HANDLE_ERROR(L, errno); 232 | } 233 | if (n > 0) { 234 | lua_pushlstring(L, buff, n); 235 | free(buff); 236 | return 1; 237 | } else { 238 | free(buff); 239 | return 0; 240 | } 241 | } 242 | const char *arg = luaL_optstring(L, 2, "*l"); 243 | if (strncmp(arg, "*l", 2) == 0) { 244 | // read stdout until EOL 245 | size_t cb = 0; 246 | size_t max_cb = 1024; 247 | char *buff = realloc(NULL, max_cb); 248 | while (1) { 249 | ssize_t x = read((*uspawn)->fd[1][0], buff + cb, 1); 250 | if (x < 0) { 251 | free(buff); 252 | return LUA_HANDLE_ERROR(L, errno); 253 | } else if (x == 0) { 254 | if (cb > 0) { 255 | buff[cb] = 0; 256 | lua_pushlstring(L, buff, cb); 257 | free(buff); 258 | return 1; 259 | } else { 260 | free(buff); 261 | return 0; 262 | } 263 | } else { 264 | if (buff[cb] == '\n') { 265 | buff[cb] = 0; 266 | lua_pushlstring(L, buff, cb); 267 | free(buff); 268 | return 1; 269 | } 270 | cb++; 271 | if (cb + 1 == max_cb) { 272 | max_cb += 1024; 273 | buff = realloc(buff, max_cb); 274 | } 275 | } 276 | } 277 | } else { 278 | // read stdout until EOF 279 | size_t cb = 0; 280 | size_t max_cb = 1024; 281 | char *buff = realloc(NULL, max_cb); 282 | while (1) { 283 | ssize_t x = read((*uspawn)->fd[1][0], buff + cb, max_cb - cb - 1); 284 | if (x < 0) { 285 | free(buff); 286 | return LUA_HANDLE_ERROR(L, errno); 287 | } else if (x == 0) { 288 | if (cb > 0) { 289 | buff[cb] = 0; 290 | lua_pushlstring(L, buff, cb); 291 | free(buff); 292 | return 1; 293 | } else { 294 | free(buff); 295 | return 0; 296 | } 297 | } else { 298 | cb += x; 299 | if (cb + 1 == max_cb) { 300 | max_cb += 1024; 301 | buff = realloc(buff, max_cb); 302 | } 303 | } 304 | } 305 | } 306 | } 307 | 308 | int spawn_stdout_file_id(lua_State *L) { 309 | spawn_t **uspawn = lua_touserdata(L, 1); 310 | if (!*uspawn) return LUA_HANDLE_ERROR_STR(L, "spawn was already closed"); 311 | lua_pushinteger(L, (*uspawn)->fd[1][0]); 312 | return 1; 313 | } 314 | 315 | int spawn_pid(lua_State *L) { 316 | spawn_t **uspawn = lua_touserdata(L, 1); 317 | if (!*uspawn) return LUA_HANDLE_ERROR_STR(L, "spawn was already closed"); 318 | lua_pushnumber(L, (*uspawn)->pid); 319 | return 1; 320 | } 321 | 322 | int spawn_running(lua_State *L) { 323 | spawn_t **uspawn = lua_touserdata(L, 1); 324 | if (!*uspawn) return LUA_HANDLE_ERROR_STR(L, "spawn was already closed"); 325 | siginfo_t si; 326 | memset(&si, 0, sizeof(si)); 327 | int ret = waitid(P_PID, (*uspawn)->pid, &si, WEXITED | WNOHANG | WNOWAIT); 328 | if (ret) return LUA_HANDLE_ERROR(L, errno); 329 | lua_pushboolean(L, si.si_pid == 0); 330 | return 1; 331 | } 332 | 333 | int spawn_gc(lua_State *L) { 334 | spawn_t **uspawn = lua_touserdata(L, 1); 335 | if (*uspawn) { 336 | fprintf(stderr, "ipc.spawn being garbage collected before wait was called, sending SIGTERM to child process"); 337 | int ret = kill((*uspawn)->pid, SIGTERM); 338 | if (ret) return LUA_HANDLE_ERROR(L, errno); 339 | spawn_wait(L); 340 | lua_pop(L, 1); 341 | } 342 | return 0; 343 | } 344 | -------------------------------------------------------------------------------- /src/spawn.h: -------------------------------------------------------------------------------- 1 | #ifndef _SPAWN_H_ 2 | #define _SPAWN_H_ 3 | 4 | #include "luaT.h" 5 | 6 | int spawn_open(lua_State *L); 7 | int spawn_wait(lua_State *L); 8 | int spawn_stdin(lua_State *L); 9 | int spawn_stdout(lua_State *L); 10 | int spawn_stdout_file_id(lua_State *L); 11 | int spawn_pid(lua_State *L); 12 | int spawn_running(lua_State *L); 13 | int spawn_gc(lua_State *L); 14 | 15 | #endif 16 | -------------------------------------------------------------------------------- /src/workqueue.c: -------------------------------------------------------------------------------- 1 | #include "TH.h" 2 | #include "luaT.h" 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "ringbuffer.h" 9 | #include "serialize.h" 10 | #include "error.h" 11 | 12 | #define DEFAULT_WORKQUEUE_SIZE (16*1024) 13 | 14 | #define TOO_TRICKY (0) 15 | #define WORKQUEUE_VERBOSE (0) 16 | 17 | typedef struct queue_t { 18 | struct ringbuffer_t* rb; 19 | pthread_mutex_t mutex; 20 | pthread_cond_t read_avail_cond; 21 | #if TOO_TRICKY 22 | pthread_cond_t write_avail_cond; 23 | #endif 24 | uint32_t num_items; 25 | } queue_t; 26 | 27 | typedef struct workqueue_t { 28 | struct workqueue_t *next; 29 | struct workqueue_t *prev; 30 | int refcount; 31 | size_t size_increment; 32 | const char *name; 33 | queue_t questions; 34 | queue_t answers; 35 | pthread_t owner_thread; 36 | pthread_mutex_t mutex; 37 | } workqueue_t; 38 | 39 | static pthread_once_t workqueue_once = PTHREAD_ONCE_INIT; 40 | static pthread_mutex_t workqueue_mutex; 41 | static workqueue_t *workqueue_head; 42 | 43 | static void workqueue_one_time_init_inner() { 44 | pthread_mutexattr_t mutex_attr; 45 | pthread_mutexattr_init(&mutex_attr); 46 | pthread_mutexattr_settype(&mutex_attr, PTHREAD_MUTEX_RECURSIVE); 47 | pthread_mutex_init(&workqueue_mutex, &mutex_attr); 48 | workqueue_head = NULL; 49 | } 50 | 51 | static void workqueue_one_time_init() { 52 | pthread_once(&workqueue_once, workqueue_one_time_init_inner); 53 | } 54 | 55 | static void workqueue_insert(workqueue_t *workqueue) { 56 | if (workqueue->name == NULL) 57 | return; 58 | if (workqueue_head) { 59 | workqueue_head->prev = workqueue; 60 | workqueue->next = workqueue_head; 61 | } 62 | workqueue_head = workqueue; 63 | } 64 | 65 | static void workqueue_remove(workqueue_t *workqueue) { 66 | if (workqueue->name == NULL) 67 | return; 68 | if (workqueue_head == workqueue) { 69 | workqueue_head = workqueue->next; 70 | } 71 | if (workqueue->next) { 72 | workqueue->next->prev = workqueue->prev; 73 | } 74 | if (workqueue->prev) { 75 | workqueue->prev->next = workqueue->next; 76 | } 77 | } 78 | 79 | static workqueue_t *workqueue_find(const char *name) { 80 | if (name == NULL) 81 | return NULL; 82 | workqueue_t *workqueue = workqueue_head; 83 | while (workqueue && strcmp(workqueue->name, name) != 0) { 84 | workqueue = workqueue->next; 85 | } 86 | if (workqueue) { 87 | THAtomicIncrementRef(&workqueue->refcount); 88 | } 89 | return workqueue; 90 | } 91 | 92 | static void workqueue_init_queue(queue_t *queue, size_t size) { 93 | pthread_mutexattr_t mutex_attr; 94 | pthread_mutexattr_init(&mutex_attr); 95 | pthread_mutexattr_settype(&mutex_attr, PTHREAD_MUTEX_RECURSIVE); 96 | pthread_mutex_init(&queue->mutex, &mutex_attr); 97 | pthread_cond_init(&queue->read_avail_cond, NULL); 98 | #if TOO_TRICKY 99 | pthread_cond_init(&queue->write_avail_cond, NULL); 100 | #endif 101 | queue->rb = ringbuffer_create(size); 102 | } 103 | 104 | static void workqueue_destroy_queue(queue_t *queue) { 105 | pthread_mutex_destroy(&queue->mutex); 106 | pthread_cond_destroy(&queue->read_avail_cond); 107 | #if TOO_TRICKY 108 | pthread_cond_destroy(&queue->write_avail_cond); 109 | #endif 110 | ringbuffer_destroy(queue->rb); 111 | } 112 | 113 | int workqueue_open(lua_State *L) { 114 | workqueue_one_time_init(); 115 | const char *name = luaL_optlstring(L, 1, NULL, NULL); 116 | size_t size = luaL_optnumber(L, 2, DEFAULT_WORKQUEUE_SIZE); 117 | size_t size_increment = luaL_optnumber(L, 3, size); 118 | pthread_mutex_lock(&workqueue_mutex); 119 | workqueue_t *workqueue = workqueue_find(name); 120 | int creator = 0; 121 | if (!workqueue) { 122 | creator = 1; 123 | workqueue = (workqueue_t *)calloc(1, sizeof(workqueue_t)); 124 | workqueue->refcount = 1; 125 | workqueue->size_increment = size_increment; 126 | if (name == NULL) 127 | workqueue->name = NULL; 128 | else 129 | workqueue->name = strdup(name); 130 | workqueue_init_queue(&workqueue->questions, size); 131 | workqueue_init_queue(&workqueue->answers, size); 132 | workqueue->owner_thread = pthread_self(); 133 | pthread_mutexattr_t mutex_attr; 134 | pthread_mutexattr_init(&mutex_attr); 135 | pthread_mutexattr_settype(&mutex_attr, PTHREAD_MUTEX_RECURSIVE); 136 | pthread_mutex_init(&workqueue->mutex, &mutex_attr); 137 | workqueue_insert(workqueue); 138 | } 139 | pthread_mutex_unlock(&workqueue_mutex); 140 | workqueue_t** ud = (workqueue_t **)lua_newuserdata(L, sizeof(workqueue_t*)); 141 | *ud = workqueue; 142 | luaL_getmetatable(L, "ipc.workqueue"); 143 | lua_setmetatable(L, -2); 144 | lua_pushinteger(L, creator); 145 | return 2; 146 | } 147 | 148 | int workqueue_queue_read(lua_State *L, queue_t *queue, int doNotBlock) { 149 | pthread_mutex_lock(&queue->mutex); 150 | while (1) { 151 | if (queue->num_items) { 152 | int ret = rb_load(L, queue->rb); 153 | queue->num_items--; 154 | #if TOO_TRICKY 155 | pthread_cond_signal(&queue->write_avail_cond); 156 | #endif 157 | pthread_mutex_unlock(&queue->mutex); 158 | if (ret < 0) return LUA_HANDLE_ERROR(L, ret); 159 | return ret; 160 | } else if (doNotBlock) { 161 | break; 162 | } else { 163 | pthread_cond_wait(&queue->read_avail_cond, &queue->mutex); 164 | } 165 | } 166 | pthread_mutex_unlock(&queue->mutex); 167 | return 0; 168 | } 169 | 170 | int workqueue_read(lua_State *L) { 171 | workqueue_t *workqueue = *(workqueue_t **)lua_touserdata(L, 1); 172 | if (!workqueue) return LUA_HANDLE_ERROR_STR(L, "workqueue is not open"); 173 | int doNotBlock = luaT_optboolean(L, 2, 0); 174 | if (workqueue->owner_thread == pthread_self()) { 175 | return workqueue_queue_read(L, &workqueue->answers, doNotBlock); 176 | } else { 177 | return workqueue_queue_read(L, &workqueue->questions, doNotBlock); 178 | } 179 | } 180 | 181 | static int workqueue_queue_write(lua_State *L, int index, queue_t *queue, size_t size_increment, int upval) { 182 | pthread_mutex_lock(&queue->mutex); 183 | int top = lua_gettop(L); 184 | while (index <= top) { 185 | ringbuffer_push_write_pos(queue->rb); 186 | int ret = rb_save(L, index, queue->rb, 0, upval); 187 | if (ret == -ENOMEM) { 188 | ringbuffer_pop_write_pos(queue->rb); 189 | #if TOO_TRICKY 190 | if (ringbuffer_peek(queue->rb)) { 191 | pthread_cond_wait(&queue->write_avail_cond, &queue->mutex); 192 | } else 193 | #endif 194 | { 195 | ringbuffer_grow_by(queue->rb, size_increment); 196 | #if WORKQUEUE_VERBOSE 197 | fprintf(stderr, "INFO: ipc.workqueue grew to %zu bytes\n", queue->rb->cb); 198 | #endif 199 | } 200 | } else if (ret) { 201 | ringbuffer_pop_write_pos(queue->rb); 202 | pthread_mutex_unlock(&queue->mutex); 203 | return LUA_HANDLE_ERROR(L, -ret); 204 | } else { 205 | index++; 206 | queue->num_items++; 207 | } 208 | } 209 | pthread_cond_signal(&queue->read_avail_cond); 210 | pthread_mutex_unlock(&queue->mutex); 211 | return 0; 212 | } 213 | 214 | int workqueue_write(lua_State *L) { 215 | workqueue_t *workqueue = *(workqueue_t **)lua_touserdata(L, 1); 216 | if (!workqueue) return LUA_HANDLE_ERROR_STR(L, "workqueue is not open"); 217 | if (workqueue->owner_thread == pthread_self()) { 218 | return workqueue_queue_write(L, 2, &workqueue->questions, workqueue->size_increment, 0); 219 | } else { 220 | return workqueue_queue_write(L, 2, &workqueue->answers, workqueue->size_increment, 0); 221 | } 222 | } 223 | 224 | int workqueue_writeup(lua_State *L) { 225 | workqueue_t *workqueue = *(workqueue_t **)lua_touserdata(L, 1); 226 | if (!workqueue) return LUA_HANDLE_ERROR_STR(L, "workqueue is not open"); 227 | if (workqueue->owner_thread == pthread_self()) { 228 | return workqueue_queue_write(L, 2, &workqueue->questions, workqueue->size_increment, 1); 229 | } else { 230 | return workqueue_queue_write(L, 2, &workqueue->answers, workqueue->size_increment, 1); 231 | } 232 | } 233 | 234 | int workqueue_drain(lua_State *L) { 235 | workqueue_t *workqueue = *(workqueue_t **)lua_touserdata(L, 1); 236 | if (!workqueue) return LUA_HANDLE_ERROR_STR(L, "workqueue is not open"); 237 | if (workqueue->owner_thread != pthread_self()) return LUA_HANDLE_ERROR_STR(L, "workqueue drain is only available on the owner thread"); 238 | pthread_mutex_lock(&workqueue->questions.mutex); 239 | pthread_mutex_lock(&workqueue->answers.mutex); 240 | uint32_t mark = workqueue->answers.num_items + workqueue->questions.num_items; 241 | pthread_mutex_unlock(&workqueue->questions.mutex); 242 | while (workqueue->answers.num_items < mark) { 243 | pthread_cond_wait(&workqueue->answers.read_avail_cond, &workqueue->answers.mutex); 244 | } 245 | pthread_mutex_unlock(&workqueue->answers.mutex); 246 | return 0; 247 | } 248 | 249 | int workqueue_close(lua_State *L) { 250 | workqueue_t **ud = (workqueue_t **)lua_touserdata(L, 1); 251 | workqueue_t *workqueue = *ud; 252 | if (!workqueue) return LUA_HANDLE_ERROR_STR(L, "workqueue is already closed"); 253 | pthread_mutex_lock(&workqueue_mutex); 254 | if (THAtomicDecrementRef(&workqueue->refcount)) { 255 | workqueue_remove(workqueue); 256 | workqueue_destroy_queue(&workqueue->questions); 257 | workqueue_destroy_queue(&workqueue->answers); 258 | pthread_mutex_destroy(&workqueue->mutex); 259 | free((void *)workqueue->name); 260 | workqueue->name = NULL; 261 | free(workqueue); 262 | } 263 | *ud = NULL; 264 | pthread_mutex_unlock(&workqueue_mutex); 265 | return 0; 266 | } 267 | 268 | int workqueue_gc(lua_State *L) { 269 | pthread_mutex_lock(&workqueue_mutex); 270 | workqueue_t *workqueue = *(workqueue_t **)lua_touserdata(L, 1); 271 | if (workqueue) { 272 | workqueue_close(L); 273 | } 274 | pthread_mutex_unlock(&workqueue_mutex); 275 | return 0; 276 | } 277 | 278 | int workqueue_retain(lua_State *L) { 279 | workqueue_t *workqueue = *(workqueue_t **)lua_touserdata(L, 1); 280 | if (!workqueue) return LUA_HANDLE_ERROR_STR(L, "workqueue is not open"); 281 | if (workqueue->name == NULL) { 282 | pthread_mutex_lock(&workqueue_mutex); 283 | THAtomicIncrementRef(&workqueue->refcount); 284 | pthread_mutex_unlock(&workqueue_mutex); 285 | } 286 | return 0; 287 | } 288 | 289 | int workqueue_metatablename(lua_State *L) { 290 | lua_pushstring(L, "ipc.workqueue"); 291 | return 1; 292 | } 293 | -------------------------------------------------------------------------------- /src/workqueue.h: -------------------------------------------------------------------------------- 1 | #ifndef _WORKQUEUE_H_ 2 | #define _WORKQUEUE_H_ 3 | 4 | #include "luaT.h" 5 | 6 | int workqueue_open(lua_State *L); 7 | int workqueue_close(lua_State *L); 8 | int workqueue_read(lua_State *L); 9 | int workqueue_write(lua_State *L); 10 | int workqueue_writeup(lua_State *L); 11 | int workqueue_drain(lua_State *L); 12 | int workqueue_gc(lua_State *L); 13 | int workqueue_retain(lua_State *L); 14 | int workqueue_metatablename(lua_State *L); 15 | 16 | #endif 17 | -------------------------------------------------------------------------------- /test/test.lua: -------------------------------------------------------------------------------- 1 | require 'ipc.test_BackgroundTask' 2 | require 'ipc.test_BackgroundTaskPool' 3 | require 'ipc.test_cliser' 4 | require 'ipc.test_map' 5 | require 'ipc.test_spawn' 6 | require 'ipc.test_Tree' 7 | require 'ipc.test_workqueue' 8 | require 'ipc.test_mutex' 9 | require 'ipc.test_sharedtable' 10 | require 'ipc.test_channel' 11 | -------------------------------------------------------------------------------- /test/test_BackgroundTask.lua: -------------------------------------------------------------------------------- 1 | local test = require 'regress' 2 | local sys = require 'sys' 3 | local BackgroundTask = require 'ipc.BackgroundTask' 4 | 5 | test { 6 | testSimple = function() 7 | local task = BackgroundTask(function(t) 8 | local sys = require 'sys' 9 | sys.sleep(t) 10 | return 'done!', 42, true 11 | end, 0.7) 12 | local s,n,b = task.getResult() 13 | assert(s == 'done!') 14 | assert(n == 42) 15 | assert(b == true) 16 | end, 17 | 18 | testPolling = function() 19 | local task = BackgroundTask(function(t) 20 | local sys = require 'sys' 21 | sys.sleep(t) 22 | return 'done!', 42, true 23 | end, 1.3) 24 | local x = 0 25 | while not task.isDone() do 26 | x = x + 1 27 | sys.sleep(0.1) 28 | end 29 | assert(x > 9 and x < 16) 30 | local s,n,b = task.getResult() 31 | assert(s == 'done!') 32 | assert(n == 42) 33 | assert(b == true) 34 | end, 35 | 36 | testNoReturns = function() 37 | local task = BackgroundTask(function(t) 38 | local sys = require 'sys' 39 | sys.sleep(t) 40 | end, 0.4) 41 | local y = task.getResult() 42 | assert(y == nil) 43 | end, 44 | 45 | testError = function() 46 | local task = BackgroundTask(function(t) 47 | local sys = require 'sys' 48 | sys.sleep(t) 49 | error('die') 50 | return 43 51 | end, 0.4) 52 | assert(pcall(function() return task.getResult() end) == false) 53 | end, 54 | 55 | testErrorPolling = function() 56 | local task = BackgroundTask(function(t) 57 | local sys = require 'sys' 58 | sys.sleep(t) 59 | error('die') 60 | return 43 61 | end, 0.4) 62 | assert(not task.isDone()) 63 | sys.sleep(0.6) 64 | assert(task.isDone() == true) 65 | assert(pcall(function() return task.getResult() end) == false) 66 | end, 67 | } 68 | -------------------------------------------------------------------------------- /test/test_BackgroundTaskPool.lua: -------------------------------------------------------------------------------- 1 | local test = require 'regress' 2 | local sys = require 'sys' 3 | local BackgroundTaskPool = require 'ipc.BackgroundTaskPool' 4 | 5 | test { 6 | testSimple = function() 7 | local pool = BackgroundTaskPool(20, { closeOnLastTask = true }) 8 | for _ = 1,1000 do 9 | pool.addTask(function(t) 10 | local sys = require 'sys' 11 | sys.sleep(t) 12 | return math.random() 13 | end, math.random(1, 1000) / 1000) 14 | end 15 | for i = 1,1000 do 16 | assert(type(pool.getResult(i)) == 'number') 17 | end 18 | end, 19 | 20 | testPolling = function() 21 | local pool = BackgroundTaskPool(13, { closeOnLastTask = true }) 22 | for _ = 1,42 do 23 | pool.addTask(function(t) 24 | local sys = require 'sys' 25 | sys.sleep(t) 26 | return math.random() 27 | end, math.random(1, 1000) / 1000) 28 | end 29 | local x = 0 30 | while not pool.isDone() do 31 | x = x + 1 32 | sys.sleep(0.1) 33 | end 34 | assert(x > 10 and x < 40, x) 35 | for i = 1,42 do 36 | assert(type(pool.getResult(i)) == 'number') 37 | end 38 | end, 39 | 40 | testError = function() 41 | local pool = BackgroundTaskPool(7, { closeOnLastTask = true }) 42 | for i = 1,42 do 43 | pool.addTask(function(t, i) 44 | local sys = require 'sys' 45 | sys.sleep(t) 46 | if i % 8 == 1 then 47 | error('die') 48 | end 49 | return math.random() 50 | end, math.random(1, 1000) / 1000, i) 51 | end 52 | for i = 1,42 do 53 | if i % 8 == 1 then 54 | assert(pcall(function() return pool.getResult(i) end) == false) 55 | else 56 | assert(type(pool.getResult(i)) == 'number') 57 | end 58 | end 59 | end, 60 | 61 | testErrorPolling = function() 62 | local pool = BackgroundTaskPool(13, { closeOnLastTask = true }) 63 | for i = 1,42 do 64 | pool.addTask(function(t, i) 65 | local sys = require 'sys' 66 | sys.sleep(t) 67 | if i % 27 == 1 then 68 | error('die') 69 | end 70 | return math.random() 71 | end, math.random(1, 1000) / 1000, i) 72 | end 73 | local x = 0 74 | while not pool.isDone() do 75 | x = x + 1 76 | sys.sleep(0.1) 77 | end 78 | assert(x > 10 and x < 40, x) 79 | for i = 1,42 do 80 | if i % 27 == 1 then 81 | assert(pcall(function() return pool.getResult(i) end) == false) 82 | else 83 | assert(type(pool.getResult(i)) == 'number') 84 | end 85 | end 86 | end, 87 | } 88 | -------------------------------------------------------------------------------- /test/test_Tree.lua: -------------------------------------------------------------------------------- 1 | local test = require 'regress' 2 | local ipc = require 'libipc' 3 | local Tree = require 'ipc.Tree' 4 | 5 | local function testAllReduce(njobs, base, makeValue, reduce) 6 | local server, port = ipc.server('127.0.0.1') 7 | local m = ipc.map(njobs - 1, function(njobs, base, port, makeValue, reduce, mapid) 8 | local ipc = require 'libipc' 9 | local Tree = require 'ipc.Tree' 10 | local client = ipc.client('127.0.0.1', port) 11 | local jobid = mapid + 1 12 | local tree = Tree(jobid, njobs, base, nil, client, '127.0.0.1') 13 | local value = makeValue(jobid) 14 | local value = tree.allReduce(value, reduce) 15 | return value 16 | end, njobs, base, port, makeValue, reduce) 17 | server:clients(njobs - 1, function(client) end) 18 | local tree = Tree(1, njobs, base, server, nil, '127.0.0.1', port) 19 | local value = makeValue(1) 20 | local final = tree.allReduce(value, reduce) 21 | local ret = { m:join() } 22 | table.insert(ret, 1, final) 23 | return ret 24 | end 25 | 26 | test { 27 | testTreeNumbersBase2 = function() 28 | local ret = testAllReduce(8, 2, 29 | function(jobid) return jobid end, 30 | function(a, b) return a + b end) 31 | test.mustBeTrue(#ret == 8, 'expected 8 results, not '..#ret) 32 | for _,rv in ipairs(ret) do 33 | test.mustBeTrue(rv == 36, 'expected final value of 36, not '..rv) 34 | end 35 | end, 36 | 37 | testTreeNumbersArrayBase2 = function() 38 | local ret = testAllReduce(4, 2, 39 | function(jobid) return { jobid, 2 * jobid } end, 40 | function(a, b) return a + b end) 41 | test.mustBeTrue(#ret == 4, 'expected 4 results, not '..#ret) 42 | for _,rv in ipairs(ret) do 43 | test.mustBeTrue(rv[1] == 10, 'expected final value of 10, not '..rv[1]) 44 | test.mustBeTrue(rv[2] == 20, 'expected final value of 20, not '..rv[2]) 45 | end 46 | end, 47 | 48 | testTreeTensorsBase2 = function() 49 | local ret = testAllReduce(2, 2, 50 | function(jobid) 51 | return torch.Tensor(10):fill(jobid) 52 | end, 53 | function(a, b) 54 | a:add(b) 55 | return a 56 | end) 57 | test.mustBeTrue(#ret == 2, 'expected 8 results, not '..#ret) 58 | for _,rv in ipairs(ret) do 59 | test.mustBeTrue(rv:sum() == 30, 'expected final value of 360, not '..rv:sum()) 60 | end 61 | end, 62 | 63 | testTreeTensorsBase4 = function() 64 | local ret = testAllReduce(8, 4, 65 | function(jobid) 66 | return torch.Tensor(10):fill(jobid) 67 | end, 68 | function(a, b) return a:add(b) end) 69 | test.mustBeTrue(#ret == 8, 'expected 8 results, not '..#ret) 70 | for _,rv in ipairs(ret) do 71 | test.mustBeTrue(rv:sum() == 360, 'expected final value of 360, not '..rv:sum()) 72 | end 73 | end, 74 | 75 | testTreeTensorsBase8 = function() 76 | local ret = testAllReduce(8, 8, 77 | function(jobid) 78 | return torch.Tensor(10):fill(jobid) 79 | end, 80 | function(a, b) return a + b end) 81 | test.mustBeTrue(#ret == 8, 'expected 8 results, not '..#ret) 82 | for _,rv in ipairs(ret) do 83 | test.mustBeTrue(rv:sum() == 360, 'expected final value of 360, not '..rv:sum()) 84 | end 85 | end, 86 | 87 | testUnevenNumberOfSteps = function() 88 | local function expected(n, ni) 89 | local c = 0 90 | for i = 1,ni do 91 | for j = i,n do 92 | c = c + j 93 | end 94 | end 95 | return c 96 | end 97 | local function reduce(a, b) return a + b end 98 | local function zero() return 0 end 99 | local function loop(njobs, jobid, tree, reduce, zero) 100 | local value = 0 101 | for i = 1,jobid do 102 | value = value + tree.allReduce(jobid, reduce) 103 | end 104 | value = value + tree.allReduce(nil, reduce, zero) 105 | return value 106 | end 107 | local njobs = 4 108 | local base = 2 109 | local server, port = ipc.server('127.0.0.1') 110 | local m = ipc.map(njobs - 1, function(njobs, base, port, reduce, zero, loop, mapid) 111 | local ipc = require 'libipc' 112 | local Tree = require 'ipc.Tree' 113 | local client = ipc.client('127.0.0.1', port) 114 | local jobid = mapid + 1 115 | local tree = Tree(jobid, njobs, base, nil, client, '127.0.0.1') 116 | return loop(njobs, jobid, tree, reduce, zero) 117 | end, njobs, base, port, reduce, zero, loop) 118 | server:clients(njobs - 1, function(client) end) 119 | local tree = Tree(1, njobs, base, server, nil, '127.0.0.1') 120 | local final = loop(njobs, 1, tree, reduce, zero) 121 | local ret = { m:join() } 122 | table.insert(ret, 1, final) 123 | test.mustBeTrue(#ret == njobs, 'expected '..njobs..' results, not '..#ret) 124 | for i,rv in ipairs(ret) do 125 | local e = expected(njobs, i) 126 | test.mustBeTrue(rv == e, 'expected final value of '..e..', not '..rv) 127 | end 128 | end, 129 | 130 | testTreeMultipleTensors = function() 131 | local ret = testAllReduce(8, 2, 132 | function(jobid) 133 | return { torch.Tensor(10):fill(jobid), torch.Tensor(10):fill(jobid - 1) } 134 | end, 135 | function(a, b) 136 | a:add(b) 137 | return a 138 | end) 139 | test.mustBeTrue(#ret == 8, 'expected 8 results, not '..#ret) 140 | for _,rv in ipairs(ret) do 141 | test.mustBeTrue(rv[1]:sum() == 360, 'expected final value of 360, not '..rv[1]:sum()) 142 | test.mustBeTrue(rv[2]:sum() == 280, 'expected final value of 280, not '..rv[2]:sum()) 143 | end 144 | end, 145 | 146 | testScatter = function() 147 | local njobs = 4 148 | local base = 2 149 | local server, port = ipc.server('127.0.0.1') 150 | local m = ipc.map(njobs - 1, function(njobs, base, port, mapid) 151 | local ipc = require 'libipc' 152 | local Tree = require 'ipc.Tree' 153 | local client = ipc.client('127.0.0.1', port) 154 | local jobid = mapid + 1 155 | local tree = Tree(jobid, njobs, base, nil, client, '127.0.0.1') 156 | return tree.scatter(jobid) 157 | end, njobs, base, port) 158 | server:clients(njobs - 1, function(client) end) 159 | local tree = Tree(1, njobs, base, server, nil, '127.0.0.1') 160 | local final = tree.scatter(1) 161 | local ret = { m:join() } 162 | table.insert(ret, 1, final) 163 | test.mustBeTrue(#ret == njobs, 'expected '..njobs..' results, not '..#ret) 164 | for _,rv in ipairs(ret) do 165 | test.mustBeTrue(rv == 1, 'expected final value of 1, not '..rv) 166 | end 167 | end, 168 | } 169 | -------------------------------------------------------------------------------- /test/test_cliser.lua: -------------------------------------------------------------------------------- 1 | local test = require 'regress' 2 | pcall(require, 'cutorch') 3 | local ipc = require 'libipc' 4 | 5 | local function testCSN(numClients, test, callbackS, callbackC, extra) 6 | local server,port = ipc.server() 7 | local clients = ipc.map(numClients, function(port, callbackC, extra) 8 | local sys = require 'sys' 9 | local ipc = require 'libipc' 10 | local client = ipc.client(port) 11 | callbackC(client, extra) 12 | assert(client:recv() == "bye") 13 | client:close() 14 | return true 15 | end, port, callbackC, extra) 16 | callbackS(server) 17 | server:broadcast("bye") 18 | local ret = clients:join() 19 | server:close() 20 | local passed = type(ret) == 'boolean' and ret == true 21 | local msg = (type(ret) == 'string' and ret) or 'client failed with an unknown error' 22 | assert(passed, msg) 23 | end 24 | 25 | local function testCS(test, callbackS, callbackC, extra) 26 | return testCSN(1, test, callbackS, callbackC, extra) 27 | end 28 | 29 | local function testT(t0, t1, eq) 30 | local t2 = t0 31 | if t2:type() == 'torch.CudaTensor' then 32 | t2 = t2:float() 33 | end 34 | testCS(test, 35 | function(server) 36 | server:clients(1, function(client) 37 | assert(torch.all(torch.eq(t0, t1)) == false, "should not match before recv") 38 | client:recv(t1) 39 | local feq = eq or function() 40 | assert(torch.all(torch.eq(t0, t1)), "should match after recv") 41 | end 42 | feq(t0, t1) 43 | end) 44 | end, 45 | function(client, t2) 46 | client:send(t2) 47 | end, t2) 48 | end 49 | 50 | local function testTF(t0, t1) 51 | testCS(test, 52 | function(server) 53 | server:clients(1, function(client) 54 | local ok = pcall(function() client:recv(t1) end) 55 | assert(ok == false, "recv should have failed") 56 | end) 57 | end, 58 | function(client, t0) 59 | client:send(t0) 60 | end, t0) 61 | end 62 | 63 | test { 64 | testListenAndConnect = function() 65 | testCS(test, 66 | function(server) 67 | server:clients(1, function(client) 68 | local tag = client:recv() 69 | client:tag(tag) 70 | assert(client:tag() == "hi") 71 | local id = client:recv() 72 | client:id(id) 73 | assert(client:id() == 42) 74 | end) 75 | end, 76 | function(client) 77 | client:send("hi") 78 | client:send(42) 79 | end) 80 | end, 81 | 82 | testSlowStart = function() 83 | local m = ipc.map(1, function() 84 | local ipc = require 'libipc' 85 | return ipc.client('127.0.0.1', 8080) 86 | end) 87 | sys.sleep(2) 88 | local server = ipc.server('127.0.0.1', 8080) 89 | local client = m:join() 90 | server:clients(1, function(client) end) 91 | client:close() 92 | server:close() 93 | end, 94 | 95 | testPingPong = function() 96 | testCS(test, 97 | function(server) 98 | server:clients(1, function(client) 99 | local m0 = client:recv() 100 | assert(m0 == "ping", "expected a ping") 101 | client:send("pong") 102 | end) 103 | end, 104 | function(client) 105 | client:send("ping") 106 | local m1 = client:recv() 107 | assert(m1 == "pong", "expected a pong, saw: "..m1) 108 | end) 109 | end, 110 | 111 | testPingPongAsync = function() 112 | testCS(test, 113 | function(server) 114 | server:clients(1, function(client) 115 | local m0 = client:recv() 116 | assert(m0 == "ping", "expected a ping") 117 | sys.sleep(1) 118 | client:send("pong") 119 | end) 120 | end, 121 | function(client) 122 | client:send("ping") 123 | local x = 0 124 | while 1 do 125 | local m1 = client:recvAsync() 126 | if m1 ~= nil then 127 | assert(m1 == "pong", "expected a pong, saw: "..m1) 128 | assert(x > 0, "expected some delay") 129 | break 130 | end 131 | x = x + 1 132 | end 133 | end) 134 | end, 135 | 136 | testBroadcast = function() 137 | testCSN(10, test, 138 | function(server) 139 | server:clients(10, function(client) 140 | local tag = client:recv() 141 | client:tag(tag) 142 | end) 143 | server:broadcast({ x = "2" }, "2") 144 | server:broadcast({ x = "3" }, "3") 145 | server:broadcast({ x = "1" }, "1") 146 | server:broadcast("bye") 147 | end, 148 | function(client) 149 | local tag = tostring(math.random(2)) 150 | client:send(tag) 151 | local msg1 = client:recv() 152 | assert(msg1.x == tag) 153 | local msg2 = client:recv() 154 | assert(msg2 == "bye") 155 | end) 156 | end, 157 | 158 | testRecvAny = function() 159 | testCSN(10, test, 160 | function(server) 161 | server:clients(10, function(client) end) 162 | for i = 1,10 do 163 | local msg = server:recvAny() 164 | assert(msg == "hi") 165 | end 166 | server:broadcast("bye") 167 | end, 168 | function(client) 169 | client:send("hi") 170 | local msg = client:recv() 171 | assert(msg == "bye") 172 | end) 173 | end, 174 | 175 | testStoragePingPong = function() 176 | testCS(test, 177 | function(server) 178 | server:clients(1, function(client) 179 | local s0 = torch.ByteStorage(4) 180 | client:recv(s0) 181 | assert(s0:string() == "ping", "expected a ping") 182 | client:send(torch.ByteStorage():string("pong")) 183 | end) 184 | end, 185 | function(client) 186 | local s0 = torch.ByteStorage():string("ping") 187 | client:send(s0) 188 | local s1 = torch.ByteStorage(4) 189 | client:recv(s1) 190 | assert(s1:string() == "pong", "expected a pong") 191 | end) 192 | end, 193 | 194 | testCUDAStoragePingPong = function() 195 | if cutorch then 196 | local t0 = torch.randn(3, 3):float() 197 | local t2 = t0:cuda() 198 | local t1 = torch.randn(3, 3):cuda() 199 | testCS(test, 200 | function(server) 201 | server:clients(1, function(client) 202 | assert(torch.all(torch.eq(t2, t1)) == false, "should not match before recv") 203 | client:recv(t1:storage()) 204 | assert(torch.all(torch.eq(t2, t1)), "should match after recv") 205 | end) 206 | end, 207 | function(client, t0) 208 | client:send(t0:storage()) 209 | end, t0) 210 | end 211 | end, 212 | 213 | testStorageSizeMismatch = function() 214 | testCS(test, 215 | function(server) 216 | server:clients(1, function(client) 217 | local s0 = torch.ByteStorage(3) 218 | local ok = pcall(function() client:recv(s0) end) 219 | assert(ok == false, "should fail with storage size mismatch") 220 | end) 221 | end, 222 | function(client) 223 | local s0 = torch.ByteStorage():string("ping") 224 | client:send(s0) 225 | end) 226 | end, 227 | 228 | testTensor = function() 229 | testT(torch.randn(5, 6, 7), torch.randn(5, 6, 7)) 230 | end, 231 | 232 | testNoncontiguousTensorEasy = function() 233 | testT(torch.randn(5, 6):sub(2,4, 2,5), torch.randn(5, 6):sub(2,4, 2,5)) 234 | end, 235 | 236 | testNoncontiguousTensorMedium = function() 237 | testT(torch.randn(5, 6, 7):sub(2,4, 2,5, 2,6), torch.randn(5, 6, 7):sub(2,4, 2,5, 2,6)) 238 | end, 239 | 240 | testNoncontiguousTensorHard = function() 241 | testT(torch.randn(5, 6, 7, 8):sub(2,4), torch.randn(5, 6, 7, 8):sub(2,4)) 242 | end, 243 | 244 | testCUDATensor = function() 245 | if cutorch then 246 | testT(torch.randn(3, 4, 5):cuda(), torch.randn(3, 4, 5):cuda()) 247 | end 248 | end, 249 | 250 | testNoncontiguousCUDATensor = function() 251 | if cutorch then 252 | testT(torch.randn(8, 7, 6):cuda():sub(2,7), torch.randn(8, 7, 6):cuda():sub(2,7)) 253 | end 254 | end, 255 | 256 | testTensorZeroSized = function() 257 | local t0 = torch.randn(0) 258 | testCS(test, 259 | function(server) 260 | server:clients(1, function(client) 261 | local t1 = torch.randn(0) 262 | client:recv(t1) 263 | assert(t0:nDimension() == t1:nDimension(), "should match after recv") 264 | end) 265 | end, 266 | function(client, t0) 267 | client:send(t0) 268 | end, t0) 269 | end, 270 | 271 | testTensorNumDimensionsMismatch = function() 272 | testTF(torch.randn(3, 4), torch.randn(3, 5, 6)) 273 | end, 274 | 275 | testTensorDimensionSizeMismatch = function() 276 | testTF(torch.randn(3, 4, 5), torch.randn(3, 5, 5)) 277 | end, 278 | 279 | testSerialize = function() 280 | local server,port = ipc.server() 281 | local t = ipc.map(1, function(port) 282 | local ipc = require 'libipc' 283 | return ipc.client(port) 284 | end, port) 285 | server:clients(1, function(client) end) 286 | local client = t:join() 287 | client:send('hi') 288 | local msg = server:recvAny() 289 | test.mustBeTrue(msg == 'hi', 'expected "hi", saw: '..tostring(msg)) 290 | client:close() 291 | server:close() 292 | end, 293 | 294 | testNetStats = function() 295 | local server,port = ipc.server() 296 | local t = ipc.map(1, function(port) 297 | local ipc = require 'libipc' 298 | return ipc.client(port) 299 | end, port) 300 | server:clients(1, function(client) end) 301 | local client = t:join() 302 | client:send('hi') 303 | server:recvAny() 304 | test.mustBeTrue(type(client:netStats()) == 'table', 'expected a table') 305 | test.mustBeTrue(type(server:netStats()) == 'table', 'expected a table') 306 | client:close() 307 | server:close() 308 | end, 309 | } 310 | -------------------------------------------------------------------------------- /test/test_flock.lua: -------------------------------------------------------------------------------- 1 | local test = require 'regress' 2 | local ipc = require 'libipc' 3 | 4 | test { 5 | testSimple = function() 6 | local fn = os.tmpname() 7 | local flock = ipc.flock(fn) 8 | assert(ipc.flock(fn, true) == nil) 9 | flock:write("test") 10 | flock:close() 11 | flock = nil 12 | collectgarbage() 13 | local flock = ipc.flock(fn) 14 | assert(flock:read() == "test") 15 | end, 16 | 17 | testDeadProcess = function() 18 | local fn = os.tmpname() 19 | os.remove(fn) 20 | local pid = ipc.fork() 21 | if pid == 0 then 22 | local flock = ipc.flock(fn) 23 | sys.sleep(0.1) 24 | flock:close() 25 | os.exit(0) 26 | else 27 | local i = 0 28 | while ipc.flock(fn, true) == nil do 29 | i = i + 1 30 | if pid then 31 | ipc.waitpid(pid) 32 | pid = nil 33 | end 34 | end 35 | assert(i > 0) 36 | end 37 | end, 38 | } 39 | -------------------------------------------------------------------------------- /test/test_map.lua: -------------------------------------------------------------------------------- 1 | local test = require 'regress' 2 | local ipc = require 'libipc' 3 | 4 | test { 5 | testSingle = function() 6 | local x = ipc.map(1, function(y) return y end, 42):join() 7 | test.mustBeTrue(x == 42, 'expected 42, saw '..x) 8 | end, 9 | 10 | testMultiple = function() 11 | local x,y,z = ipc.map(1, function(a, b, c) return a,b,c end, "hi", 42, {k=2}):join() 12 | test.mustBeTrue(x == 'hi', 'expected "hi", saw '..x) 13 | test.mustBeTrue(y == 42, 'expected 42, saw '..y) 14 | test.mustBeTrue(z.k == 2, 'expected 2') 15 | end, 16 | 17 | testSingleWithN = function() 18 | local x = {ipc.map(3, function(y) return y end, 42):join()} 19 | test.mustBeTrue(#x == 3, 'expected 3, saw '..#x) 20 | test.mustBeTrue(x[1] == 42, 'expected 42, saw '..x[1]) 21 | test.mustBeTrue(x[2] == 42, 'expected 42, saw '..x[2]) 22 | test.mustBeTrue(x[3] == 42, 'expected 42, saw '..x[3]) 23 | end, 24 | 25 | testMultipleWithN = function() 26 | local x1,y1,z1,x2,y2,z2 = ipc.map(2, function(a, b, c) return a,b,c end, "hi", 42, {k=2}):join() 27 | test.mustBeTrue(x1 == 'hi', 'expected "hi", saw '..x1) 28 | test.mustBeTrue(y1 == 42, 'expected 42, saw '..y1) 29 | test.mustBeTrue(z1.k == 2, 'expected 2') 30 | test.mustBeTrue(x2 == 'hi', 'expected "hi", saw '..x2) 31 | test.mustBeTrue(y2 == 42, 'expected 42, saw '..y2) 32 | test.mustBeTrue(z2.k == 2, 'expected 2') 33 | end, 34 | 35 | testNil = function() 36 | local x, y = ipc.map(1, function(x, y) return nil, y end, nil, 42):join() 37 | test.mustBeTrue(x == nil, 'expected nil, saw '..(x or 'nil')) 38 | test.mustBeTrue(y == 42, 'expected 42, saw '..y) 39 | end, 40 | 41 | testErrors = function() 42 | local ok, msg = pcall(function() ipc.map(1, function() error('boom') end):join() end) 43 | test.mustBeTrue(ok == false, 'expected the join to fail') 44 | test.mustBeTrue(type(msg) == 'string', 'expected the error message to be a string') 45 | end, 46 | 47 | testMoreErrors = function() 48 | local ok, msg = pcall(function() ipc.map(2, function(idx) 49 | if idx == 1 then 50 | error('boom') 51 | else 52 | return 42 53 | end 54 | end):join() end) 55 | test.mustBeTrue(ok == false, 'expected the join to fail') 56 | test.mustBeTrue(type(msg) == 'string', 'expected the error message to be a string') 57 | end, 58 | 59 | testUpvalueError = function() 60 | Global4523463456345 = 32 -- set a global 61 | local function envIsSet() 62 | return Global4523463456345 == nil and 56 63 | end 64 | local name, value = debug.getupvalue(envIsSet, 1) 65 | local res = ipc.map(1, envIsSet):join() 66 | test.mustBeTrue(res == 56) 67 | 68 | local function envIsSet2() 69 | Global4523463456345 = 46 70 | local function envIsSet() 71 | return Global4523463456345 72 | end 73 | return envIsSet() 74 | end 75 | local res = ipc.map(1, envIsSet2):join() 76 | test.mustBeTrue(res == 46) 77 | 78 | Global4523463456345 = nil 79 | 80 | local i,j=1,2 81 | local function closure() 82 | return i+j+1 83 | end 84 | local ok, msg = pcall(function() ipc.map(1, closure):join() end) 85 | test.mustBeTrue(ok == false, 'expected the map to fail') 86 | test.mustBeTrue(type(msg) == 'string', 'expected the error message to be a string') 87 | end, 88 | 89 | testCheckErrors = function() 90 | local m = ipc.map(2, function(idx) 91 | local sys = require 'sys' 92 | if idx == 1 then 93 | error('boom') 94 | else 95 | sys.sleep(2) 96 | return 42 97 | end 98 | end) 99 | sys.sleep(0.2) 100 | local ok,msg = pcall(function() return m:checkErrors() end) 101 | test.mustBeTrue(ok == false, 'expected the checkErrors to fail') 102 | test.mustBeTrue(type(msg) == 'string', 'expected the error message to be a string') 103 | local ok,msg = pcall(function() return m:join() end) 104 | test.mustBeTrue(ok == true, 'expected the join to pass') 105 | test.mustBeTrue(type(msg) == 'number' and msg == 42, 'expected the result to be a number 42') 106 | end, 107 | 108 | testLuaCheckStack = function() 109 | -- OSX has low file ulimits that cause the require system to die 110 | local n = sys.uname() == 'macos' and 50 or 1000 111 | local ret = { ipc.map(n, function() return 1 end):join() } 112 | test.mustBeTrue(#ret == n, 'expected '..n..' elements, saw '..#ret) 113 | end, 114 | 115 | testLastArg = function() 116 | local ret = { ipc.map(4, function(s, id) return id end, "hi"):join() } 117 | test.mustBeTrue(#ret == 4, 'expected 4 elements, saw '..#ret) 118 | test.mustBeTrue(ret[1] == 1, 'expected 1 at 1') 119 | test.mustBeTrue(ret[2] == 2, 'expected 2 at 2') 120 | test.mustBeTrue(ret[3] == 3, 'expected 3 at 3') 121 | test.mustBeTrue(ret[4] == 4, 'expected 4 at 4') 122 | end, 123 | 124 | testBigArguments = function() 125 | local n = 1000 126 | local data = { } 127 | for i = 1,n do 128 | data[i] = i 129 | end 130 | ipc.map(3, function(data, mapid) 131 | return data[mapid] 132 | end, data):join() 133 | end, 134 | 135 | testBigReturns = function() 136 | ipc.map(3, function() 137 | local n = 1000 138 | local data = { } 139 | for i = 1,n do 140 | data[i] = i 141 | end 142 | return data 143 | end):join() 144 | end, 145 | 146 | testFunctionMapId = function() 147 | Global4523463456345 = 1 148 | Globalsadg234523 = 2 149 | 150 | local function run(opt, mapid) 151 | assert(torch.type(opt) == 'table') 152 | assert(torch.type(mapid) == 'number') 153 | -- do something with globals 154 | return (Global4523463456345 or 3) + (Globalsadg234523 or 5) 155 | end 156 | 157 | local largetable = {} 158 | for i=1,1000 do 159 | largetable["asdfasdfa"..i] = i 160 | end 161 | local mutex = ipc.mutex() 162 | ipc.map(3, run, { 163 | asts = largetable, 164 | verbose = 3232, 165 | mode = 'string', 166 | mutex = mutex, 167 | barrier = true, 168 | partitions = torch.Tensor(10), 169 | reportEvery = 200, 170 | }):join() 171 | 172 | Global4523463456345 = nil 173 | Globalsadg234523 = nil 174 | end, 175 | } 176 | -------------------------------------------------------------------------------- /test/test_marshal.lua: -------------------------------------------------------------------------------- 1 | local test = require 'regress' 2 | local ipc = require 'libipc' 3 | 4 | test { 5 | testWriteRead = function() 6 | local m = ipc.marshal(23) 7 | test.mustBeTrue(torch.type(m) == 'userdata', "expecting userdata, got "..torch.type(m)) 8 | local res = m:read() 9 | test.mustBeTrue(res == 23, 'expected 23, saw '..res) 10 | local res2 = m:read() 11 | test.mustBeTrue(res2 == 23, 'expected 23, saw '..res2) 12 | local res3 = m:read() 13 | test.mustBeTrue(res3 == 23, 'expected 23, saw '..res3) 14 | end, 15 | 16 | testMap = function() 17 | local m = ipc.marshal(23) 18 | local m, val, m2, val2 = ipc.map(2, function(m) 19 | local val = m:read() 20 | assert(val == 23, "Expecting 23") 21 | return m, val 22 | end, m):join() 23 | 24 | local res = m:read() 25 | test.mustBeTrue(res == 23, 'expected 23, saw '..res) 26 | local res = m:read() 27 | test.mustBeTrue(res == 23, 'expected 23, saw '..res) 28 | end, 29 | 30 | testWorkqueue = function() 31 | local m = ipc.marshal(43) 32 | local q = ipc.workqueue("marshal") 33 | local map = ipc.map(2, function() 34 | local ipc = require 'libipc' 35 | local q = ipc.workqueue("marshal") 36 | while true do 37 | local m = q:read() 38 | if m == nil then 39 | break 40 | end 41 | local res = m:read() 42 | assert(res == 43, "Expecting 43") 43 | q:write(m) 44 | end 45 | end) 46 | 47 | for i=1,5 do 48 | q:write(m) 49 | end 50 | for i=1,2 do 51 | q:write(nil) 52 | end 53 | 54 | map:join() 55 | 56 | for i=1,5 do 57 | local m2 = q:read() 58 | local res2 = m2:read() 59 | test.mustBeTrue(res2 == 43, 'expected 43, saw '..res2) 60 | end 61 | end, 62 | 63 | testTable = function() 64 | local obj = {1,2,3,v=4,g=5,t=function() return 6 end} 65 | local m = ipc.marshal(obj) 66 | 67 | local obj2 = m:read() 68 | test.mustBeTrue(torch.type(obj2) == 'table', 'expected tabe, saw '..torch.type(obj2)) 69 | test.mustBeTrue(obj2[1] == 1 and obj2[2] == 2 and obj2[3] == 3 and obj2.v == 4 and obj2.g == 5 and obj2.t() == 6, 'error in obj') 70 | 71 | local obj2 = m:read() 72 | test.mustBeTrue(torch.type(obj2) == 'table', 'expected tabe, saw '..torch.type(obj2)) 73 | test.mustBeTrue(obj2[1] == 1 and obj2[2] == 2 and obj2[3] == 3 and obj2.v == 4 and obj2.g == 5 and obj2.t() == 6, 'error in obj') 74 | 75 | local obj2 = m:read() 76 | test.mustBeTrue(torch.type(obj2) == 'table', 'expected tabe, saw '..torch.type(obj2)) 77 | test.mustBeTrue(obj2[1] == 1 and obj2[2] == 2 and obj2[3] == 3 and obj2.v == 4 and obj2.g == 5 and obj2.t() == 6, 'error in obj') 78 | end, 79 | 80 | testClosure = function() 81 | local upval = 6 82 | local obj = {1,2,3,v=4,g=5,t=function() return upval end} 83 | local success, m = pcall(function() return ipc.marshal(obj) end) 84 | test.mustBeTrue(not success, "Expecting upval marshalling error") 85 | 86 | local m = ipc.marshal(obj, true) 87 | 88 | local obj2 = m:read() 89 | test.mustBeTrue(torch.type(obj2) == 'table', 'expected tabe, saw '..torch.type(obj2)) 90 | test.mustBeTrue(obj2[1] == 1 and obj2[2] == 2 and obj2[3] == 3 and obj2.v == 4 and obj2.g == 5 and obj2.t() == 6, 'error in obj') 91 | 92 | local obj2 = m:read() 93 | test.mustBeTrue(torch.type(obj2) == 'table', 'expected tabe, saw '..torch.type(obj2)) 94 | test.mustBeTrue(obj2[1] == 1 and obj2[2] == 2 and obj2[3] == 3 and obj2.v == 4 and obj2.g == 5 and obj2.t() == 6, 'error in obj') 95 | 96 | local obj2 = m:read() 97 | test.mustBeTrue(torch.type(obj2) == 'table', 'expected tabe, saw '..torch.type(obj2)) 98 | test.mustBeTrue(obj2[1] == 1 and obj2[2] == 2 and obj2[3] == 3 and obj2.v == 4 and obj2.g == 5 and obj2.t() == 6, 'error in obj') 99 | end, 100 | 101 | testSize= function() 102 | local m = ipc.marshal(23, 10, 10) 103 | test.mustBeTrue(torch.type(m) == 'userdata', "expecting userdata, got "..torch.type(m)) 104 | local res = m:read() 105 | test.mustBeTrue(res == 23, 'expected 23, saw '..res) 106 | local res2 = m:read() 107 | test.mustBeTrue(res2 == 23, 'expected 23, saw '..res2) 108 | local res3 = m:read() 109 | test.mustBeTrue(res3 == 23, 'expected 23, saw '..res3) 110 | end, 111 | } 112 | 113 | 114 | -------------------------------------------------------------------------------- /test/test_mutex.lua: -------------------------------------------------------------------------------- 1 | local test = require 'regress' 2 | local ipc = require 'libipc' 3 | 4 | test { 5 | testLockingAndBarrier = function() 6 | local beforeWriteMutex = ipc.mutex() 7 | local afterWriteMutex = ipc.mutex() 8 | local shared = torch.FloatTensor(10000) 9 | shared:fill(0) 10 | 11 | local m = ipc.map(3, function(beforeWriteMutex, afterWriteMutex, shared, mapid) 12 | local ipc = require 'libipc' 13 | local sys = require 'sys' 14 | 15 | assert(shared[1] == 0) 16 | beforeWriteMutex:barrier(4) 17 | 18 | afterWriteMutex:barrier(4) 19 | assert(shared[1] ~= 0) 20 | 21 | afterWriteMutex:lock() 22 | for i = 1,shared:size(1) do 23 | shared[i] = mapid 24 | end 25 | afterWriteMutex:unlock() 26 | end, beforeWriteMutex, afterWriteMutex, shared) 27 | 28 | -- `beforeWriteMutex:barrier(4)` guarantees `assert(shared[1] == 0)` to succeed: 29 | -- the assignment `shard[1] = 1000` won't happen until all 3 threads finish the above assert. 30 | beforeWriteMutex:barrier(4) 31 | 32 | shared[1] = 1000 33 | 34 | -- afterWriteMutex:barrier(4) guarantees `assert(shared[1] ~= 0)` to succeed: 35 | -- the assignment `shard[1] = 1000` is guaranteed to happen before the above asserts. 36 | afterWriteMutex:barrier(4) 37 | 38 | m:join() 39 | local first = shared[1] 40 | for i = 2,shared:size(1) do 41 | assert(shared[i] == first) 42 | end 43 | end, 44 | } 45 | -------------------------------------------------------------------------------- /test/test_sharedtable.lua: -------------------------------------------------------------------------------- 1 | local test = require 'regress' 2 | local ipc = require 'libipc' 3 | 4 | test { 5 | testCount = function() 6 | local t = {2,3,4} 7 | local t2 = ipc.sharedtable(t) 8 | assert(#t == #t2) 9 | end, 10 | 11 | testMove = function() 12 | local t = {2,3,4} 13 | local t2 = ipc.sharedtable(t, true) 14 | assert(#t == 0) 15 | assert(#t2 == 3) 16 | for i=1,3 do 17 | assert(i+1 == t2[i]) 18 | end 19 | end, 20 | 21 | testEmpty = function() 22 | local t = ipc.sharedtable() 23 | assert(#t == 0) 24 | local t = ipc.sharedtable(nil) 25 | assert(#t == 0) 26 | end, 27 | 28 | testRead = function() 29 | local t = {2,3,4} 30 | local t2 = ipc.sharedtable(t) 31 | 32 | for i = 1,#t do 33 | assert(t[i] == t2[i]) 34 | end 35 | end, 36 | 37 | testPairs = function() 38 | local t = {2,3,4} 39 | local t2 = ipc.sharedtable(t) 40 | 41 | if _VERSION == 'Lua 5.1' then 42 | -- pairs not supported for userdata in 5.1 43 | return 44 | end 45 | for k, v in pairs(t2) do 46 | assert(t[k] == v) 47 | end 48 | end, 49 | 50 | testWrite = function() 51 | local t = {2,3,4} 52 | local t2 = ipc.sharedtable(t) 53 | 54 | for i = 1,#t do 55 | t2[i] = t2[i]+1 56 | assert(t[i]+1 == t2[i]) 57 | end 58 | end, 59 | 60 | testWriteOnOther = function() 61 | local t = ipc.sharedtable({0}) 62 | ipc.map(1, function(tab) 63 | for i=1,100 do 64 | tab[1] = tab[1]+1 65 | end 66 | end, t):join() 67 | assert(t[1] == 100) 68 | 69 | local t = ipc.sharedtable() 70 | ipc.map(100, function(tab, i) 71 | tab[i] = i 72 | end, t):join() 73 | assert(#t == 100) 74 | for i=1,100 do 75 | assert(t[i] == i) 76 | end 77 | end, 78 | 79 | testMultipleWrites = function() 80 | local t = ipc.sharedtable() 81 | local m = ipc.mutex() 82 | ipc.map(10, function(tab, mutex) 83 | for i=1,100 do 84 | mutex:lock() 85 | tab[i] = (tab[i] or 0) + 1 86 | mutex:unlock() 87 | end 88 | end, t, m):join() 89 | assert(#t == 100) 90 | for i=1,100 do 91 | assert(t[i] == 10) 92 | end 93 | end, 94 | 95 | testExternalType = function() 96 | local t = ipc.sharedtable() 97 | t.mutex = ipc.mutex() 98 | ipc.map(10, function(tab) 99 | for i=1,100 do 100 | tab.mutex:lock() 101 | tab[i] = (tab[i] or torch.LongTensor({0})) + 1 102 | tab.mutex:unlock() 103 | end 104 | end, t):join() 105 | for i=1,100 do 106 | assert(t[i][1] == 10) 107 | end 108 | end, 109 | 110 | testTableIncrement = function() 111 | local t = ipc.sharedtable() 112 | local t2 = ipc.sharedtable() 113 | local m = ipc.mutex() 114 | ipc.map(10, function(tab, tab2, mutex) 115 | local ipc = require 'libipc' 116 | for i=1,100 do 117 | mutex:lock() 118 | tab[i] = (tab[i] or ipc.sharedtable({0})) 119 | tab[i][1] = tab[i][1]+1 120 | tab2[i] = (tab2[i] or ipc.sharedtable({0})) 121 | tab2[i][1] = tab2[i][1]+1 122 | mutex:unlock() 123 | end 124 | end, t, t2, m):join() 125 | assert(#t == 100) 126 | assert(#t2 == 100) 127 | for i=1,100 do 128 | assert(t[i][1] == 10) 129 | assert(t2[i][1] == 10) 130 | end 131 | end, 132 | 133 | testTableExpand = function() 134 | local t = ipc.sharedtable() 135 | local t2 = ipc.sharedtable() 136 | local m = ipc.mutex() 137 | ipc.map(10, function(tab, tab2, mutex) 138 | local ipc = require 'libipc' 139 | for i=1,100 do 140 | mutex:lock() 141 | tab[i] = (tab[i] or ipc.sharedtable()) 142 | local l = #tab[i]+1 143 | tab[i][l] = l 144 | tab2[i] = (tab2[i] or ipc.sharedtable()) 145 | local l = #tab2[i]+1 146 | tab2[i][l] = l 147 | mutex:unlock() 148 | end 149 | end, t, t2, m):join() 150 | assert(#t == 100) 151 | assert(#t2 == 100) 152 | for i=1,100 do 153 | assert(#t[i] == 10) 154 | assert(#t2[i] == 10) 155 | for j=1,10 do 156 | assert(t[i][j] == j) 157 | assert(t2[i][j] == j) 158 | end 159 | end 160 | end, 161 | 162 | testSize = function() 163 | local t = ipc.sharedtable() 164 | local size1 = ipc.sharedtable_size(t) 165 | local t2 = {} 166 | for i=1,1000 do 167 | t2[i] = i 168 | end 169 | for i=1,1000 do 170 | t[i] = t2 171 | end 172 | local size2 = ipc.sharedtable_size(t) 173 | assert(size2 > size1) 174 | end, 175 | } 176 | -------------------------------------------------------------------------------- /test/test_spawn.lua: -------------------------------------------------------------------------------- 1 | local test = require 'regress' 2 | local ipc = require 'libipc' 3 | 4 | test { 5 | testStdoutAll = function() 6 | local p = ipc.spawn({ 7 | file = 'echo', 8 | args = { 9 | 'what', 10 | 'up', 11 | 'dawg', 12 | }, 13 | }) 14 | assert(type(p:stdoutFileId()) == 'number') 15 | assert(p:stdout('*all') == 'what up dawg\n') 16 | assert(p:stdout('*all') == nil) 17 | assert(p:wait() == 0) 18 | end, 19 | 20 | testStdoutLine = function() 21 | local p = ipc.spawn({ 22 | file = 'echo', 23 | args = { 24 | 'what', 25 | 'up', 26 | 'dawg', 27 | }, 28 | }) 29 | assert(p:stdout('*line') == 'what up dawg') 30 | assert(p:stdout('*line') == nil) 31 | assert(p:wait() == 0) 32 | end, 33 | 34 | testStdoutNumber = function() 35 | local p = ipc.spawn({ 36 | file = 'echo', 37 | args = { 38 | 'what', 39 | 'up', 40 | 'dawg', 41 | }, 42 | }) 43 | assert(p:stdout(256) == 'what up dawg\n') 44 | assert(p:stdout(256) == nil) 45 | assert(p:wait() == 0) 46 | end, 47 | 48 | testStdin = function() 49 | local p = ipc.spawn({ 50 | file = 'tee', 51 | }) 52 | p:stdin('a\n') 53 | p:stdin('b\n') 54 | p:stdin('c\n') 55 | p:stdin('d') 56 | p:stdin() 57 | assert(p:stdout('*line') == 'a') 58 | assert(p:stdout('*line') == 'b') 59 | assert(p:stdout('*line') == 'c') 60 | assert(p:stdout('*line') == 'd') 61 | assert(p:stdout() == nil) 62 | assert(p:wait() == 0) 63 | end, 64 | 65 | testEnv = function() 66 | local p = ipc.spawn({ 67 | file = 'printenv', 68 | args = { 69 | 'SOME_VAR', 70 | }, 71 | env = { 72 | 'SOME_VAR=42', 73 | }, 74 | }) 75 | assert(p:stdout('*all') == '42\n') 76 | assert(p:wait() == 0) 77 | end, 78 | 79 | testSignalKill = function() 80 | local p = ipc.spawn({ 81 | file = 'sleep', 82 | args = { 83 | 100, 84 | } 85 | }) 86 | assert(p:wait("KILL") == 0) 87 | end, 88 | 89 | testSignalTerm = function() 90 | local p = ipc.spawn({ 91 | file = 'sleep', 92 | args = { 93 | 100, 94 | } 95 | }) 96 | assert(p:wait("TERM") == 0) 97 | end, 98 | 99 | testExitCode = function() 100 | local p = ipc.spawn({ 101 | file = 'bash', 102 | args = { 103 | '-c', 104 | 'ls /this/will/never/exist/so/crazy 2> /dev/null', 105 | } 106 | }) 107 | assert(p:wait() ~= 0) 108 | end, 109 | 110 | testRunning = function() 111 | local p = ipc.spawn({ 112 | file = 'sleep', 113 | args = { 114 | 1, 115 | }, 116 | }) 117 | local i = 0 118 | while p:running() do 119 | i = i + 1 120 | sys.sleep(0.2) 121 | end 122 | assert(i > 3 and i < 7) -- fuzzy 123 | assert(p:wait() == 0) 124 | end, 125 | 126 | testCollectGarbage = function() 127 | ipc.spawn({ 128 | file = 'echo', 129 | args = { 130 | 'good', 131 | }, 132 | }):wait() 133 | collectgarbage() 134 | ipc.spawn({ 135 | file = 'echo', 136 | args = { 137 | 'bad', 138 | }, 139 | }) 140 | collectgarbage() 141 | end, 142 | 143 | testErrors = function() 144 | local ok = pcall(function() 145 | local p = ipc.spawn({ 146 | file = 'sleep', 147 | args = { 148 | 100, 149 | }, 150 | }) 151 | error("die before waiting") 152 | p:wait() 153 | end) 154 | assert(ok == false) 155 | collectgarbage() 156 | end, 157 | 158 | testMisuse = function() 159 | local ok, msg 160 | 161 | ok, msg = pcall(function() 162 | local p = ipc.spawn({ 163 | file = 'echo', 164 | args = 'what', -- should be a table 165 | }) 166 | end) 167 | assert(ok == false) 168 | assert(string.find(msg, "expected a table")) 169 | 170 | ok, msg = pcall(function() 171 | local p = ipc.spawn({ 172 | file = function() return 'echo' end, -- should be a string 173 | }) 174 | end) 175 | assert(ok == false) 176 | assert(string.find(msg, "expected a string")) 177 | end, 178 | } 179 | -------------------------------------------------------------------------------- /test/test_workqueue.lua: -------------------------------------------------------------------------------- 1 | local test = require 'regress' 2 | local ipc = require 'libipc' 3 | 4 | local name = 'test' 5 | local q = ipc.workqueue(name) 6 | 7 | local echo = ipc.map(1, function(name) 8 | local ipc = require 'libipc' 9 | local lib = {} 10 | local class = torch.class("lib.class", lib) 11 | local q = ipc.workqueue(name) 12 | while true do 13 | local msg = q:read() 14 | if torch.type(msg) == 'table' and msg.closure then 15 | q:writeup(msg[1]) 16 | else 17 | if msg == nil then 18 | break 19 | elseif torch.type(msg) == 'lib.class' then 20 | msg.reply = true 21 | elseif torch.isTensor(msg) then 22 | if msg:size(1) == 1 then 23 | msg:fill(13) 24 | end 25 | elseif type(msg) == 'table' and torch.typename(msg.f) then 26 | msg.f:fill(42) 27 | end 28 | q:write(msg) 29 | end 30 | end 31 | q:close() 32 | end, "test") 33 | 34 | local function tableEq(t0, t1) 35 | for k,v in pairs(t0) do 36 | if type(v) == 'table' then 37 | local e,msg = tableEq(v, t1[k]) 38 | if not e then 39 | return e,msg 40 | end 41 | elseif t1[k] ~= v then 42 | return false, "for key '"..k.."' "..v.." ~= "..t1[k] 43 | end 44 | end 45 | for k,_ in pairs(t1) do 46 | if t0[k] == nil then 47 | return false, "extra right hand side key '"..k.."'" 48 | end 49 | end 50 | return true, "" 51 | end 52 | 53 | test { 54 | testBooleans = function() 55 | local n = true 56 | q:write(n) 57 | local n1 = q:read() 58 | test.mustBeTrue(n == n1, 'Boolean serialization failed '..tostring(n)..' ~= '..tostring(n1)) 59 | local n = false 60 | q:write(n) 61 | local n1 = q:read() 62 | test.mustBeTrue(n == n1, 'Boolean serialization failed '..tostring(n)..' ~= '..tostring(n1)) 63 | end, 64 | 65 | testNumbers = function() 66 | local n = 42.13 67 | q:write(n) 68 | local n1 = q:read() 69 | test.mustBeTrue(n == n1, 'Number serialization failed '..n..' ~= '..n1) 70 | end, 71 | 72 | testStrings = function() 73 | local n = "hey man what's up with that code?" 74 | q:write(n) 75 | local n1 = q:read() 76 | test.mustBeTrue(n == n1, 'String serialization failed '..n..' ~= '..n1) 77 | end, 78 | 79 | testArrays = function() 80 | local n = { 1, 2, 73.86, 'hello', true, false, 'good bye', 42.13 } 81 | q:write(n) 82 | local n1 = q:read() 83 | local e, msg = tableEq(n, n1) 84 | test.mustBeTrue(e, 'Table serialization failed '..msg) 85 | end, 86 | 87 | testTables = function() 88 | local n = { 89 | k0 = true, 90 | k1 = 23.45, 91 | k2 = "hey", 92 | k3 = { 1, 2, "yo" }, 93 | k4 = { a = 1 }, 94 | k5 = { a = { b = 2 } }, 95 | } 96 | q:write(n) 97 | local n1 = q:read() 98 | local e, msg = tableEq(n, n1) 99 | test.mustBeTrue(e, 'Table serialization failed '..msg) 100 | end, 101 | 102 | testMetaTables = function() 103 | -- local module class 104 | local lib = {} 105 | local class = torch.class('lib.class', lib) 106 | local cmd = lib.class() 107 | q:write(cmd) 108 | local cmd1 = q:read() 109 | test.mustBeTrue(torch.typename(cmd) == torch.typename(cmd1), "local Metatable serialization failed") 110 | test.mustBeTrue(cmd1.reply, "local Metatable table serialize fail") 111 | -- global module class 112 | local cmd = torch.CmdLine() 113 | cmd.id = 1234 114 | cmd.cmd = torch.CmdLine() 115 | q:write(cmd) 116 | local cmd1 = q:read() 117 | test.mustBeTrue(torch.typename(cmd) == torch.typename(cmd1), "global Metatable serialization failed") 118 | test.mustBeTrue(torch.typename(cmd.cmd) == torch.typename(cmd1.cmd), "global Metatable nested serialization failed") 119 | test.mustBeTrue(cmd.id == 1234, "global Metatable table serialize fail") 120 | local e, msg = tableEq(cmd, cmd1) 121 | test.mustBeTrue(e, 'Table serialization failed '..msg) 122 | end, 123 | 124 | testFunctions = function() 125 | local f = function(a, b, c) return math.sqrt((a * a) + (b * b) + (c * c)) end 126 | q:write(f) 127 | local f1 = q:read() 128 | local n = f(1, 2, 3) 129 | local n1 = f1(1, 2, 3) 130 | test.mustBeTrue(n == n1, 'Function serialization failed '..n..' ~= '..n1) 131 | end, 132 | 133 | testClosures = function() 134 | if f3rtwertwert534 ~= nil then 135 | return 136 | end 137 | -- global function with an unlikely name 138 | f3rtwertwert534 = function() return 534 end 139 | local bias1, bias2, bias3 = 0, 1, 2 140 | local f0 = function() return f3rtwertwert534() + bias3 end 141 | local f = function(a, b, c) return f0() + bias2 + bias1 + math.sqrt((a * a) + (b * b) + (c * c)) end 142 | q:writeup({f,closure=true}) -- writeup 143 | local f1 = q:read() 144 | local n = f(1, 2, 3) 145 | local n1 = f1(1, 2, 3) 146 | test.mustBeTrue(n == n1, 'Function serialization failed '..n..' ~= '..n1) 147 | f3rtwertwert534 = nil 148 | end, 149 | 150 | testTensors = function() 151 | local f = torch.randn(10) 152 | q:write(f) 153 | local f1 = q:read() 154 | for i = 1,10 do 155 | test.mustBeTrue(f[i] == f1[i], 'Tensor serialization failed '..f[i]..' ~= '..f1[i]) 156 | end 157 | end, 158 | 159 | testTensorsTwoWay = function() 160 | local f = torch.FloatTensor(1) 161 | f:fill(0) 162 | q:write(f) 163 | local f1 = q:read() 164 | test.mustBeTrue(f1[1] == 13, 'Tensor serialization failed '..f1[1]..' ~= 13') 165 | end, 166 | 167 | testTensorsInTable = function() 168 | local t = { 169 | f = torch.FloatTensor(1) 170 | } 171 | t.f:fill(0) 172 | q:write(t) 173 | local t1 = q:read() 174 | test.mustBeTrue(t1.f[1] == 42, 'Tensor serialization failed '..t1.f[1]..' ~= 42') 175 | end, 176 | 177 | testDrain = function() 178 | for i = 1,13 do 179 | q:write(i) 180 | end 181 | q:drain() 182 | for i = 1,13 do 183 | local r = q:read(true) 184 | test.mustBeTrue(r == i, 'Expected '..r..' to be '..i..' after drain') 185 | end 186 | local f = q:read(true) 187 | test.mustBeTrue(f == nil, 'Expected to read nil after draining') 188 | end, 189 | 190 | testMultiWrite = function() 191 | local a = { } 192 | for i = 1,13 do 193 | a[i] = i 194 | end 195 | q:write((unpack or table.unpack)(a)) 196 | for i = 1,13 do 197 | test.mustBeTrue(i == q:read()) 198 | end 199 | end, 200 | 201 | testDualStalls = function() 202 | local sq = ipc.workqueue('ds', 16) 203 | local m = ipc.map(1, function() 204 | local ipc = require 'libipc' 205 | local sq = ipc.workqueue('ds') 206 | local count = 0 207 | while true do 208 | local n = sq:read() 209 | if n == nil then 210 | break 211 | end 212 | sq:write(n) 213 | count = count + n 214 | end 215 | return count 216 | end) 217 | local expected = 0 218 | for i = 1,100 do 219 | sq:write(i) 220 | expected = expected + i 221 | end 222 | sq:write(nil) 223 | local final = m:join() 224 | assert(final == expected) 225 | end, 226 | 227 | testAnon = function() 228 | local q = ipc.workqueue() 229 | local m = ipc.mutex() 230 | local w = ipc.map(1, function(q, m) 231 | local data = q:read() 232 | q:write(data+1) 233 | q:read() 234 | m:barrier(2) 235 | q:write(true) 236 | end, q, m) 237 | 238 | q:write(0) 239 | assert(q:read() == 1) 240 | q:write(false) 241 | q:close() 242 | q = nil 243 | m:barrier(2) 244 | w:join() 245 | end, 246 | 247 | testCheckCreator = function() 248 | local q, creator = ipc.workqueue('test name') 249 | assert(creator == 1) 250 | local w = ipc.map(1, function() 251 | local ipc = require 'libipc' 252 | local q, creator = ipc.workqueue('test name') 253 | q:write(creator) 254 | end) 255 | assert(q:read() == 0) 256 | w:join() 257 | end, 258 | } 259 | 260 | q:write(nil) 261 | echo:join() 262 | --------------------------------------------------------------------------------