├── doc ├── images │ ├── csv1.png │ └── csv2.png ├── install.md └── tutorial.md ├── README.md ├── CMakeLists.txt ├── init.lua ├── rocks └── hypero-scm-1.rockspec ├── utils.lua ├── schema.sql ├── scripts └── export.lua ├── LICENSE ├── Sampler.lua ├── Connect.lua ├── Postgres.lua ├── Experiment.lua ├── examples └── neuralnetwork.lua ├── test.lua └── Battery.lua /doc/images/csv1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Element-Research/hypero/HEAD/doc/images/csv1.png -------------------------------------------------------------------------------- /doc/images/csv2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Element-Research/hypero/HEAD/doc/images/csv2.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # hypero 2 | 3 | Simple distributed hyper-optimization library for torch7 : 4 | 5 | * [Installation](doc/install.md) 6 | * [Tutorial](doc/tutorial.md) 7 | * [Example](examples/neuralnetwork.lua) 8 | 9 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR) 3 | CMAKE_POLICY(VERSION 2.6) 4 | IF(LUAROCKS_PREFIX) 5 | MESSAGE(STATUS "Installing Torch through Luarocks") 6 | STRING(REGEX REPLACE "(.*)lib/luarocks/rocks.*" "\\1" CMAKE_INSTALL_PREFIX "${LUAROCKS_PREFIX}") 7 | MESSAGE(STATUS "Prefix inferred from Luarocks: ${CMAKE_INSTALL_PREFIX}") 8 | ENDIF() 9 | FIND_PACKAGE(Torch REQUIRED) 10 | 11 | SET(src) 12 | FILE(GLOB luasrc *.lua) 13 | 14 | ADD_TORCH_PACKAGE(hypero "${src}" "${luasrc}") 15 | -------------------------------------------------------------------------------- /init.lua: -------------------------------------------------------------------------------- 1 | require 'string' 2 | _ = require 'moses' 3 | require 'xlua' 4 | require 'fs' 5 | require 'os' 6 | require 'sys' 7 | require 'lfs' 8 | require 'torchx' 9 | require 'json' 10 | 11 | hypero = {} 12 | 13 | torch.include('hypero', 'utils.lua') 14 | torch.include('hypero', 'Postgres.lua') 15 | torch.include('hypero', 'Connect.lua') 16 | torch.include('hypero', 'Battery.lua') 17 | torch.include('hypero', 'Experiment.lua') 18 | torch.include('hypero', 'Sampler.lua') 19 | torch.include('hypero', 'test.lua') 20 | 21 | return hypero 22 | -------------------------------------------------------------------------------- /rocks/hypero-scm-1.rockspec: -------------------------------------------------------------------------------- 1 | package = "hypero" 2 | version = "scm-1" 3 | 4 | source = { 5 | url = "git://github.com/Element-Research/hypero", 6 | tag = "master" 7 | } 8 | 9 | description = { 10 | summary = "A hyper-optimization library for torch7", 11 | detailed = [[ 12 | A simple asynchronous distributed hyper-parameter optimization library for torch7. 13 | It performs random-search of a parameter distribution. 14 | All params, meta-data and results are stored in a PostgreSQL database. 15 | These can be queried and updated using scripts or client APIs. 16 | ]], 17 | homepage = "https://github.com/Element-Research/hypero/blob/master/README.md" 18 | } 19 | 20 | dependencies = { 21 | "torch >= 7.0", 22 | "moses >= 1.3.1", 23 | "fs >= 0.3", 24 | "xlua >= 1.0", 25 | "luafilesystem >= 1.6.2", 26 | "sys >= 1.1", 27 | "torchx >= 1.0", 28 | "luajson", 29 | "luasql-postgres" 30 | } 31 | 32 | build = { 33 | type = "command", 34 | build_command = [[ 35 | cmake -E make_directory build && cd build && cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH="$(LUA_BINDIR)/.." -DCMAKE_INSTALL_PREFIX="$(PREFIX)" && $(MAKE) 36 | ]], 37 | install_command = "cd build && $(MAKE) install" 38 | } 39 | -------------------------------------------------------------------------------- /utils.lua: -------------------------------------------------------------------------------- 1 | -- http://lua-users.org/wiki/SplitJoin 2 | function string:split(sep) 3 | local sep, fields = sep or ":", {} 4 | local pattern = string.format("([^%s]+)", sep) 5 | self:gsub(pattern, function(c) fields[#fields+1] = c end) 6 | return fields 7 | end 8 | 9 | -- http://stackoverflow.com/questions/2705793/how-to-get-number-of-entries-in-a-lua-table 10 | function table.length(T) 11 | local count = 0 12 | for _ in pairs(T) do count = count + 1 end 13 | return count 14 | end 15 | 16 | -- http://nocurve.com/simple-csv-read-and-write-using-lua/ 17 | function hypero.writecsv(path, header, data, sep) 18 | sep = sep or ',' 19 | local file = assert(io.open(path, "w")) 20 | local nCol = header 21 | 22 | if torch.type(header) == 'table' then 23 | nCol = #nCol 24 | data = _.clone(data) 25 | table.insert(data, 1, header) 26 | else 27 | assert(torch.type(nCol) == 'number') 28 | end 29 | 30 | for i=1,#data do 31 | local row = data[i] 32 | 33 | for j=1,nCol do 34 | if j>1 then 35 | file:write(sep) 36 | end 37 | local val = row[j] 38 | local jsonVal = json.encode.encode(val) 39 | if torch.type(val) == 'table' then 40 | jsonVal = '"'..jsonVal..'"' 41 | end 42 | file:write(jsonVal) 43 | end 44 | file:write('\n') 45 | end 46 | file:close() 47 | end 48 | -------------------------------------------------------------------------------- /schema.sql: -------------------------------------------------------------------------------- 1 | --DROP SCHEMA hyper CASCADE; 2 | CREATE SCHEMA IF NOT EXISTS hyper; 3 | 4 | CREATE TABLE IF NOT EXISTS hyper.battery ( 5 | bat_id BIGSERIAL, 6 | bat_name VARCHAR(255), 7 | bat_time TIMESTAMP DEFAULT now(), 8 | PRIMARY KEY (bat_id), 9 | UNIQUE (bat_name) 10 | ); 11 | 12 | -- DROP TABLE hyper.version; 13 | CREATE TABLE IF NOT EXISTS hyper.version ( 14 | ver_id BIGSERIAL, 15 | bat_id INT8, 16 | ver_desc VARCHAR(255), 17 | ver_time TIMESTAMP DEFAULT now(), 18 | PRIMARY KEY (ver_id), 19 | FOREIGN KEY (bat_id) REFERENCES hyper.battery (bat_id), 20 | UNIQUE (bat_id, ver_desc) 21 | ); 22 | 23 | CREATE TABLE IF NOT EXISTS hyper.experiment ( 24 | hex_id BIGSERIAL, 25 | bat_id INT8, 26 | ver_id INT8, 27 | hex_time TIMESTAMP DEFAULT now(), 28 | FOREIGN KEY (bat_id) REFERENCES hyper.battery(bat_id), 29 | FOREIGN KEY (ver_id) REFERENCES hyper.version(ver_id), 30 | PRIMARY KEY (hex_id) 31 | ); 32 | 33 | CREATE TABLE IF NOT EXISTS hyper.param ( 34 | hex_id INT8, 35 | hex_param JSON, 36 | PRIMARY KEY (hex_id), 37 | FOREIGN KEY (hex_id) REFERENCES hyper.experiment (hex_id) 38 | ); 39 | 40 | CREATE TABLE IF NOT EXISTS hyper.meta ( 41 | hex_id INT8, 42 | hex_meta JSON, 43 | PRIMARY KEY (hex_id), 44 | FOREIGN KEY (hex_id) REFERENCES hyper.experiment (hex_id) 45 | ); 46 | 47 | CREATE TABLE IF NOT EXISTS hyper.result ( 48 | hex_id INT8, 49 | hex_result JSON, 50 | PRIMARY KEY (hex_id), 51 | FOREIGN KEY (hex_id) REFERENCES hyper.experiment (hex_id) 52 | ); 53 | 54 | -------------------------------------------------------------------------------- /scripts/export.lua: -------------------------------------------------------------------------------- 1 | require 'hypero' 2 | require 'paths' 3 | 4 | --[[command line arguments]]-- 5 | 6 | cmd = torch.CmdLine() 7 | cmd:text() 8 | cmd:text('Export database battery data') 9 | cmd:text('Options:') 10 | cmd:option('--schema', 'hyper', 'SQL schema in databse') 11 | cmd:option('--batteryName', '', "name of battery of experiments to be exported") 12 | cmd:option('--versionDesc', '', 'desc of version to be exported') 13 | cmd:option('--minVer', '', '--versionDesc specifies the minimum version to be exported') 14 | cmd:option('--paramNames', '*', "comma separated list of hyper-param columns to retrieve") 15 | cmd:option('--metaNames', '*', "comma separated list of meta-data columns to retrieve") 16 | cmd:option('--resultNames', '*', "comma separated list of result columns to retrieve") 17 | cmd:option('--format', 'csv', "export format : csv") 18 | cmd:option('--savePath', '', 'for csv format, defaults to [schema].csv') 19 | cmd:option('--orderBy', 'hexId', "order by this result column") 20 | cmd:option('--desc', false, 'order is descending') 21 | cmd:text() 22 | opt = cmd:parse(arg or {}) 23 | opt.asc = not opt.desc 24 | assert(opt.batteryName ~= '') 25 | 26 | conn = hypero.connect{schema=schema} 27 | bat = conn:battery(opt.batteryName, opt.versionDesc, true, true) 28 | 29 | local data, header = bat:exportTable(opt) 30 | 31 | if opt.format == 'csv' or opt.format == 'CSV' then 32 | opt.savePath = opt.savePath == '' and (opt.schema..'.csv') or opt.savePath 33 | hypero.writecsv(opt.savePath, header, data) 34 | else 35 | error("Unrecognized export format : "..opt.format) 36 | end 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 Element Inc. (Nicholas Leonard), 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of hypero nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | 29 | -------------------------------------------------------------------------------- /Sampler.lua: -------------------------------------------------------------------------------- 1 | ------------------------------------------------------------------------ 2 | --[[ Sampler ]]-- 3 | -- hyper parameter sampling distributions 4 | ------------------------------------------------------------------------ 5 | local Sampler = torch.class("hypero.Sampler") 6 | 7 | -- sample from a categorical distribution 8 | function Sampler:categorical(probs, vals) 9 | assert(torch.type(probs) == 'table', "Expecting table of probabilites, got :"..tostring(probs)) 10 | 11 | local probs = torch.Tensor(probs) 12 | local idx = torch.multinomial(probs, 1)[1] 13 | local val = vals and vals[idx] or idx 14 | 15 | return val 16 | end 17 | 18 | -- sample from a normal distribution 19 | function Sampler:normal(mean, std) 20 | assert(torch.type(mean) == 'number') 21 | assert(torch.type(std) == 'number') 22 | 23 | local val = torch.normal(mean, std) 24 | 25 | return val 26 | end 27 | 28 | -- sample from uniform distribution 29 | function Sampler:uniform(minval, maxval) 30 | assert(torch.type(minval) == 'number') 31 | assert(torch.type(maxval) == 'number') 32 | 33 | local val = torch.uniform(minval, maxval) 34 | 35 | return val 36 | end 37 | 38 | -- Returns a value drawn according to exp(uniform(low, high)) 39 | -- so that the logarithm of the return value is uniformly distributed. 40 | -- When optimizing, this variable is constrained to the interval [exp(low), exp(high)]. 41 | function Sampler:logUniform(minval, maxval) 42 | assert(torch.type(minval) == 'number') 43 | assert(torch.type(maxval) == 'number') 44 | 45 | local val = torch.exp(torch.uniform(minval, maxval)) 46 | 47 | return val 48 | end 49 | 50 | -- sample from uniform integer distribution 51 | function Sampler:randint(minval, maxval) 52 | assert(torch.type(minval) == 'number') 53 | assert(torch.type(maxval) == 'number') 54 | 55 | local val = math.random(minval, maxval) 56 | 57 | return val 58 | end 59 | 60 | -------------------------------------------------------------------------------- /Connect.lua: -------------------------------------------------------------------------------- 1 | local Connect = torch.class("hypero.Connect") 2 | 3 | function Connect:__init(config) 4 | config = config or {} 5 | self.schema = config.schema or 'hyper' 6 | assert(torch.type(self.schema) == 'string') 7 | assert(self.schema ~= '') 8 | self.dbconn = config.dbconn or hypero.Postgres(config) 9 | self:create() 10 | end 11 | hypero.connect = hypero.Connect 12 | 13 | function Connect:battery(batName, verDesc, verbose, strict) 14 | local bat = hypero.Battery(self, batName, verbose, strict) 15 | bat:version(verDesc, strict) 16 | return bat 17 | end 18 | 19 | function Connect:close() 20 | self.dbconn:close() 21 | end 22 | 23 | function Connect:create() 24 | -- create the schema and tables if they don't already exist 25 | self.dbconn:execute(string.gsub([[ 26 | CREATE SCHEMA IF NOT EXISTS $schema$; 27 | 28 | CREATE TABLE IF NOT EXISTS $schema$.battery ( 29 | bat_id BIGSERIAL, 30 | bat_name VARCHAR(255), 31 | bat_time TIMESTAMP DEFAULT now(), 32 | PRIMARY KEY (bat_id), 33 | UNIQUE (bat_name) 34 | ); 35 | 36 | CREATE TABLE IF NOT EXISTS $schema$.version ( 37 | ver_id BIGSERIAL, 38 | bat_id INT8, 39 | ver_desc VARCHAR(255), 40 | ver_time TIMESTAMP DEFAULT now(), 41 | PRIMARY KEY (ver_id), 42 | FOREIGN KEY (bat_id) REFERENCES $schema$.battery (bat_id), 43 | UNIQUE (bat_id, ver_desc) 44 | ); 45 | 46 | CREATE TABLE IF NOT EXISTS $schema$.experiment ( 47 | hex_id BIGSERIAL, 48 | bat_id INT8, 49 | ver_id INT8, 50 | hex_time TIMESTAMP DEFAULT now(), 51 | FOREIGN KEY (bat_id) REFERENCES $schema$.battery(bat_id), 52 | FOREIGN KEY (ver_id) REFERENCES $schema$.version(ver_id), 53 | PRIMARY KEY (hex_id) 54 | ); 55 | 56 | CREATE TABLE IF NOT EXISTS $schema$.param ( 57 | hex_id INT8, 58 | hex_param JSON, 59 | PRIMARY KEY (hex_id), 60 | FOREIGN KEY (hex_id) REFERENCES $schema$.experiment (hex_id) 61 | ); 62 | 63 | CREATE TABLE IF NOT EXISTS $schema$.meta ( 64 | hex_id INT8, 65 | hex_meta JSON, 66 | PRIMARY KEY (hex_id), 67 | FOREIGN KEY (hex_id) REFERENCES $schema$.experiment (hex_id) 68 | ); 69 | 70 | CREATE TABLE IF NOT EXISTS $schema$.result ( 71 | hex_id INT8, 72 | hex_result JSON, 73 | PRIMARY KEY (hex_id), 74 | FOREIGN KEY (hex_id) REFERENCES $schema$.experiment (hex_id) 75 | ); 76 | ]],"%$schema%$", self.schema)) 77 | end 78 | 79 | --[[ Decorator methods ]]-- 80 | 81 | function Connect:executeMany(...) 82 | return self.dbconn:executeMany(...) 83 | end 84 | 85 | function Connect:execute(...) 86 | return self.dbconn:execute(...) 87 | end 88 | 89 | function Connect:fetch(...) 90 | return self.dbconn:fetch(...) 91 | end 92 | 93 | function Connect:fetchOne(...) 94 | return self.dbconn:fetchOne(...) 95 | end 96 | 97 | function Connect:close(...) 98 | return self.dbconn:close(...) 99 | end 100 | -------------------------------------------------------------------------------- /Postgres.lua: -------------------------------------------------------------------------------- 1 | ------------------------------------------------------------------------ 2 | --[[ Postgres ]]-- 3 | -- Simplified PostgreSQL database connection handler. 4 | -- Uses the ~/.pgpass file for authentication 5 | -- see (http://wiki.postgresql.org/wiki/Pgpass) 6 | ------------------------------------------------------------------------ 7 | local Postgres = torch.class("hypero.Postgres") 8 | 9 | function Postgres:__init(config) 10 | config = config or {} 11 | local args, database, user, host, env, autocommit 12 | = xlua.unpack( 13 | {config}, 14 | 'Postgres', 'Default is to get the connection string from an ' .. 15 | 'environment variable. For security reasons, and to allow ' .. 16 | 'for its serialization, no password is accepted. The password ' .. 17 | 'should be set in the ~/.pgpass file.', 18 | {arg='database', type='string', help='name of postgres database'}, 19 | {arg='user', type='string', help='username used to connect to db'}, 20 | {arg='host', type='string', help='hostname/IP address of database'}, 21 | {arg='env', type='string', default='HYPER_PG_CONN'}, 22 | {arg='autocommit', type='boolean', default=true} 23 | ) 24 | if not (database or user or host) then 25 | self.connStr = os.getenv(env) 26 | assert(self.connStr, "Environment variable HYPER_PG_CONN not set") 27 | else 28 | self.connStr = "" 29 | if database then self.connStr = self.connStr.."dbname="..database.." " end 30 | if user then self.connStr = self.connStr.."user="..user.." " end 31 | if host then self.connStr = self.connStr.."host="..host.." " end 32 | end 33 | local env = require('luasql.postgres'):postgres() 34 | self.conn = assert(env:connect(self.connStr)) 35 | self.conn:setautocommit(autocommit) 36 | self.autocommit = autocommit 37 | end 38 | 39 | function Postgres:executeMany(command, param_list) 40 | local results, errs = {}, {} 41 | local result, err 42 | for i, params in pairs (param_list) do 43 | result, err = self:execute(command, params) 44 | table.insert(results, result) 45 | table.insert(errs, err) 46 | end 47 | return results, errs 48 | end 49 | 50 | function Postgres:execute(command, params) 51 | local result, err 52 | if params then 53 | params = torch.type(params) == 'table' and params or {params} 54 | result, err = self.conn:execute(string.format(command, unpack(params))) 55 | else 56 | result, err = self.conn:execute(command) 57 | end 58 | return result, err 59 | end 60 | 61 | --mode : 'n' returns rows as array, 'a' returns them as key-value 62 | function Postgres:fetch(command, params, mode) 63 | mode = mode or 'n' 64 | local cur, err = self:execute(command, params) 65 | if cur then 66 | local coltypes = cur:getcoltypes() 67 | local colnames = cur:getcolnames() 68 | local row = cur:fetch({}, mode) 69 | local rows = {} 70 | while row do 71 | table.insert(rows, row) 72 | row = cur:fetch({}, mode) 73 | end 74 | cur:close() 75 | return rows, coltypes, colnames 76 | else 77 | return false, err 78 | end 79 | end 80 | 81 | function Postgres:fetchOne(command, params, mode) 82 | mode = mode or 'n' 83 | local cur, err = self:execute(command, params) 84 | if cur then 85 | local coltypes = cur:getcoltypes() 86 | local colname = cur:getcolnames() 87 | local row = cur:fetch({}, mode) 88 | cur:close() 89 | return row, coltypes, colnames 90 | else 91 | return false, err 92 | end 93 | end 94 | 95 | function Postgres:close() 96 | self.conn:close() 97 | end 98 | 99 | -- These two methods allow for (de)serialization of Postgres objects: 100 | function Postgres:write(file) 101 | file:writeObject(self.connStr) 102 | file:writeObject(self.autocommit) 103 | end 104 | 105 | function Postgres:read(file, version) 106 | self.connStr = file:readObject() 107 | self.autocommit = file:readObject() 108 | local env = require('luasql.postgres'):postgres() 109 | self.conn = assert(env:connect(self.connStr)) 110 | end 111 | -------------------------------------------------------------------------------- /doc/install.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | This section explains how to setup the hypero server and client(s). 4 | 5 | ## Server 6 | 7 | You will need postgresql: 8 | 9 | ```bash 10 | $ sudo apt-get install postgresql libpq-dev 11 | $ luarocks install luasql-postgres PGSQL_INCDIR=/usr/include/postgresql 12 | ``` 13 | 14 | Setup a user account and a database: 15 | 16 | ```bash 17 | $ sudo su postgres 18 | $ psql postgres 19 | postgres=# CREATE USER "hypero" WITH ENCRYPTED PASSWORD 'mysecretpassword'; 20 | postgres=# CREATE DATABASE hypero; 21 | postgres=# GRANT ALL ON DATABASE hypero TO hypero; 22 | postgres=# \q 23 | exit 24 | ``` 25 | 26 | where you should replace `mysecretpassword` with your own super secret password. 27 | Then you should be able to login using those credentials : 28 | 29 | ```bash 30 | psql -U hypero -W -h localhost hypero 31 | Password for user hypero: 32 | hypero=> \q 33 | ``` 34 | 35 | Now let's setup the server so that you can connect to it from any host using your username. 36 | You will need to add a line to `pg_hba.conf` file and change the `listen_addresses` value of 37 | `postgresql.conf` file (below, replace 9.3 with your postgresql version): 38 | 39 | ```bash 40 | $ sudo su postgres 41 | $ vim /etc/postgresql/9.3/main/pg_hba.conf 42 | host all hypero all md5 43 | $ vim /etc/postgresql/9.3/main/postgresql.conf 44 | ... 45 | #------------------------------------------------------------------------------ 46 | # CONNECTIONS AND AUTHENTICATION 47 | #------------------------------------------------------------------------------ 48 | 49 | # - Connection Settings - 50 | 51 | listen_addresses = '*' 52 | ... 53 | $ service postgresql restart 54 | $ exit 55 | ``` 56 | 57 | These changes basically allow any host supplying the correct credentials (username and password) to 58 | connect to the database which listens on port 5432 of all IP addresses of the server. 59 | If you want to make the system more secure (i.e. strict), 60 | you can consult the postgreSQL documentation for each of those files. 61 | 62 | To test out the changes, you can ssh to a different host and try to login from there. 63 | Supposing we setup our postgresql server on host `192.168.1.3` and that we ssh to `192.168.1.2` : 64 | 65 | ```bash 66 | $ ssh username@192.168.1.2 67 | $ sudo apt-get install postgresql-client 68 | $ psql -U hypero -W -h 192.168.1.3 hypero 69 | Password for user hypero: 70 | hypero=> \q 71 | ``` 72 | 73 | ## Client(s) 74 | 75 | At this point, every time we login, we need to supply a password. 76 | However, postgresql provides a simple facility for storing passwords on disk. 77 | We need only store a connection string in a `.pgpass` file located at the home directory: 78 | 79 | ```bash 80 | $ vim ~/.pgpass 81 | 192.168.1.3:5432:*:hypero:mysecretpassword 82 | $ chmod og-rwx ~/.pgpass 83 | ``` 84 | 85 | The `chmod` command is to keep other users from viewing your connection string. 86 | So now we can login to the database without requiring any password : 87 | 88 | ```bash 89 | $ psql -U hypero -h 192.168.1.3 hypero 90 | hypero=> \q 91 | ``` 92 | 93 | You should create and secure such a `.pgpass` file for each client host 94 | that will need to connect to the hypero database server. 95 | It will make your code that much more secure. Otherwise, you would 96 | need to pass around the username and password within your code (bad). 97 | 98 | Next it's time to install hypero and its dependencies : 99 | 100 | ``` 101 | $ sudo apt-get install libpq-dev 102 | $ luarocks install luasql-postgres PGSQL_INCDIR=/usr/include/postgresql 103 | $ luarocks install https://raw.githubusercontent.com/Element-Research/hypero/master/rocks/hypero-scm-1.rockspec 104 | ``` 105 | 106 | The final step is to define the `HYPER_PG_CONN` environment variable in your `.bashrc` file: 107 | 108 | ``` 109 | $ vim ~/.bashrc 110 | export HYPER_PG_CONN="dbname=hypero user=hypero host=192.168.1.3" 111 | $ source ~/.bashrc 112 | ``` 113 | 114 | Replace these with your database credentials (the `host` is the IP address of your database). 115 | This will allow you to connect to the database without specifying anything : 116 | 117 | ```lua 118 | $ th 119 | th> require 'hypero' 120 | th> conn = hypero.connect() 121 | ``` 122 | 123 | That's it. 124 | -------------------------------------------------------------------------------- /Experiment.lua: -------------------------------------------------------------------------------- 1 | local Xp = torch.class("hypero.Experiment") 2 | 3 | function Xp:__init(conn, hexId) 4 | self.conn = conn 5 | assert(torch.isTypeOf(conn, "hypero.Connect")) 6 | 7 | if torch.isTypeOf(hexId, "hypero.Battery") then 8 | local bat = hexId 9 | local batId, verId = bat.id, bat:version() 10 | assert(torch.type(batId) == 'number' or torch.type(batId) == 'string') 11 | assert(torch.type(verId) == 'number' or torch.type(verId) == 'string') 12 | assert(pcall(function() return tonumber(batId) and tonumber(verId) end)) 13 | -- get a new experiment id 14 | local err 15 | local row, err = self.conn:fetchOne([[ 16 | INSERT INTO %s.experiment (bat_id, ver_id) 17 | VALUES (%s, %s) RETURNING hex_id 18 | ]], {self.conn.schema, batId, verId}) 19 | if not row then 20 | error("Experiment error :\n"..err) 21 | end 22 | self.id = tonumber(row[1]) 23 | else 24 | assert(torch.type(hexId) == 'number' or torch.type(hexId) == 'string') 25 | self.id = tonumber(hexId) 26 | local row, err = self.conn:fetchOne([[ 27 | SELECT * FROM %s.experiment WHERE hex_id = %s 28 | ]], {self.conn.schema, hexId}) 29 | assert(row, "Non existent experiment id : "..hexId) 30 | end 31 | end 32 | 33 | -- hyper-param get/set 34 | function Xp:setParam(hp, update) 35 | assert(torch.type(hp) == 'table') 36 | -- set 37 | local jsonVal = json.encode.encode(hp) 38 | local cur, err = self.conn:execute([[ 39 | INSERT INTO %s.param (hex_id, hex_param) VALUES (%s, '%s') 40 | ]], {self.conn.schema, self.id, jsonVal}) 41 | 42 | if update and not cur then 43 | -- handle insert conflict 44 | local cur, err = self.conn:execute([[ 45 | UPDATE %s.param SET hex_param = '%s' WHERE hex_id = %s 46 | ]], {self.conn.schema, jsonVal, self.id}) 47 | if not cur then 48 | error("Experiment:setParam UPDATE err :\n"..err) 49 | end 50 | elseif not cur then 51 | error("Experiment:setParam INSERT err :\n"..err) 52 | end 53 | 54 | return value 55 | end 56 | 57 | function Xp:getParam() 58 | -- get 59 | local row = self.conn:fetchOne([[ 60 | SELECT hex_param FROM %s.param WHERE hex_id = %s 61 | ]], {self.conn.schema, self.id}) 62 | 63 | if row then 64 | return json.decode.decode(row[1]) 65 | else 66 | return nil, err 67 | end 68 | end 69 | 70 | -- meta-data get/set 71 | -- Unlike hyper-params, metadata should not influence the results of the experiment. 72 | function Xp:setMeta(md, update) 73 | assert(torch.type(md) == 'table') 74 | -- set 75 | local jsonVal = json.encode.encode(md) 76 | local cur, err = self.conn:execute([[ 77 | INSERT INTO %s.meta (hex_id, hex_meta) VALUES (%s, '%s') 78 | ]], {self.conn.schema, self.id, jsonVal}) 79 | 80 | if update and not cur then 81 | -- handle insert conflict 82 | local cur, err = self.conn:execute([[ 83 | UPDATE %s.meta SET hex_meta = '%s' WHERE hex_id = %s 84 | ]], {self.conn.schema, jsonVal, self.id}) 85 | if not cur then 86 | error("Experiment:setMeta UPDATE err :\n"..err) 87 | end 88 | elseif not cur then 89 | error("Experiment:setMeta INSERT err :\n"..err) 90 | end 91 | end 92 | 93 | function Xp:getMeta() 94 | -- get 95 | local row, err = self.conn:fetchOne([[ 96 | SELECT hex_meta FROM %s.meta WHERE hex_id = %s 97 | ]], {self.conn.schema, self.id}) 98 | 99 | if row then 100 | return json.decode.decode(row[1]) 101 | else 102 | return nil, err 103 | end 104 | end 105 | 106 | function Xp:setResult(res, update) 107 | assert(torch.type(res) == 'table') 108 | -- set 109 | local jsonVal = json.encode.encode(res) 110 | local cur, err = self.conn:execute([[ 111 | INSERT INTO %s.result (hex_id, hex_result) VALUES (%s, '%s') 112 | ]], {self.conn.schema, self.id, jsonVal}) 113 | 114 | if update and not cur then 115 | -- handle insert conflict 116 | local cur, err = self.conn:execute([[ 117 | UPDATE %s.result SET hex_result = '%s' WHERE hex_id = %s 118 | ]], {self.conn.schema, jsonVal, self.id}) 119 | if not cur then 120 | error("Experiment:setResult UPDATE err :\n"..err) 121 | end 122 | elseif not cur then 123 | error("Experiment:setResult INSERT err :\n"..err) 124 | end 125 | end 126 | 127 | function Xp:getResult() 128 | local row, err = self.conn:fetchOne([[ 129 | SELECT hex_result FROM %s.result WHERE hex_id = %s 130 | ]], {self.conn.schema, self.id}) 131 | 132 | if row then 133 | return json.decode.decode(row[1]) 134 | else 135 | return nil, err 136 | end 137 | end 138 | -------------------------------------------------------------------------------- /doc/tutorial.md: -------------------------------------------------------------------------------- 1 | # Tutorial 2 | 3 | This is a brief tutorial on how to use *hypero*. 4 | We demonstrate how the library can be used to log experiments, 5 | sample hyper-parameters, and query the database for analysing results. 6 | 7 | ## Connect 8 | 9 | Let's start off by connecting to the database server : 10 | 11 | ```lua 12 | require 'hypero' 13 | conn = hypero.connect{database='localhost', username='nicholas'} 14 | ``` 15 | 16 | The `conn` variable is a `Connect` instance. 17 | 18 | ## Battery 19 | 20 | Define a new `Battery` of experiments to run : 21 | 22 | ```lua 23 | batName = "RNN Visual Attenion - MNIST" 24 | verDesc = "fixed bug in Sequencer" 25 | battery = conn:battery(batName, verDesc) 26 | ``` 27 | 28 | This allows us to group our experiments into batteries identified by a unique `batName` string. 29 | We can also optionally keep track of the different versions of the 30 | code we are using by providing a unique `verDesc` string. 31 | This is usually a description of the changes we made to the last version of code to obtain the new one. 32 | Making changes to our code often influences the results of our experiments, 33 | so it's good practive to log these. 34 | 35 | Grouping experiments by battery and version will come in handy later 36 | when we need to retrieve the results of our experiment (see below). 37 | 38 | ## Experiment 39 | 40 | Once we have our versionned `battery` defined, we can use it to instantiate new experiments: 41 | 42 | ```lua 43 | hex = bat:experiment() 44 | ``` 45 | 46 | Think of each such experiment as an entry into the hyper-optimization log. 47 | The experiment log is organized into 3 PostgreSQL tables, where each row is associated to an experiment : 48 | 49 | * `param` : hyper-parameters like the learning rate, momentum, learning rate decay, etc. 50 | * `result` : experimental results like the learning curves or the accuracy (train, valid, test), etc. 51 | * `meta` : meta-data like the hostname from which the experiment was run, the path to the saved model, etc. 52 | 53 | These database tables can be filled with Lua tables. 54 | For example, given the following Lua tables : 55 | 56 | ```lua 57 | hp = {startLR = 0.01, momentum = 0.09, lrDecay = 'linear', minLR = 0.0001, satEpoch = 300} 58 | md = {hostname = 'hermes', dataset = 'mnist'} 59 | res = {trainAcc = 0.998, validAcc = 0.876, testAcc = 0.862} 60 | ``` 61 | 62 | Using the `hex` experiment, we can update the respective database tables as follows: 63 | 64 | ```lua 65 | hex:setParam(hp) 66 | hex:setMeta(md) 67 | hex:setResult(res) 68 | ``` 69 | 70 | Internally, the Lua tables are serialized to a 71 | [JSON](https://en.wikipedia.org/wiki/JSON) string 72 | and stored in a single column of the database. 73 | This keeps the database schema pretty simple. 74 | The only constraint is that the Lua tables be convertable 75 | to JSON, so only primitive types like `nil`, `string`, 76 | `table` and `number` can be nested within the table. 77 | 78 | ## Sampler 79 | 80 | The above example, we didn't really sample anything. 81 | That is because the database (i.e. centralized persistent storage) aspect of the 82 | library was separated from the hyper-parameter sampling. 83 | For sampling, we can basically use whatever we want, 84 | but hypero provide a `Sampler` object with different sampling distribution methods. 85 | It's doesn't use anything fancy like a Gaussian Process or anything like that. 86 | But if you do a good job of bounding and choosing your distributions, 87 | you still end up with a really effective *random search*. 88 | 89 | Example : 90 | 91 | ```lua 92 | hs = hypero.Sampler() 93 | hp = {} 94 | hp.preprocess = hs:categorical({0.8,0.1,0.1}, {'', 'lcn', 'std'}) 95 | hp.startLR = hs:logUniform(math.log(0.1), math.log(0.00001)) 96 | hp.minLR = math.min(hs:logUniform(math.log(0.1)), math.log(0.0001))*hp.startLR, 0.000001) 97 | hp.satEpoch = hs:normal(300, 200) 98 | hp.hiddenDepth = hs:randint(1, 7) 99 | ``` 100 | 101 | What did we create a `Sampler` class for this? 102 | Well we never know, maybe someday, we will have a Sampler 103 | subclass that will use a Gaussian Process 104 | or something to optimize the sampling of hyper-parameters. 105 | 106 | Again, if we want to store the hyper-parameters in the database, it's as easy as : 107 | 108 | ```lua 109 | hex:setParam(hp) 110 | ``` 111 | 112 | ## Training Script 113 | 114 | If we have a bunch of GPUs or CPUs lying around, we can create 115 | a training script that loops over different experiments. 116 | Each experiment can be logged into the database using `hypero`. 117 | For a complete example for how this is done, please consult this 118 | [example training script](../examples/neuralnetwork.lua). 119 | The main part of the script that concerns hypero is this : 120 | 121 | ```lua 122 | ... 123 | -- loop over experiments 124 | for i=1,hopt.maxHex do 125 | collectgarbage() 126 | local hex = bat:experiment() 127 | local opt = _.clone(hopt) 128 | 129 | -- hyper-parameters 130 | local hp = {} 131 | hp.preprocess = ntbl(opt.preprocess) or hs:categorical(opt.preprocess, {'', 'lcn', 'std'}) 132 | hp.startLR = ntbl(opt.startLR) or hs:logUniform(math.log(opt.startLR[1]), math.log(opt.startLR[2])) 133 | hp.minLR = (ntbl(opt.minLR) or hs:logUniform(math.log(opt.minLR[1]), math.log(opt.minLR[2])))*hp.startLR 134 | hp.satEpoch = ntbl(opt.satEpoch) or hs:normal(unpack(opt.satEpoch)) 135 | hp.momentum = ntbl(opt.momentum) or hs:categorical(opt.momentum, {0,0.9,0.95}) 136 | hp.maxOutNorm = ntbl(opt.maxOutNorm) or hs:categorical(opt.maxOutNorm, {0,1,2,4}) 137 | hp.hiddenDepth = ntbl(opt.hiddenDepth) or hs:randint(unpack(opt.hiddenDepth)) 138 | hp.hiddenSize = ntbl(opt.hiddenSize) or math.round(hs:logUniform(math.log(opt.hiddenSize[1]), math.log(opt.hiddenSize[2]))) 139 | hp.batchSize = ntbl(opt.batchSize) or hs:categorical(opt.batchSize, {16,32,64}) 140 | hp.extra = ntbl(opt.extra) or hs:categorical(opt.extra, {'none','dropout','batchnorm'}) 141 | 142 | for k,v in pairs(hp) do opt[k] = v end 143 | 144 | if not opt.silent then 145 | table.print(opt) 146 | end 147 | 148 | -- build dp experiment 149 | local xp, ds, hlog = buildExperiment(opt) 150 | 151 | -- more hyper-parameters 152 | hp.seed = xp:randomSeed() 153 | hex:setParam(hp) 154 | 155 | -- meta-data 156 | local md = {} 157 | md.name = xp:name() 158 | md.hostname = os.hostname() 159 | md.dataset = torch.type(ds) 160 | 161 | if not opt.silent then 162 | table.print(md) 163 | end 164 | 165 | md.modelstr = tostring(xp:model()) 166 | hex:setMeta(md) 167 | 168 | -- run the experiment 169 | local success, err = pcall(function() xp:run(ds) end ) 170 | 171 | -- results 172 | if success then 173 | res = {} 174 | res.trainCurve = hlog:getResultByEpoch('optimizer:feedback:confusion:accuracy') 175 | res.validCurve = hlog:getResultByEpoch('validator:feedback:confusion:accuracy') 176 | res.testCurve = hlog:getResultByEpoch('tester:feedback:confusion:accuracy') 177 | res.trainAcc = hlog:getResultAtMinima('optimizer:feedback:confusion:accuracy') 178 | res.validAcc = hlog:getResultAtMinima('validator:feedback:confusion:accuracy') 179 | res.testAcc = hlog:getResultAtMinima('tester:feedback:confusion:accuracy') 180 | res.lrs = opt.lrs 181 | res.minimaEpoch = hlog.minimaEpoch 182 | hex:setResult(res) 183 | 184 | if not opt.silent then 185 | table.print(res) 186 | end 187 | else 188 | print(err) 189 | end 190 | end 191 | ``` 192 | 193 | So basically, for each experiment, sample hyper-parameters, 194 | build and run the experiment, and save the hyper-parameters, 195 | meta-data and results to the database. 196 | If we have multiple GPUs/CPUs, we can launch an instance 197 | of the script for each available GPU/CPU, sit back, relax and 198 | wait for the results to be logged into the database. 199 | That is assuming your script is bug-free. 200 | When a bug in the code is uncovered (as it inevitably will be), 201 | we can just fix it and update the version of the battery before re-running our scripts. 202 | 203 | ## Query 204 | 205 | Assuming our training script(s) has been running for a couple of experiments, 206 | we need a way to query the results from the database. 207 | We can use the [export script](../scripts/export.lua) to export our results 208 | to CSV format. Assuming, our battery is called `Neural Network - Mnist` and 209 | we only care about versions `Neural Network v1` and above, 210 | we can use the following command to retrieve our results: 211 | 212 | ```bash 213 | th scripts/export.lua --batteryName 'Neural Network - Mnist' --versionDesc 'Neural Network v1' 214 | ``` 215 | 216 | The resulting `hyper.csv` file will contain all the experiments (one per row), 217 | rows ordered by experiment id (`hexId` column). 218 | The columns are organized by hyper-parameters, followed by results and finally the meta-data columns : 219 | 220 | ![](images/csv1.png) 221 | 222 | That is a lot of data. We can filter the data by specifying 223 | the columns we would like to include in the CSV using the 224 | `--[param,meta,result]Names` cmd-line arguments. 225 | While we are at it, we might want to order rows in 226 | descending order of the `validAcc` column : 227 | 228 | ```bash 229 | th scripts/export.lua --batteryName 'Neural Network - Mnist' --versionDesc 'Neural Network v1' --metaNames 'hostname' --resultNames 'trainAcc,validAcc,testAcc' --orderBy 'validAcc' --desc 230 | ``` 231 | 232 | The resulting `hyper.csv` file looks much better don't you think? 233 | 234 | ![](images/csv2.png) 235 | -------------------------------------------------------------------------------- /examples/neuralnetwork.lua: -------------------------------------------------------------------------------- 1 | require 'dp' 2 | require 'hypero' 3 | 4 | --[[command line arguments]]-- 5 | 6 | cmd = torch.CmdLine() 7 | cmd:text() 8 | cmd:text('MNIST dataset Image Classification using MLP Training') 9 | cmd:text('Example:') 10 | cmd:text('$> th neuralnetwork.lua --batchSize 128 --momentum 0.5') 11 | cmd:text('Options:') 12 | cmd:option('--batteryName', 'hypero neural network example', "name of battery of experiments to be run") 13 | cmd:option('--maxHex', 100, 'maximum number of hyper-experiments to train (from this script)') 14 | cmd:option('--preprocess', "{16,2,1}", "preprocessor (or distribution thereof)") 15 | cmd:option('--startLR', '{0.001,1}', 'learning rate at t=0 (log-uniform {log(min), log(max)})') 16 | cmd:option('--minLR', '{0.001,1}', 'minimum LR = minLR*startLR (log-uniform {log(min), log(max)})') 17 | cmd:option('--satEpoch', '{300, 150}', 'epoch at which linear decayed LR will reach minLR*startLR (normal {mean, std})') 18 | cmd:option('--maxOutNorm', '{1,3,4,2}', 'max norm each layers output neuron weights (categorical)') 19 | cmd:option('--momentum', '{4,4,2}', 'momentum (categorical)') 20 | cmd:option('--hiddenDepth', '{0,7}', 'number of hidden layers (randint {min, max})') 21 | cmd:option('--hiddenSize', '{128,1024}', 'number of hidden units per layer (log-uniform {log(min), log(max)})') 22 | cmd:option('--batchSize', '{1,4,1}', 'number of examples per batch (categorical)') 23 | cmd:option('--extra', '{1,1,1}', 'apply nothing, dropout or batchNorm (categorical)') 24 | cmd:option('--cuda', false, 'use CUDA') 25 | cmd:option('--useDevice', 1, 'sets the device (GPU) to use') 26 | cmd:option('--maxEpoch', 500, 'maximum number of epochs to run') 27 | cmd:option('--maxTries', 50, 'maximum number of epochs to try to find a better local minima for early-stopping') 28 | cmd:option('--progress', false, 'display progress bar') 29 | cmd:option('--silent', false, 'dont print anything to stdout') 30 | cmd:text() 31 | hopt = cmd:parse(arg or {}) 32 | hopt.preprocess = dp.returnString(hopt.preprocess) 33 | hopt.startLR = dp.returnString(hopt.startLR) 34 | hopt.minLR = dp.returnString(hopt.minLR) 35 | hopt.satEpoch = dp.returnString(hopt.satEpoch) 36 | hopt.maxOutNorm = dp.returnString(hopt.maxOutNorm) 37 | hopt.momentum = dp.returnString(hopt.momentum) 38 | hopt.hiddenDepth = dp.returnString(hopt.hiddenDepth) 39 | hopt.hiddenSize = dp.returnString(hopt.hiddenSize) 40 | hopt.batchSize = dp.returnString(hopt.batchSize) 41 | hopt.extra = dp.returnString(hopt.extra) 42 | 43 | hopt.versionDesc = "Neural Network v1" 44 | 45 | --[[ dp ]]-- 46 | 47 | function buildExperiment(opt) 48 | --[[preprocessing]]-- 49 | 50 | local input_preprocess = {} 51 | if opt.preprocess == 'std' then 52 | table.insert(input_preprocess, dp.Standardize()) 53 | elseif opt.preprocess == 'lcn' then 54 | table.insert(input_preprocess, dp.GCN()) 55 | table.insert(input_preprocess, dp.LeCunLCN{progress=true}) 56 | elseif opt.preprocess ~= '' then 57 | error("unknown preprocess : "..opt.preprocess) 58 | end 59 | 60 | --[[data]]-- 61 | 62 | local ds = torch.checkpoint( 63 | paths.concat(dp.DATA_DIR,"checkpoint","mnist_"..opt.preprocess..".t7"), 64 | function() 65 | return dp.Mnist{input_preprocess = input_preprocess} 66 | end) 67 | 68 | 69 | --[[Model]]-- 70 | 71 | local model = nn.Sequential() 72 | model:add(nn.Convert(ds:ioShapes(), 'bf')) -- to batchSize x nFeature (also type converts) 73 | 74 | -- hidden layers 75 | inputSize = ds:featureSize() 76 | 77 | for i=1,opt.hiddenDepth do 78 | 79 | model:add(nn.Linear(inputSize, opt.hiddenSize)) -- parameters 80 | if opt.extra == 'batchNorm' then 81 | model:add(nn.BatchNormalization(opt.hiddenSize)) 82 | end 83 | model:add(nn.Tanh()) 84 | if opt.extra == 'dropout' then 85 | model:add(nn.Dropout()) 86 | end 87 | inputSize = opt.hiddenSize 88 | end 89 | 90 | -- output layer 91 | model:add(nn.Linear(inputSize, #(ds:classes()))) 92 | model:add(nn.LogSoftMax()) 93 | 94 | 95 | --[[Propagators]]-- 96 | 97 | -- linear decay 98 | opt.learningRate = opt.startLR 99 | opt.decayFactor = (opt.minLR - opt.learningRate)/opt.satEpoch 100 | opt.lrs = {} 101 | 102 | local train = dp.Optimizer{ 103 | acc_update = opt.accUpdate, 104 | loss = nn.ModuleCriterion(nn.ClassNLLCriterion(), nil, nn.Convert()), 105 | epoch_callback = function(model, report) -- called every epoch 106 | -- learning rate decay 107 | if report.epoch > 0 then 108 | opt.lrs[report.epoch] = opt.learningRate 109 | opt.learningRate = opt.learningRate + opt.decayFactor 110 | opt.learningRate = math.max(opt.minLR, opt.learningRate) 111 | if not opt.silent then 112 | print("learningRate", opt.learningRate) 113 | end 114 | end 115 | end, 116 | callback = function(model, report) -- called for every batch 117 | if opt.accUpdate then 118 | model:accUpdateGradParameters(model.dpnn_input, model.output, opt.learningRate) 119 | else 120 | model:updateGradParameters(opt.momentum) -- affects gradParams 121 | model:updateParameters(opt.learningRate) -- affects params 122 | end 123 | model:maxParamNorm(opt.maxOutNorm) -- affects params 124 | model:zeroGradParameters() -- affects gradParams 125 | end, 126 | feedback = dp.Confusion(), 127 | sampler = dp.ShuffleSampler{batch_size = opt.batchSize}, 128 | progress = opt.progress 129 | } 130 | local valid = dp.Evaluator{ 131 | feedback = dp.Confusion(), 132 | sampler = dp.Sampler{batch_size = opt.batchSize} 133 | } 134 | local test = dp.Evaluator{ 135 | feedback = dp.Confusion(), 136 | sampler = dp.Sampler{batch_size = opt.batchSize} 137 | } 138 | 139 | --[[Experiment]]-- 140 | -- this will be used by hypero 141 | local hlog = dp.HyperLog() 142 | 143 | local xp = dp.Experiment{ 144 | model = model, 145 | optimizer = train, 146 | validator = valid, 147 | tester = test, 148 | observer = { 149 | hlog, 150 | dp.EarlyStopper{ 151 | error_report = {'validator','feedback','confusion','accuracy'}, 152 | maximize = true, 153 | max_epochs = opt.maxTries 154 | } 155 | }, 156 | random_seed = os.time(), 157 | max_epoch = opt.maxEpoch 158 | } 159 | 160 | --[[GPU or CPU]]-- 161 | 162 | if opt.cuda then 163 | require 'cutorch' 164 | require 'cunn' 165 | cutorch.setDevice(opt.useDevice) 166 | xp:cuda() 167 | end 168 | 169 | xp:verbose(not opt.silent) 170 | if not opt.silent then 171 | print"Model :" 172 | print(model) 173 | end 174 | 175 | return xp, ds, hlog 176 | end 177 | 178 | --[[hypero]]-- 179 | 180 | conn = hypero.connect() 181 | bat = conn:battery(hopt.batteryName, hopt.versionDesc) 182 | hs = hypero.Sampler() 183 | 184 | -- this allows the hyper-param sampler to be bypassed via cmd-line 185 | function ntbl(param) 186 | return torch.type(param) ~= 'table' and param 187 | end 188 | 189 | 190 | -- loop over experiments 191 | for i=1,hopt.maxHex do 192 | collectgarbage() 193 | local hex = bat:experiment() 194 | local opt = _.clone(hopt) 195 | 196 | -- hyper-parameters 197 | local hp = {} 198 | hp.preprocess = ntbl(opt.preprocess) or hs:categorical(opt.preprocess, {'', 'lcn', 'std'}) 199 | hp.startLR = ntbl(opt.startLR) or hs:logUniform(math.log(opt.startLR[1]), math.log(opt.startLR[2])) 200 | hp.minLR = (ntbl(opt.minLR) or hs:logUniform(math.log(opt.minLR[1]), math.log(opt.minLR[2])))*hp.startLR 201 | hp.satEpoch = ntbl(opt.satEpoch) or hs:normal(unpack(opt.satEpoch)) 202 | hp.momentum = ntbl(opt.momentum) or hs:categorical(opt.momentum, {0,0.9,0.95}) 203 | hp.maxOutNorm = ntbl(opt.maxOutNorm) or hs:categorical(opt.maxOutNorm, {0,1,2,4}) 204 | hp.hiddenDepth = ntbl(opt.hiddenDepth) or hs:randint(unpack(opt.hiddenDepth)) 205 | hp.hiddenSize = ntbl(opt.hiddenSize) or math.round(hs:logUniform(math.log(opt.hiddenSize[1]), math.log(opt.hiddenSize[2]))) 206 | hp.batchSize = ntbl(opt.batchSize) or hs:categorical(opt.batchSize, {16,32,64}) 207 | hp.extra = ntbl(opt.extra) or hs:categorical(opt.extra, {'none','dropout','batchnorm'}) 208 | 209 | for k,v in pairs(hp) do opt[k] = v end 210 | 211 | if not opt.silent then 212 | table.print(opt) 213 | end 214 | 215 | -- build dp experiment 216 | local xp, ds, hlog = buildExperiment(opt) 217 | 218 | -- more hyper-parameters 219 | hp.seed = xp:randomSeed() 220 | hex:setParam(hp) 221 | 222 | -- meta-data 223 | local md = {} 224 | md.name = xp:name() 225 | md.hostname = os.hostname() 226 | md.dataset = torch.type(ds) 227 | 228 | if not opt.silent then 229 | table.print(md) 230 | end 231 | 232 | md.modelstr = tostring(xp:model()) 233 | hex:setMeta(md) 234 | 235 | -- run the experiment 236 | local success, err = pcall(function() xp:run(ds) end ) 237 | 238 | -- results 239 | if success then 240 | res = {} 241 | res.trainCurve = hlog:getResultByEpoch('optimizer:feedback:confusion:accuracy') 242 | res.validCurve = hlog:getResultByEpoch('validator:feedback:confusion:accuracy') 243 | res.testCurve = hlog:getResultByEpoch('tester:feedback:confusion:accuracy') 244 | res.trainAcc = hlog:getResultAtMinima('optimizer:feedback:confusion:accuracy') 245 | res.validAcc = hlog:getResultAtMinima('validator:feedback:confusion:accuracy') 246 | res.testAcc = hlog:getResultAtMinima('tester:feedback:confusion:accuracy') 247 | res.lrs = opt.lrs 248 | res.minimaEpoch = hlog.minimaEpoch 249 | hex:setResult(res) 250 | 251 | if not opt.silent then 252 | table.print(res) 253 | end 254 | else 255 | print(err) 256 | end 257 | end 258 | -------------------------------------------------------------------------------- /test.lua: -------------------------------------------------------------------------------- 1 | -- make sure you setup a .pgpass file and HYPER_PG_CONN env variable 2 | -- (see README.md#install for instructions on how to setup postgreSQL) 3 | 4 | local mytester 5 | local testSchema = 'hypero_test' 6 | local htest = {} 7 | 8 | function htest.Postgres() 9 | local dbconn = hypero.Postgres() 10 | local res = dbconn:execute([[ 11 | DROP SCHEMA IF EXISTS %s CASCADE; 12 | CREATE SCHEMA %s; 13 | ]], {testSchema, testSchema}) 14 | mytester:assert(testSchema, "Postgres schema err") 15 | 16 | local res, err = dbconn:execute([[ 17 | CREATE TABLE %s.test5464 ( n INT4, v FLOAT4, s TEXT ); 18 | ]], {testSchema}) 19 | mytester:assert(res, "Postgres create table err") 20 | res, err = dbconn:fetchOne("SELECT * FROM %s.test5464", testSchema) 21 | mytester:assert(_.isEmpty(res), "Postgres select empty table err") 22 | 23 | local param_list = { 24 | {5, 4.1, 'asdfasdf'}, 25 | {6, 3.5, 'asdfashhd'}, 26 | {6, 3.7, 'asdfashhd2'} 27 | } 28 | res, err = dbconn:executeMany( 29 | string.format([[ 30 | INSERT INTO %s.test5464 VALUES (%s, %s, '%s');]], 31 | testSchema, '%s', '%s', '%s'), param_list) 32 | mytester:assertTableEq(res, {1,1,1}, "Postgres insert many err") 33 | 34 | res, err = dbconn:fetch([[ 35 | SELECT * FROM %s.test5464 WHERE n = 6 36 | ]], {testSchema}) 37 | mytester:assert(torch.type(res) == 'table') 38 | mytester:assert(#res == 2, "Postgres select serialize err") 39 | mytester:assert(#res[1] == 3, "Postgres missing columns err") 40 | 41 | -- test serialization/deserialization of postgres object 42 | local dbconn_str = torch.serialize(dbconn) 43 | dbconn = torch.deserialize(dbconn_str) 44 | 45 | local res, err = dbconn:execute([[ 46 | CREATE TABLE %s.test5464 ( n INT4, v FLOAT4, s TEXT ); 47 | ]], {testSchema}) 48 | mytester:assert(res == nil and err, "Postgres serialize table exist err") 49 | res, err = dbconn:fetchOne("SELECT * FROM %s.test5464", testSchema) 50 | mytester:assert(torch.type(res) == 'table', "Postgres serialize select table err") 51 | mytester:assert(#res == 3, "Postgres serialize select table err") 52 | 53 | res, err = dbconn:executeMany( 54 | string.format([[ 55 | INSERT INTO %s.test5464 VALUES (%s, %s, '%s');]], 56 | testSchema, '%s', '%s', '%s'), param_list) 57 | mytester:assertTableEq(res, {1,1,1}, "Postgres serialize insert many err") 58 | 59 | res, err = dbconn:fetch([[ 60 | SELECT * FROM %s.test5464 WHERE n = 6 61 | ]], {testSchema}) 62 | mytester:assert(torch.type(res) == 'table') 63 | mytester:assert(#res == 4, "Postgres select serialize err") 64 | mytester:assert(#res[1] == 3, "Postgres missing columns err") 65 | 66 | dbconn:close() 67 | end 68 | 69 | function htest.Connect() 70 | local dbconn = hypero.Postgres() 71 | local res, err = dbconn:execute("DROP SCHEMA IF EXISTS %s CASCADE", testSchema) 72 | mytester:assert(res, "DROP SCHEMA error") 73 | res, err = dbconn:execute("SELECT * FROM %s.battery", testSchema) 74 | mytester:assert(not res, "Connect DROP SCHEMA not dropped err") 75 | dbconn:close() 76 | local conn = hypero.connect{schema=testSchema} 77 | res, err = conn:execute("SELECT * FROM %s.battery", testSchema) 78 | mytester:assert(res, "Connect battery TABLE err") 79 | res, err = conn:execute("SELECT * FROM %s.version", testSchema) 80 | mytester:assert(res, "Connect version TABLE err") 81 | res, err = conn:execute("SELECT * FROM %s.experiment", testSchema) 82 | mytester:assert(res, "Connect experiment TABLE err") 83 | res, err = conn:execute("SELECT * FROM %s.param", testSchema) 84 | mytester:assert(res, "Connect param TABLE err") 85 | res, err = conn:execute("SELECT * FROM %s.meta", testSchema) 86 | mytester:assert(res, "Connect metadata TABLE err") 87 | res, err = conn:execute("SELECT * FROM %s.result", testSchema) 88 | mytester:assert(res, "Connect result TABLE err") 89 | conn:close() 90 | end 91 | 92 | function htest.Battery() 93 | local verbose = false 94 | local dbconn = hypero.Postgres() 95 | local res, err = dbconn:execute("DROP SCHEMA IF EXISTS %s CASCADE", testSchema) 96 | mytester:assert(res, "DROP SCHEMA error") 97 | local conn = hypero.connect{schema=testSchema,dbconn=dbconn} 98 | local batName = "Test 22" 99 | local bat = conn:battery(batName, nil, verbose) 100 | mytester:assert(bat.id == 1, "Battery id err") 101 | mytester:assert(bat.verId == 1, "Battery verId err") 102 | mytester:assert(bat.verDesc == "Initial battery version") 103 | local verDesc = "Version 33" 104 | local verId2, verDesc2 = bat:version(verDesc) 105 | mytester:assert(verId2 == 2, "Battery version() id err") 106 | mytester:assert(verDesc2 == verDesc, "Battery version() desc err") 107 | local bat = conn:battery(batName, verDesc, verbose) 108 | mytester:assert(bat.id == 1, "Battery id err") 109 | mytester:assert(bat.verId == 2, "Battery verId err") 110 | res, err = conn:fetchOne("SELECT COUNT(*) FROM %s.battery", testSchema) 111 | mytester:assert(res, err) 112 | mytester:assert(res[1] == '1') 113 | res, err = conn:fetchOne("SELECT COUNT(*) FROM %s.version", testSchema) 114 | mytester:assert(res, err) 115 | mytester:assert(res[1] == '2') 116 | 117 | local verIds = bat:fetchVersions() 118 | mytester:assert(#verIds == 2) 119 | mytester:assertTableEq(verIds, {1, 2}) 120 | bat:version("Version 34") 121 | local verIds = bat:fetchVersions("Version 33") 122 | mytester:assert(#verIds == 2) 123 | mytester:assertTableEq(verIds, {2, 3}) 124 | 125 | local verId = bat:getVerId("Version 33") 126 | mytester:assert(verId == 2) 127 | 128 | for j, verDesc in ipairs{'Version 35', 'Version 36'} do 129 | bat:version(verDesc) 130 | 131 | -- create some dummy experiments 132 | for i=1,10 do 133 | local hex = bat:experiment() 134 | hex:setParam{lr=0.0001,mom=0.9, i=i, v=verDesc} 135 | hex:setMeta{hostname='bobby', screen=3463, i=i, v=verDesc, d=3} 136 | hex:setResult{valid_acc = 0.0001, i=i, v=verDesc} 137 | end 138 | 139 | local res, err = conn:fetch([[ 140 | SELECT hex_id FROM %s.experiment 141 | WHERE (bat_id, ver_id) = (%s, %s) 142 | ]], {testSchema, bat.id, bat.verId}) 143 | mytester:assert(res, err) 144 | mytester:assert(#res == 10) 145 | end 146 | 147 | local verIds = bat:fetchVersions("Version 33") 148 | local hexIds = bat:fetchExperiments(verIds) 149 | mytester:assert(#hexIds == 20) 150 | mytester:assert(torch.type(hexIds[1]) == 'number') 151 | 152 | local param, colNames = bat:getParam(hexIds, {'i', 'v'}) 153 | mytester:assert(table.length(param) == 20) 154 | mytester:assert(#param[1] == 2 and #param[20] == 2) 155 | mytester:assertTableEq(colNames, {'i', 'v'}) 156 | 157 | local param, colNames = bat:getParam(hexIds, '*') 158 | mytester:assert(table.length(param) == 20) 159 | mytester:assert(#param[1] == 4 and #param[20] == 4) 160 | mytester:assert(#colNames == 4) 161 | 162 | local param, colNames = bat:getMeta(hexIds, {'i', 'v'}) 163 | mytester:assert(table.length(param) == 20) 164 | mytester:assert(#param[1] == 2 and #param[20] == 2) 165 | mytester:assertTableEq(colNames, {'i', 'v'}) 166 | 167 | local param, colNames = bat:getMeta(hexIds, '*') 168 | mytester:assert(table.length(param) == 20) 169 | mytester:assert(#param[1] == 5 and #param[20] == 5) 170 | mytester:assert(#colNames == 5) 171 | 172 | local param, colNames = bat:getResult(hexIds, {'i', 'v'}) 173 | mytester:assert(table.length(param) == 20) 174 | mytester:assert(#param[1] == 2 and #param[20] == 2) 175 | mytester:assertTableEq(colNames, {'i', 'v'}) 176 | 177 | local param, colNames = bat:getResult(hexIds, '*') 178 | mytester:assert(table.length(param) == 20) 179 | mytester:assert(#param[1] == 3 and #param[20] == 3) 180 | mytester:assert(#colNames == 3) 181 | 182 | local tbl, colNames = bat:exportTable() 183 | mytester:assert(table.length(tbl) == 10) 184 | mytester:assert(table.length(tbl[1]) == 13) 185 | 186 | local prevHexId = -1 187 | local sorted = true 188 | for i,row in ipairs(tbl) do 189 | sorted = sorted and row[1] > prevHexId 190 | prevHexId = row[1] 191 | end 192 | mytester:assert(sorted) 193 | 194 | conn:close() 195 | end 196 | 197 | function htest.Experiment() 198 | local dbconn = hypero.Postgres() 199 | local res, err = dbconn:execute("DROP SCHEMA IF EXISTS %s CASCADE", testSchema) 200 | mytester:assert(res, "DROP SCHEMA error") 201 | local conn = hypero.connect{schema=testSchema,dbconn=dbconn} 202 | local batName = "Test 23" 203 | local bat = conn:battery(batName, verDesc, false) 204 | local hex = bat:experiment() 205 | mytester:assert(hex.id == 1) 206 | res, err = conn:fetchOne("SELECT COUNT(*) FROM %s.experiment", testSchema) 207 | mytester:assert(res, err) 208 | mytester:assert(res[1] == '1') 209 | local hex = hypero.Experiment(conn, 1) 210 | mytester:assert(hex.id == 1) 211 | res, err = conn:fetchOne("SELECT COUNT(*) FROM %s.experiment", testSchema) 212 | mytester:assert(res, err) 213 | mytester:assert(res[1] == '1') 214 | local success = pcall(function() return hypero.Experiment(2) end) 215 | mytester:assert(not success) 216 | 217 | -- hyperParam 218 | local hp = {lr=0.0001,mom=0.9} 219 | hex:setParam(hp) 220 | local hp2 = hex:getParam() 221 | mytester:assert(hp2.lr == 0.0001) 222 | mytester:assert(hp2.mom == 0.9) 223 | hp.lr = 0.01 224 | mytester:assert(not pcall(function() return hex:setParam(hp) end)) 225 | hex:setParam(hp, true) 226 | local hp2 = hex:getParam() 227 | mytester:assert(hp2.lr == 0.01) 228 | 229 | -- metaData 230 | local md = {hostname='bobby', screen=3463} 231 | hex:setMeta(md) 232 | local md2 = hex:getMeta() 233 | mytester:assert(md2.hostname == 'bobby') 234 | mytester:assert(md2.screen == 3463) 235 | md.hostname = 'sonny' 236 | mytester:assert(not pcall(function() return hex:setMeta(md) end)) 237 | hex:setMeta(md, true) 238 | local md2 = hex:getMeta() 239 | mytester:assert(md2.hostname == 'sonny') 240 | 241 | -- result 242 | local res = {valid_acc = 0.0001, test_acc = 0.02} 243 | hex:setResult(res) 244 | local res2 = hex:getResult() 245 | mytester:assert(res2.valid_acc == 0.0001) 246 | mytester:assert(res2.test_acc == 0.02) 247 | res.valid_acc = 0.01 248 | mytester:assert(not pcall(function() return hex:setResult(es) end)) 249 | hex:setResult(res, true) 250 | local res2 = hex:getResult() 251 | mytester:assert(res2.valid_acc == 0.01) 252 | 253 | conn:close() 254 | end 255 | 256 | function htest.Sampler() 257 | local hs = hypero.Sampler() 258 | local val = hs:categorical({0.001, 0.0001, 0.0001, 10000}, {1,2,3,4}) 259 | mytester:assert(val == 4, "Sampler err") 260 | local val = hs:normal(0, 1) 261 | mytester:assert(torch.type(val) == 'number') 262 | local val = hs:uniform(0, 1) 263 | mytester:assert(val >= 0 and val <= 1) 264 | local val = hs:logUniform(0, 1) 265 | mytester:assert(val >= math.exp(0) and val <= math.exp(1)) 266 | local val = hs:randint(1,100) 267 | mytester:assert(math.floor(val) == val) 268 | mytester:assert(val >= 1 and val <= 100) 269 | end 270 | 271 | function hypero.test(tests) 272 | math.randomseed(os.time()) 273 | mytester = torch.Tester() 274 | mytester:add(htest) 275 | mytester:run(tests) 276 | return mytester 277 | end 278 | -------------------------------------------------------------------------------- /Battery.lua: -------------------------------------------------------------------------------- 1 | ------------------------------------------------------------------------ 2 | --[[ Battery ]]-- 3 | -- A battery of experiments which can have multiple versions. 4 | ------------------------------------------------------------------------ 5 | local Battery = torch.class("hypero.Battery") 6 | 7 | function Battery:__init(conn, name, verbose, strict) 8 | assert(torch.type(name) == 'string') 9 | assert(name ~= '') 10 | assert(torch.isTypeOf(conn, "hypero.Connect")) 11 | self.conn = conn 12 | self.name = name 13 | self.verbose = (verbose == nil) and true or verbose 14 | 15 | -- check if the battery already exists 16 | local row, err = self.conn:fetchOne([[ 17 | SELECT bat_id FROM %s.battery WHERE bat_name = '%s'; 18 | ]], {self.conn.schema, self.name}) 19 | 20 | if (not row or _.isEmpty(row)) then 21 | if strict then 22 | error"Battery doesn't exist (create it with strict=false)" 23 | end 24 | if self.verbose then 25 | print("Creating new battery : "..name) 26 | end 27 | row, err = self.conn:fetchOne([[ 28 | INSERT INTO %s.battery (bat_name) VALUES ('%s') RETURNING bat_id; 29 | ]], {self.conn.schema, self.name}) 30 | 31 | if not row or _.isEmpty(row) then 32 | -- this can happen when multiple clients try to INSERT 33 | -- the same battery simultaneously 34 | local row, err = self.conn:fetchOne([[ 35 | SELECT bat_id FROM %s.battery WHERE bat_name = '%s'; 36 | ]], {self.conn.schema, self.name}) 37 | if not row then 38 | error("Battery init error : \n"..err) 39 | end 40 | end 41 | end 42 | 43 | self.id = tonumber(row[1]) 44 | end 45 | 46 | -- Version requires a description (like a commit message). 47 | -- A battery can have multiple versions. 48 | -- Each code change could have its own battery version. 49 | function Battery:version(desc, strict) 50 | if torch.type(desc) == 'string' and desc == '' then 51 | desc = nil 52 | end 53 | 54 | if desc then 55 | -- identify version using description desc : 56 | assert(torch.type(desc) == 'string', "expecting battery version description string") 57 | self.verDesc = desc 58 | 59 | -- check if the version already exists 60 | local row, err = self.conn:fetchOne([[ 61 | SELECT ver_id FROM %s.version 62 | WHERE (bat_id, ver_desc) = (%s, '%s'); 63 | ]], {self.conn.schema, self.id, self.verDesc}) 64 | 65 | if not row or _.isEmpty(row) then 66 | if strict then 67 | error"Battery version doesn't exist (create it with strict=false)" 68 | end 69 | if self.verbose then 70 | print("Creating new battery version : "..self.verDesc) 71 | end 72 | row, err = self.conn:fetchOne([[ 73 | INSERT INTO %s.version (bat_id, ver_desc) 74 | VALUES (%s, '%s') RETURNING ver_id; 75 | ]], {self.conn.schema, self.id, self.verDesc}) 76 | 77 | if not row or _.isEmpty(row) then 78 | -- this can happen when multiple clients try to INSERT 79 | -- the same version simultaneously 80 | local err 81 | row, err = self.conn:fetchOne([[ 82 | SELECT ver_id FROM %s.version WHERE ver_desc = '%s'; 83 | ]], {self.conn.schema, self.verDesc}) 84 | if not self.id then 85 | error("Battery version error : \n"..err) 86 | end 87 | end 88 | end 89 | self.verId = tonumber(row[1]) 90 | elseif not self.verId then 91 | -- try to obtain the most recent version : 92 | local row, err = self.conn:fetchOne([[ 93 | SELECT MAX(ver_id) FROM %s.version WHERE bat_id = %s; 94 | ]], {self.conn.schema, self.id}) 95 | 96 | if not row or _.isEmpty(row) then 97 | if strict then 98 | error"Battery version not initialized (create it with strict=false)" 99 | end 100 | self.verDesc = self.verDesc or "Initial battery version" 101 | if self.verbose then 102 | print("Creating new battery version : "..self.verDesc) 103 | end 104 | row, err = self.conn:fetchOne([[ 105 | INSERT INTO %s.version (bat_id, ver_desc) 106 | VALUES (%s, '%s') RETURNING ver_id; 107 | ]], {self.conn.schema, self.id, self.verDesc}) 108 | 109 | if not row or _.isEmpty(row) then 110 | -- this can happen when multiple clients try to INSERT 111 | -- the same version simultaneously 112 | local row, err = self.conn:fetchOne([[ 113 | SELECT ver_id FROM %s.version WHERE ver_desc = '%s'; 114 | ]], {self.conn.schema, self.verDesc}) 115 | if not row then 116 | error("Battery version error : \n"..err) 117 | end 118 | end 119 | end 120 | self.verId = tonumber(row[1]) 121 | end 122 | 123 | return self.verId, self.verDesc 124 | end 125 | 126 | -- factory method for experiments of this battery 127 | function Battery:experiment() 128 | assert(self.id, self.verId) 129 | return hypero.Experiment(self.conn, self) 130 | end 131 | 132 | 133 | -- fetch all version ids from db 134 | -- or new versions if minVerDesc is specified 135 | function Battery:fetchVersions(minVerDesc) 136 | local rows, err 137 | if minVerDesc then 138 | assert(torch.type(minVerDesc) == 'string', "expecting battery version description string") 139 | -- check if the version already exists 140 | local row, err = self.conn:fetchOne([[ 141 | SELECT ver_id FROM %s.version 142 | WHERE (bat_id, ver_desc) = (%s, '%s'); 143 | ]], {self.conn.schema, self.id, minVerDesc}) 144 | 145 | if not row or _.isEmpty(row) then 146 | if self.verbose then 147 | print("Could not find battery version : "..minVerDesc) 148 | if err then print(err) end 149 | end 150 | else 151 | rows, err = self.conn:fetch([[ 152 | SELECT ver_id FROM %s.version 153 | WHERE bat_id = %s AND ver_id >= %s ORDER BY ver_id ASC; 154 | ]], {self.conn.schema, self.id, row[1]}) 155 | 156 | end 157 | else 158 | rows, err = self.conn:fetch([[ 159 | SELECT ver_id FROM %s.version 160 | WHERE bat_id = %s ORDER BY ver_id ASC; 161 | ]], {self.conn.schema, self.id}) 162 | end 163 | 164 | local verIds = {} 165 | if rows then 166 | for i,row in ipairs(rows) do 167 | table.insert(verIds, tonumber(row[1])) 168 | end 169 | else 170 | error("Batter:fetchVersions error : \n"..tostring(err)) 171 | end 172 | 173 | return verIds 174 | end 175 | 176 | -- fetch all experiment ids for version(s) from db 177 | function Battery:fetchExperiments(verId) 178 | verId = verId or self.verId 179 | verId = torch.type(verId) ~= 'table' and {verId} or verId 180 | 181 | local rows, err = self.conn:fetch([[ 182 | SELECT hex_id FROM %s.experiment 183 | WHERE bat_id = %s AND ver_id IN (%s); 184 | ]], {self.conn.schema, self.id, table.concat(verId, ', ')}) 185 | 186 | local hexIds = {} 187 | if rows then 188 | for i,row in ipairs(rows) do 189 | table.insert(hexIds, tonumber(row[1])) 190 | end 191 | else 192 | error("Battery:fetchExperiments error :\n"..tostring(err)) 193 | end 194 | 195 | return hexIds 196 | end 197 | 198 | -- get version id of version having description verDesc 199 | function Battery:getVerId(verDesc) 200 | assert(torch.type(verDesc == 'string')) 201 | local row, err = self.conn:fetchOne([[ 202 | SELECT ver_id FROM %s.version 203 | WHERE (bat_id, ver_desc) = (%s, '%s') 204 | ]], {self.conn.schema, self.id, verDesc}) 205 | 206 | if not row then 207 | error("Battery:getVerId err :\n"..tostring(err)) 208 | end 209 | 210 | return tonumber(row[1]) 211 | end 212 | 213 | local function db2tbl(rows, colNames) 214 | local all = torch.type(colNames) == 'string' and colNames == '*' 215 | colNames = torch.type(colNames) == 'table' and colNames or string.split(colNames, ',') 216 | all = all or _.isEmpty(colNames) 217 | if all then 218 | colNames = {} 219 | end 220 | local colDict = {} 221 | 222 | local tbl = {} 223 | for i=1,#rows do 224 | local hexId, jsonVals = unpack(rows[i]) 225 | local vals = json.decode.decode(jsonVals) 226 | local row = {} 227 | if all then 228 | for k,v in pairs(vals) do 229 | if not colDict[k] then 230 | colDict[k] = true 231 | table.insert(colNames, k) 232 | end 233 | end 234 | end 235 | for j,name in ipairs(colNames) do 236 | row[j] = vals[name] 237 | end 238 | tbl[tonumber(hexId)] = row 239 | end 240 | return tbl, colNames 241 | end 242 | 243 | -- get hyper-params of experiments hexIds 244 | -- The output is a table of tables (rows) 245 | -- Each row is ordered by names (column names) 246 | function Battery:getParam(hexIds, names) 247 | hexIds = torch.type(hexIds) == 'table' and hexIds or {hexIds} 248 | local rows, err = self.conn:fetch([[ 249 | SELECT hex_id, hex_param FROM %s.param WHERE hex_id IN (%s) 250 | ]], {self.conn.schema, table.concat(hexIds,', ')}) 251 | 252 | if not rows then 253 | error("Battery:getParam err"..tostring(err)) 254 | end 255 | 256 | local tbl, names = db2tbl(rows, names or {}) 257 | return tbl, names 258 | end 259 | 260 | -- get meta-data of experiments hexIds 261 | function Battery:getMeta(hexIds, names) 262 | hexIds = torch.type(hexIds) == 'table' and hexIds or {hexIds} 263 | local rows, err = self.conn:fetch([[ 264 | SELECT hex_id, hex_meta FROM %s.meta WHERE hex_id IN (%s) 265 | ]], {self.conn.schema, table.concat(hexIds,', ')}) 266 | 267 | if not rows then 268 | error("Battery:getMeta err"..tostring(err)) 269 | end 270 | 271 | local tbl, names = db2tbl(rows, names) 272 | return tbl, names 273 | end 274 | 275 | -- get result of experiments hexIds 276 | function Battery:getResult(hexIds, names) 277 | hexIds = torch.type(hexIds) == 'table' and hexIds or {hexIds} 278 | local rows, err = self.conn:fetch([[ 279 | SELECT hex_id, hex_result FROM %s.result WHERE hex_id IN (%s) 280 | ]], {self.conn.schema, table.concat(hexIds,', ')}) 281 | 282 | if not rows then 283 | error("Battery:getResult err"..tostring(err)) 284 | end 285 | 286 | local tbl, names = db2tbl(rows, names) 287 | return tbl, names 288 | end 289 | 290 | -- export hyper-param, result and meta-data as a table of rows 291 | -- where each row is an experiment (a list of values). 292 | function Battery:exportTable(config) 293 | config = config or {} 294 | assert(type(config) == 'table', "Constructor requires key-value arguments") 295 | local args, verDesc, minVer, paramNames, metaNames, resultNames, 296 | orderBy, asc = xlua.unpack( 297 | {config}, 298 | 'Battery:exportTable', 299 | 'exports the battery of experiments as a lua table', 300 | {arg='verDesc', type='string', default=self.verDesc, 301 | help='description of version to be exported'}, 302 | {arg='minVer', type='boolean', default=false, 303 | help='versionDesc specifies the minimum version to be exported'}, 304 | {arg='paramNames', type='string | table', default='*', 305 | help='comma separated list of hyper-param columns to retrieve'}, 306 | {arg='metaNames', type='string | table', default='*', 307 | help='comma separated list of meta-data columns to retrieve'}, 308 | {arg='resultNames', type='string | table', default='*', 309 | help='comma separated list of result columns to retrieve'}, 310 | {arg='orderBy', type='string', default='hexId', 311 | help='order by this result column'}, 312 | {arg='asc', type='boolean', default=true, 313 | help='row ordering is ascending. False is descending.'} 314 | ) 315 | 316 | -- select versions 317 | local verIds 318 | if verDesc == '' or verDesc == '*' or verDesc == nil then 319 | verIds = self:fetchVersions() 320 | elseif minVer then 321 | verIds = self:fetchVersions(verDesc) 322 | else 323 | verIds = {self:getVerId(verDesc)} 324 | end 325 | assert(#verIds > 0, "no versions found") 326 | 327 | -- select experiments 328 | local hexIds = self:fetchExperiments(verIds) 329 | assert(#hexIds > 0, "no experiments found") 330 | 331 | -- select hyper-param, meta-data and result 332 | local hp, hpNames = self:getParam(hexIds, paramNames) 333 | local res, resNames = self:getResult(hexIds, resultNames) 334 | local md, mdNames = self:getMeta(hexIds, metaNames) 335 | 336 | -- join tables using hexId 337 | local tbl = {} 338 | local names = {hpNames, resNames, mdNames} 339 | 340 | for i, hexId in ipairs(hexIds) do 341 | local hp, res, md = hp[hexId], res[hexId], md[hexId] 342 | local row = {} 343 | local offset = 0 344 | for i,subtbl in ipairs{hp, res, md} do 345 | for k,v in pairs(subtbl) do 346 | row[k+offset] = v 347 | end 348 | offset = offset + #names[i] 349 | end 350 | if not _.isEmpty(row) then 351 | table.insert(row, 1, hexId) 352 | table.insert(tbl, row) 353 | end 354 | end 355 | 356 | local colNames = _.flatten(names) 357 | table.insert(colNames, 1, 'hexId') 358 | 359 | -- orderBy 360 | if orderBy and orderBy ~= '' then 361 | local colIdx = _.find(colNames, orderBy) 362 | assert(colIdx, "unknown orderBy column name") 363 | _.sort(tbl, function(rowA, rowB) 364 | local valA, valB = rowA[colIdx], rowB[colIdx] 365 | local success, rtn = pcall(function() 366 | if valA == nil and valB == nil then 367 | return false 368 | elseif valA == nil then 369 | return false 370 | elseif valB == nil then 371 | return true 372 | elseif asc then 373 | return valA < valB 374 | else 375 | return valA > valB 376 | end 377 | end) 378 | if success then 379 | return rtn 380 | else 381 | return false 382 | end 383 | end) 384 | end 385 | 386 | return tbl, colNames 387 | end 388 | --------------------------------------------------------------------------------