├── .gitignore ├── install.sh ├── src ├── _env.lua ├── loaders │ ├── celeba.lua │ ├── hdr.lua │ ├── mnist.lua │ ├── pix2pix.lua │ ├── cifar.lua │ ├── loader.lua │ └── places.lua ├── core │ ├── trainlog.lua │ ├── donkey.lua │ ├── optimizer.lua │ ├── dispatcher.lua │ ├── model.lua │ ├── data.lua │ ├── settings.lua │ └── trainer.lua ├── init.lua ├── models │ ├── lenet5.lua │ ├── squeezenet.lua │ ├── unet.lua │ ├── vgg.lua │ ├── colornet.lua │ ├── dcgan.lua │ └── alexnet.lua └── util │ ├── plot.lua │ ├── slurm.lua │ ├── logger.lua │ ├── color.lua │ └── helper.lua ├── CMakeLists.txt ├── dlt-1.0-1.rockspec ├── doc ├── slurm.md ├── model.md ├── misc.md ├── dispatcher.md ├── data.md ├── loader.md └── trainer.md ├── LICENSE └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | build/ -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | luarocks make dlt-1.0-1.rockspec 4 | -------------------------------------------------------------------------------- /src/_env.lua: -------------------------------------------------------------------------------- 1 | -- https://github.com/torch/torch7/issues/525 2 | 3 | local dlt = {} 4 | return dlt -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR) 2 | CMAKE_POLICY(VERSION 2.6) 3 | FIND_PACKAGE(Torch REQUIRED) 4 | 5 | SET(src) 6 | FILE(GLOB_RECURSE luasrc src/*.lua) 7 | ADD_TORCH_PACKAGE(dlt "${src}" "${luasrc}" "Deep Learning Toolbox") -------------------------------------------------------------------------------- /src/loaders/celeba.lua: -------------------------------------------------------------------------------- 1 | local dlt = require('dlt._env') 2 | 3 | local P,parent = torch.class('dlt.CelebA','dlt.Loader',dlt) 4 | 5 | -- Data: http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html 6 | -- Settings 7 | -- s.path 8 | -- s.shuffle (defaults to true) 9 | -- s.assignPoint = function(point,iBatchMember,img) 10 | -- s.type [byte] 11 | function P:__init(s) 12 | self.name = 'CelebA' 13 | parent.__init(self,s) 14 | self.nAll = 202599 15 | self.path = s.path 16 | self.type = s.type or 'byte' 17 | self.sets.testing = nil 18 | self.sets.validation = nil 19 | end 20 | 21 | function P:dataPoint(index,setName) 22 | local file = paths.concat(self.path,string.format('%06d.jpg',index)) 23 | return image.load(file,nil,self.type) 24 | end 25 | 26 | function P:initInstance(setName) 27 | self.set = self.set or {} 28 | self.set[setName] = { nPoints = self.nAll } 29 | end -------------------------------------------------------------------------------- /src/core/trainlog.lua: -------------------------------------------------------------------------------- 1 | local dlt = require('dlt._env') 2 | 3 | local T,parent = torch.class('Trainlog',dlt) 4 | 5 | -- 6 | function T:__init(name,savePath) 7 | -- Configure arguments 8 | if name == nil then dlt.log:error('Name not provided for trainlog.') end 9 | savePath = savePath or paths.concat() 10 | savePath = dlt.help.checkHomePath(savePath) 11 | dlt.help.checkMakeDir(savePath) 12 | -- Formats 13 | self.numFormat = '%1.4e' 14 | self.percFormat = '%.1f' 15 | self.delim = ',\t' 16 | 17 | -- Make file 18 | self.filename = paths.concat(savePath,name .. '.log') 19 | local exists = paths.filep(self.filename) 20 | self.file = io.open(self.filename,'a') 21 | if not exists then self.file:write('loss\n') end 22 | end 23 | 24 | function T:log(loss) 25 | self.file:write(string.format(self.numFormat,loss) .. '\n') 26 | self.file:flush() 27 | end 28 | -------------------------------------------------------------------------------- /src/loaders/hdr.lua: -------------------------------------------------------------------------------- 1 | local dlt = require('dlt._env') 2 | 3 | local H,parent = torch.class('dlt.HDR','dlt.Loader',dlt) 4 | 5 | -- s.recursive 6 | function H:__init(s) 7 | s.name = 'HDR' 8 | parent.__init(self,s) 9 | self.recursive = s.recursive 10 | self.path = s.path 11 | self.sets.validation = nil 12 | self.sets.testing = nil 13 | end 14 | 15 | 16 | function H:dataPoint(index,setName) 17 | setName = setName or self.currentSet 18 | return hdrimage.load(self.set[setName].list[index]) 19 | end 20 | 21 | function H:initInstance(setName) 22 | self.set = self.set or {} 23 | setName = setName or self.currentSet 24 | local list = dlt.help.getFiles(self.path, 25 | {hdr = true,exr = true}, 26 | self.recursive) 27 | self.set[setName] = { 28 | list = list, 29 | nPoints = #list 30 | } 31 | end 32 | -------------------------------------------------------------------------------- /dlt-1.0-1.rockspec: -------------------------------------------------------------------------------- 1 | package = "dlt" 2 | version = "1.0-1" 3 | 4 | source = { 5 | url = "git://github.com/dmarnerides/dlt.git", 6 | tag = "master" 7 | } 8 | 9 | description = { 10 | summary = "Deep Learning Toolbox for Torch", 11 | detailed = [[ 12 | This package provides a set of tools to easily create/run/replicate 13 | deep learning experiments using Torch. 14 | ]], 15 | homepage = "https://github.com/dmarnerides/dlt", 16 | license = "BSD" 17 | } 18 | 19 | dependencies = { 20 | "torch >= 7.0", 21 | "paths", 22 | "class", 23 | "optim", 24 | "threads" 25 | } 26 | 27 | build = { 28 | type = "command", 29 | build_command = [[ 30 | cmake -E make_directory build; 31 | cd build; 32 | cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH="$(LUA_BINDIR)/.." -DCMAKE_INSTALL_PREFIX="$(PREFIX)"; 33 | $(MAKE) 34 | ]], 35 | install_command = "cd build && $(MAKE) install" 36 | } -------------------------------------------------------------------------------- /doc/slurm.md: -------------------------------------------------------------------------------- 1 | # Slurm 2 | 3 | ## Usage 4 | Create scripts for the [slurm scheduler](https://slurm.schedmd.com/). 5 | 6 | ## Example 7 | File: slurm.lua 8 | ```lua 9 | local dlt = require('dlt') 10 | -- Create object (remember, settings are automatically the ones from arg) 11 | local slurm = dlt.Slurm() 12 | local file,script = slurm:createScript('~') -- This will create the script in home directory 13 | print('Slurm script location: ' .. file) 14 | print('Contents:') 15 | print(script) 16 | ``` 17 | File: precommands.sh 18 | ```bash 19 | # These are the precommands 20 | module load cudnn 21 | ``` 22 | 23 | Possible run: 24 | ```bash 25 | ## To run 'th something.lua' after runing the contents of precommands.sh, requested time 12:05:20 26 | ## Request one node with 8 tasks, on fat partition (but not fat2 because it's faulty) 27 | th slurm.lua -sJobname myjob -sTh /path/to/something.lua -sPrecommands /path/to/precommands.sh -sTime 12:05:20 -sPartition fat -sExclude fat2 -sTasks 8 -sNodes 1 28 | ``` 29 | -------------------------------------------------------------------------------- /src/init.lua: -------------------------------------------------------------------------------- 1 | local dlt = require('dlt._env') 2 | require('torch') 3 | require('paths') 4 | require('class') 5 | require('optim') 6 | require('nn') 7 | threads = require('threads') 8 | dlt.models = {} 9 | dlt.have = {} 10 | for i,mod in pairs({'cutorch','cunn','cudnn','image', 11 | 'csvigo','hdrimage','gnuplot'}) do 12 | dlt.have[mod] = pcall(require,mod) 13 | end 14 | if dlt.have.hdrimage then hdrimage = require('hdrimage') end 15 | 16 | local modules = { 17 | 'logger', 18 | 'helper', 19 | 'slurm', 20 | 'color', 21 | 'plot', 22 | 'settings', 23 | 'colornet', 'lenet5', 'alexnet', 24 | 'squeezenet', 'unet', 'vgg', 'dcgan', 25 | 'model', 26 | 'donkey', 27 | 'data', 28 | 'loader', 29 | 'pix2pix', 30 | 'places', 31 | 'celeba', 32 | 'cifar', 33 | 'mnist', 34 | 'hdr', 35 | 'optimizer', 36 | 'trainer', 37 | 'trainlog', 38 | 'dispatcher' 39 | } 40 | 41 | for i,mod in ipairs(modules) do require('dlt.' .. mod) end 42 | -- Make dlt log 43 | dlt.log = dlt.Logger() 44 | return dlt -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2017, Demetris Marnerides 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /src/core/donkey.lua: -------------------------------------------------------------------------------- 1 | local dlt = require('dlt._env') 2 | 3 | local D,parent = torch.class('dlt.Donkey',dlt) 4 | 5 | function D:__init(loader,pointSize,batchSize,useLocks, garbageCollect, 6 | tensorType,set) 7 | self.loader = loader 8 | if loader == nil then 9 | dlt.log:error('No loader provided for donkey.') 10 | end 11 | if pointSize == nil then 12 | dlt.log:error('No pointSize provided for donkey.') 13 | end 14 | self.useLocks = useLocks or false 15 | self.batchSize = batchSize or 1 16 | self.garbageCollect = garbageCollect or 0 17 | self.garbageCounter = 0 18 | tensorType = tensorType or 'float' 19 | -- Create batch in main memory 20 | self.batch = dlt.help.createBatch(self.batchSize, pointSize, 21 | tensorType, 'cpu') 22 | -- Initialize Data 23 | self.loader:init(set) 24 | -- Create Timer 25 | self.timer = torch.Timer() 26 | end 27 | 28 | function D:getBatch(iPoint) 29 | self.timer:reset() 30 | if self.garbageCollect > 0 then self:collectgarbage() end 31 | if self.useLocks then mutex:lock() end 32 | self.loader:assignBatch(self.batch,iPoint,self.batchSize) 33 | if self.useLocks then mutex:unlock() end 34 | return self.batch, self.timer:time().real 35 | end 36 | 37 | function D:collectgarbage() 38 | if self.garbageCounter >= self.garbageCollect then 39 | self.garbageCounter = 0 40 | collectgarbage() 41 | else self.garbageCounter = self.garbageCounter + 1 end 42 | end -------------------------------------------------------------------------------- /doc/model.md: -------------------------------------------------------------------------------- 1 | # Model 2 | 3 | ## Usage 4 | ```lua 5 | net = dlt.Model(create [, name, save]) 6 | ``` 7 | * `create` is a function that returns a network OR is the path (string) to a torch serialized network. 8 | * `name` string, defaults to "model" 9 | * `save` boolean, defaults to true. Whether model will be saved to disk. 10 | 11 | ## Example 12 | 13 | File model.lua : 14 | ```lua 15 | local dlt = require('dlt') 16 | -- model creation function 17 | local function myCustomModelFunction() 18 | return nn.Sequential() 19 | :add(nn.Linear(10,10)) 20 | :add(nn.Sigmoid()) 21 | end 22 | 23 | -- Create dlt.Model instance 24 | local net = dlt.Model(myCustomModelFunction,'customModel') 25 | 26 | print(net) 27 | 28 | -- Some standard functions are provided 29 | net:evaluate() 30 | net:training() 31 | net:zeroGradParameters() 32 | 33 | -- Make a test input mini batch of size 4 (useGPU can be useful) 34 | local input = net.useGPU and torch.Tensor(4,10) or torch.Tensor(4,10):cuda() -- 35 | local gradOutput = input:clone() 36 | local output = net:forward(input) 37 | local gradInput = net:backward(input,gradOutput) 38 | 39 | -- Save to file (If net is on gpu, it first takes it to RAM to save, then brings back to GPU automatically) 40 | net:save('customModel.t7') 41 | local output = net:forward(input) 42 | local gradInput = net:backward(input,gradOutput) 43 | 44 | -- Load saved model 45 | local savedModel = dlt.Model('customModel.t7','savedModel') 46 | print(savedModel) 47 | print(net) -- net is completely different from savedModel 48 | 49 | -- Load model without need of dlt (this will in RAM) 50 | local noDLTmodel = torch.load('customModel.t7') 51 | print(noDLTmodel) 52 | 53 | ``` 54 | 55 | Possible runs: 56 | ```bash 57 | # Run on CPU 58 | th model.lua -nGPU 0 59 | # Run on GPU no. 2 60 | th model.lua -nGPU 1 -defGPU 2 61 | # Run on multiple GPUs (DataParallelTable) using cudnn 62 | th model.lua -nGPU 2 -useCudnn true 63 | ``` 64 | -------------------------------------------------------------------------------- /src/models/lenet5.lua: -------------------------------------------------------------------------------- 1 | local dlt = require('dlt._env') 2 | 3 | -- Adapted from the torch 60 minute blitz 4 | -- https://github.com/soumith/cvpr2015/blob/master/Deep%20Learning%20with%20Torch.ipynb 5 | 6 | function dlt.models.lenet5(w,h,inChannels,nClasses) 7 | inChannels = inChannels or 1 8 | nClasses = nClasses or 10 9 | w = w or 32 10 | h = h or 32 11 | 12 | local currentW,currentH = w,h 13 | local net = nn.Sequential() 14 | -- inChannels input image channels, 15 | -- 6 output channels, 5x5 convolution kernel 16 | net:add(nn.SpatialConvolution(inChannels, 6, 5, 5)) 17 | currentW,currentH = dlt.help.SpatialConvolutionSize(currentW,currentH,5,5) 18 | -- non-linearity 19 | net:add(nn.ReLU(true)) 20 | -- A max-pooling operation that looks at 2x2 windows and finds the max. 21 | net:add(nn.SpatialMaxPooling(2,2,2,2)) 22 | currentW,currentH = dlt.help.SpatialMaxPoolingSize(currentW,currentH, 23 | 2,2,2,2) 24 | 25 | net:add(nn.SpatialConvolution(6, 16, 5, 5)) 26 | currentW,currentH = dlt.help.SpatialConvolutionSize(currentW,currentH,5,5) 27 | 28 | -- non-linearity 29 | net:add(nn.ReLU(true)) 30 | net:add(nn.SpatialMaxPooling(2,2,2,2)) 31 | currentW,currentH = dlt.help.SpatialMaxPoolingSize(currentW,currentH, 32 | 2,2,2,2) 33 | 34 | -- reshapes from a 3D tensor of 16x5x5 into 1D tensor of 16*5*5 35 | net:add(nn.View(16*currentW*currentH)) 36 | -- fully connected layer (matrix multiplication between input and weights) 37 | net:add(nn.Linear(16*currentW*currentH, 120)) 38 | -- non-linearity 39 | net:add(nn.ReLU(true)) 40 | net:add(nn.Linear(120, 84)) 41 | -- non-linearity 42 | net:add(nn.ReLU(true)) 43 | -- 10 is the number of outputs of the network (in this case, 10 digits) 44 | net:add(nn.Linear(84, nClasses)) 45 | net:add(nn.Sigmoid()) 46 | return net 47 | end 48 | 49 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Learning Toolbox 2 | 3 | Easily create and run deep 4 | learning experiments using [Torch](http://torch.ch/) with minimal code. 5 | 6 | Initially inspired by 7 | [ImageNet multi-GPU](https://github.com/soumith/imagenet-multiGPU.torch). 8 | 9 | Similar frameworks: 10 | 11 | * [dp](https://github.com/nicholas-leonard/dp) 12 | * [torchnet](https://github.com/torchnet/torchnet) 13 | 14 | ## Supports 15 | 16 | * Multi-GPU implementation with automatic saving/loading of 17 | [models](doc/model.md). 18 | * [Data](doc/data.md) iterator and [loaders](doc/loader.md) with 19 | multi-threading support. 20 | * Multiple types of [training](doc/trainer.md) (simple, GAN, WGAN, BEGAN), 21 | with automatic checkpointing and logging of training loss. 22 | * Easy [experiment creation](doc/trainer.md) and 23 | [dispatching](doc/dispatcher.md). Experiments are transferable accross 24 | machines (e.g. can start training on a GPU machine and finish on 25 | a non-GPU machine). 26 | * [Slurm](doc/slurm.md) scheduler support for usage on HPC facilities. 27 | * Settings parsing, optimizer, logging, colorspaces (XYZ, IPT, LMS, Lαβ). 28 | More info [here](doc/misc.md). 29 | * Data loader interfaces for *MNIST*, *CIFAR*, *CelebA*, *Places*, *pix2pix*. 30 | * Implementations of some (standard) models, including *LeNet5*, *VGG*, 31 | *AlexNet*, *Squeezenet*, *Colornet*, *UNET*. 32 | 33 | ## Installation 34 | 35 | Make sure you have [Torch](http://torch.ch/) installed. 36 | 37 | To install use: 38 | ```bash 39 | git clone https://github.com/dmarnerides/dlt.git 40 | cd dlt 41 | ./install.sh 42 | ``` 43 | 44 | ## Warning / Disclaimer 45 | 46 | I created this toolbox for my PhD, mostly to learn Lua, understand Torch in 47 | depth, and have a consistent workflow accross multiple machines and HPC 48 | facilities. 49 | 50 | Only tested on Ubuntu and CentOS. 51 | 52 | **If you use this package you will probably encounter bugs. 53 | If so please let me know!** 54 | 55 | **Use at your own risk.** 56 | 57 | If you use this code please cite the repo. 58 | 59 | ## Contact 60 | 61 | dmarnerides@gmail.com -------------------------------------------------------------------------------- /src/loaders/mnist.lua: -------------------------------------------------------------------------------- 1 | 2 | local dlt = require('dlt._env') 3 | local M,parent = torch.class('dlt.Mnist','dlt.Loader',dlt) 4 | 5 | -- Loads images and labels for training and validation sets in memory 6 | -- (byte [0-255]). 7 | -- s.path 8 | -- s.assignPoint = function(point,iBatchMember,img,cls) 9 | -- s.shuffle (defaults to true) 10 | -- s.transform is a function of the whole dataset 11 | -- (since it's small and loaded once in memory) 12 | -- s.download [true] downloads from 13 | -- https://s3.amazonaws.com/torch7/data/mnist.t7.tgz 14 | function M:__init(s) 15 | self.name = 'MNIST' 16 | parent.__init(self,s) 17 | local train = paths.concat(s.path, 'train_32x32.t7') 18 | local val = paths.concat(s.path, 'test_32x32.t7') 19 | self.path = {training = train, validation = val} 20 | if s.download == nil then s.download = true end 21 | if s.download then 22 | if not paths.filep(self.path.training) 23 | or not paths.filep(self.path.validation) then 24 | os.execute('wget https://s3.amazonaws.com/torch7/data/mnist.t7.tgz' 25 | .. ' --directory-prefix=' .. s.path) 26 | os.execute('cd ' .. s.path .. '\n tar -xvzf mnist.t7.tgz \n ' 27 | .. ' mv mnist.t7/* . \n rm -r mnist.t7*' ) 28 | end 29 | end 30 | if not paths.filep(self.path.training) then 31 | dlt.log:error('Could not find ' .. self.path.training) 32 | end 33 | if not paths.filep(self.path.validation) then 34 | dlt.log:error('Could not find ' .. self.path.validation) 35 | end 36 | -- Transformation 37 | self.transform = s.transform or function(imagesTensor) 38 | return imagesTensor 39 | end 40 | -- Internals 41 | self.sets.testing = nil 42 | end 43 | 44 | function M:initInstance(setName) 45 | local f = torch.load(self.path[setName], 'ascii') 46 | self.set = self.set or {} 47 | self.set[setName] = { 48 | images = self.transform(f.data), 49 | labels = f.labels, 50 | nPoints = f.labels:size(1) 51 | } 52 | end 53 | 54 | function M:dataPoint(index,setName) 55 | return self.set[setName].images[index], 56 | self.set[setName].labels[index] 57 | end -------------------------------------------------------------------------------- /doc/misc.md: -------------------------------------------------------------------------------- 1 | ## Settings 2 | 3 | ```lua 4 | [s] dlt.parse([out,extra]) 5 | ``` 6 | * Passes arguments (arg) and returns table of settings for dlt. 7 | * `out` is a table to place all settings in (useful for objects to pass self) 8 | * `extra` is a table of extra settings that need parsing. 9 | 10 | For the full list of available settings run: 11 | ```bash 12 | th -e "require('dlt').parse()" 13 | ``` 14 | 15 | ### Example 16 | File: settings.lua 17 | ```lua 18 | local dlt = require('dlt') 19 | 20 | local myTable = { notAnArgSetting = 'yes' } 21 | 22 | local s = dlt.parse(myTable,{ {'-check', 'false', 'Boolean setting.'}, 23 | {'-myNumber', 0, 'Number setting.'}, 24 | {'-myMessage', 'message', 'String setting.'}}) 25 | 26 | 27 | print(torch.type(s.check)) -- boolean 28 | 29 | print('s --> ' .. s.myMessage .. ': ' .. s.myNumber) 30 | print(s.notAnArgSetting) -- nil 31 | 32 | print('myTable --> ' .. s.myMessage .. ': ' .. s.myNumber) 33 | -- myTable still has all other settings (unless overwritten due to name clash) 34 | print(myTable.notAnArgSetting) -- yes 35 | 36 | -- Also parses dlt settings 37 | print(s.nGPU) 38 | print(myTable.batchSize) 39 | ``` 40 | Example run: 41 | ```bash 42 | th settings.lua -myMessage "My Number is" -myNumber 42 43 | ``` 44 | 45 | ## Optimizer 46 | 47 | A thin wrapper to [optim](https://github.com/torch/optim). 48 | 49 | Main functionality is for consistent saving of state when used in conjuction with [dlt.Trainer](trainer.md) (conversion of Tensors to/from GPU) 50 | 51 | Default optimizer is [Adam](https://github.com/torch/optim/blob/master/adam.lua) 52 | 53 | ## Logger 54 | 55 | When *dlt* loaded a logger is created `dlt.log`: 56 | * Has 6 verbose levels: 57 | * [1-6]: error, warning, yell, say, detail, debug 58 | * Use `dlt.log:say('Something')` 59 | * If `dlt.log:error('message')` is used, execution is terminated 60 | * set: `dlt.log:setLevel(level)` 61 | * Prints stuff in terminal friendly boxes 62 | * Can create sections: 63 | * `dlt.log:section('My Section')`, `dlt.log:endSection()` 64 | 65 | ## Colorspaces 66 | 67 | Supports colorspace conversions (not the ones from [image](https://github.com/torch/image)) for XYZ, IPT, LMS, Lαβ. 68 | 69 | Example: 70 | ```lua 71 | local dlt = require('dlt') 72 | local img = image.load('image.jpg') 73 | iptImage = dlt.color.rgb2ipt(img) 74 | ``` -------------------------------------------------------------------------------- /src/models/squeezenet.lua: -------------------------------------------------------------------------------- 1 | local dlt = require('dlt._env') 2 | 3 | -- Squeezenet paper: http://arxiv.org/abs/1602.07360 4 | 5 | -- Adapted from 6 | -- https://github.com/soumith/imagenet-multiGPU.torch/blob/master/models/squeezenet.lua 7 | 8 | local function bypass(net) 9 | local cat = nn.ConcatTable():add(net):add(nn.Identity()) 10 | local ret = nn.Sequential():add(cat):add(nn.CAddTable(true)) 11 | return ret 12 | end 13 | 14 | local function fire(inChannels, midChannels, outChannels1, outChannels2) 15 | local net = nn.Sequential() 16 | :add(nn.SpatialConvolution(inChannels, midChannels, 17 | 1, 1)) 18 | :add(nn.ReLU(true)) 19 | local exp = nn.Concat(2) 20 | :add(nn.SpatialConvolution(midChannels, outChannels1, 21 | 1, 1)) 22 | :add(nn.SpatialConvolution(midChannels, outChannels2, 23 | 3, 3, 1, 1, 1, 1)) 24 | return net:add(exp):add(nn.ReLU(true)) 25 | end 26 | 27 | function dlt.models.squeezenet(nClasses,LogSoftMax) 28 | nClasses = nClasses or 1000 29 | local net = nn.Sequential() 30 | -- conv1 31 | net:add(nn.SpatialConvolution(3, 96, 7, 7, 2, 2, 0, 0)) 32 | net:add(nn.ReLU(true)) 33 | net:add(nn.SpatialMaxPooling(3, 3, 2, 2)) 34 | --fire2 35 | net:add(fire(96, 16, 64, 64)) 36 | --fire3 37 | net:add(bypass(fire(128, 16, 64, 64))) 38 | --fire4 39 | net:add(fire(128, 32, 128, 128)) 40 | net:add(nn.SpatialMaxPooling(3, 3, 2, 2)) 41 | --fire5 42 | net:add(bypass(fire(256, 32, 128, 128))) 43 | --fire6 44 | net:add(fire(256, 48, 192, 192)) 45 | --fire7 46 | net:add(bypass(fire(384, 48, 192, 192))) 47 | --fire8 48 | net:add(fire(384, 64, 256, 256)) 49 | net:add(nn.SpatialMaxPooling(3, 3, 2, 2)) 50 | --fire9 51 | net:add(bypass(fire(512, 64, 256, 256))) 52 | net:add(nn.Dropout()) 53 | --conv10 54 | net:add(nn.SpatialConvolution(512, nClasses, 1, 1, 1, 1, 1, 1)) 55 | net:add(nn.ReLU(true)) 56 | net:add(nn.SpatialAveragePooling(14, 14, 1, 1)) 57 | net:add(nn.View(nClasses)) 58 | if LogSoftMax then 59 | net:add(nn.LogSoftMax()) 60 | else 61 | net:add(nn.Sigmoid()) 62 | end 63 | return net 64 | end -------------------------------------------------------------------------------- /src/loaders/pix2pix.lua: -------------------------------------------------------------------------------- 1 | -- All data from https://github.com/phillipi/pix2pix 2 | local dlt = require('dlt._env') 3 | local P,parent = torch.class('dlt.Pix2pix','dlt.Loader',dlt) 4 | 5 | -- Loads images and labels for training and validation sets in memory 6 | -- (byte [0-255]). 7 | ------ Settings 8 | -- Required: 9 | -- s.path 10 | -- Optional 11 | -- s.assignPoint = function(point,iBatchMember,img,cls) 12 | -- s.shuffle (defaults to true) 13 | -- s.name (cityscapes,maps,facades,edges2handbags,edges2shoes) 14 | -- defaults to cityscapes 15 | function P:__init(s) 16 | if dlt.help.inTable({'cityscapes','maps','facades', 17 | 'edges2handbags','edges2shoes'},s.name) then 18 | self.name = s.name 19 | elseif s.name == nil then 20 | self.name = 'cityscapes' 21 | else dlt.log:error('Unknown places name: ' .. s.name) end 22 | parent.__init(self,s) 23 | s.path = paths.concat(s.path,self.name) 24 | if not paths.dirp(s.path) then 25 | dlt.log:error('Could not find: ' .. s.path) 26 | end 27 | 28 | -- Internals 29 | self.sizes = { 30 | cityscapes = {training = 2975, validation = 500}, 31 | maps = {training = 1096, validation = 1098}, 32 | facades = {training = 400, validation = 100, testing = 106}, 33 | edges2handbags = {training = 138567, validation = 200}, 34 | edges2shoes = {training = 49825, validation = 200} 35 | } 36 | 37 | if self.name ~= 'facades' then self.sets.testing = nil end 38 | -- paths 39 | self.path = { 40 | training = self.sets.training and paths.concat(s.path,'train') 41 | or nil, 42 | validation = self.sets.validation and paths.concat(s.path,'val') 43 | or nil, 44 | testing = self.sets.testing and paths.concat(s.path,'test') 45 | or nil 46 | } 47 | 48 | if self.name == 'edges2handbags' or self.name == 'edges2shoes' then 49 | self.format = '%d_AB.jpg' 50 | else 51 | self.format = '%d.jpg' 52 | end 53 | end 54 | 55 | function P:dataPoint(index,setName) 56 | local file = paths.concat(self.path[setName], 57 | string.format(self.format,index)) 58 | return image.load(file,3,self.type) 59 | end 60 | 61 | function P:initInstance(setName) 62 | self.set = self.set or {} 63 | self.set[setName] = {nPoints = self.sizes[self.name][setName]} 64 | end 65 | -------------------------------------------------------------------------------- /src/core/optimizer.lua: -------------------------------------------------------------------------------- 1 | local dlt = require('dlt._env') 2 | 3 | local O,parent = torch.class('Optimizer',dlt) 4 | 5 | -- Settings 6 | -- [opt] = { [name = 'sgd', config = {}, hook = function(epoch,loss,current) ]} 7 | -- [tensorType = 'float'] 8 | -- [useGPU = false] 9 | -- [optimFile = nil] 10 | function O:__init(opt,tensorType,useGPU,optimFile) 11 | opt = opt or {} 12 | 13 | self.tensorType = tensorType or 'float' 14 | self.useGPU = useGPU or false 15 | 16 | -- Set up optimizer, defaults to sgd 17 | if opt.name and optim[opt.name] == nil then 18 | dlt.log:error('Unknown optim type ' .. opt.name) 19 | end 20 | self.optim = opt.name and optim[opt.name] or optim['adam'] 21 | -- Get optimizer state 22 | self.optimState = optimFile and torch.load(optimFile) or opt.config 23 | -- Make it an empty table if opt.config is nil 24 | self.optimState = self.optimState or {} 25 | -- Hook for updating the optimizer state 26 | self.optimHook = opt.hook or function(epoch,loss,current) 27 | return current 28 | end 29 | -- Transfer state to gpu 30 | -- (could have been loaded from checkpoint, saved on gpu) 31 | if self.useGPU then 32 | self:gpu() 33 | end 34 | 35 | end 36 | 37 | function O:cpu() 38 | nn.utils.recursiveType(self.optimState, 39 | 'torch.' .. dlt.help.tensorList.cpu[self.tensorType] ) 40 | end 41 | 42 | function O:gpu() 43 | nn.utils.recursiveType(self.optimState, 44 | 'torch.' .. dlt.help.tensorList.gpu[self.tensorType] ) 45 | end 46 | 47 | function O:updateState(epoch,loss) 48 | self.optimState = self.optimHook(epoch,loss,self.optimState) 49 | end 50 | 51 | function O:save(filename) 52 | -- To save, first transfer to cpu and then recast 53 | if self.useGPU then 54 | self:cpu() 55 | end 56 | torch.save(filename, self.optimState) 57 | if self.useGPU then 58 | self:gpu() 59 | end 60 | end 61 | 62 | -- Sets all numbers/tensors in optimizer state to defaults or zero 63 | function O:resetState(defaults) 64 | for key,val in pairs(self.optimState) do 65 | if defaults[key] then self.optimState[key] = defaults[key] 66 | else 67 | if torch.type(self.optimState[key]) == 'number' then 68 | self.optimState[key] = 0 69 | else 70 | if self.optimState[key].zero then 71 | self.optimState[key]:zero() 72 | end 73 | end 74 | end 75 | end 76 | end 77 | 78 | function O:step(f,param) 79 | self.optim(f,param,self.optimState) 80 | end -------------------------------------------------------------------------------- /doc/dispatcher.md: -------------------------------------------------------------------------------- 1 | # Dispatcher 2 | 3 | ## Usage 4 | ```lua 5 | dispatcher = dlt.Dispatcher(experimentFunction [,extras]) 6 | ``` 7 | * `experimentFunction` is a function that runs an experiment. 8 | * `extras` extra settings to parse. Must be provided if extra arguments were parsed already. 9 | 10 | Useful for creating self-contained directories of a pre-configured experiment, (with slurm scheduler script ready to submit if required). 11 | 12 | Need to provide with `experimentName` and `runRoot` settings. The experiment will be created/run in *runRoot/experimentName* . 13 | 14 | ## Example 15 | 16 | File dispatch.lua 17 | ```lua 18 | local dlt = require('dlt') 19 | -- Can easily add extra settings to parse 20 | local extras = {{'-localRun','true','Whether we run or create slurm script'}, 21 | {'-dataPath','none','Path for data'}} 22 | 23 | -- Get settings to use in closure 24 | local s = dlt.parse(nil,extras) 25 | 26 | -- Dispatcher needs a function that runs an experiment (could be doing anything really) 27 | local function experiment() 28 | -- MUST get local reference to dlt 29 | local dlt = require('dlt') 30 | torch.setdefaulttensortype('torch.FloatTensor') 31 | -- Make experiment table 32 | local exp = { model = { create = dlt.models.lenet5 }, 33 | loader = dlt.Mnist{ path = s.dataPath, 34 | transform = function(images) return images:float():div(255) end, 35 | assignPoint = function(batch,i,img,cls) 36 | batch.input[i]:copy(img) 37 | batch.output[i] = cls 38 | end 39 | }, 40 | pointSize = {input = {1,32,32}, output = {}}, 41 | criterion = nn.CrossEntropyCriterion() 42 | } 43 | -- Run trainer with given experiment 44 | dlt.Trainer(exp):run() 45 | end 46 | 47 | -- Create the dispatcher (MUST pass the extra settings in this case) 48 | local dispatcher = dlt.Dispatcher(experiment,extras) 49 | -- Run on local machine or make slurm script if we are on HPC machine 50 | if s.localRun then 51 | -- will add this string before running the script 52 | dispatcher:run('export THC_CACHING_ALLOCATOR=1') 53 | -- If we need to use qlua (e.g. for image.display()) then use 54 | -- dispatcher:run('export THC_CACHING_ALLOCATOR=1',true) 55 | else 56 | -- Will only make the slurm script in 'runRoot/experimentName/job' 57 | -- Remember to pass slurm script arguments when invoking this 58 | dispatcher:makeSlurm() 59 | end 60 | ``` 61 | 62 | Possible run: 63 | ```bash 64 | th dispatch.lua -runRoot ~/results -experimentName dispatcherTest -nGPU 1 -defGPU 2 -dataPath ~/data/mnist -batchSize 1000 65 | ``` -------------------------------------------------------------------------------- /src/models/unet.lua: -------------------------------------------------------------------------------- 1 | -- https://arxiv.org/pdf/1505.04597.pdf 2 | -- https://arxiv.org/pdf/1611.07004v1.pdf 3 | 4 | local dlt = require('dlt._env') 5 | 6 | function dlt.models.unet(layers,nInput,nOutput,nDrop,noPadding,noTanh,useSbn) 7 | if useSbn == nil then useSbn = true end 8 | local pad = noPadding and 0 or 1 9 | if not layers then 10 | dlt.log:error('Unet requires specification of layers.') 11 | end 12 | nOutput = nOutput or 3 13 | nInput = nInput or 3 14 | nDrop = nDrop or 3 15 | local function LSS(nIn,nOut,sbn) 16 | if sbn == nil then sbn = true end 17 | local ret = nn.Sequential() 18 | ret:add(nn.LeakyReLU(0.2,true)) 19 | :add(nn.SpatialConvolution(nIn, nOut, 4, 4, 2, 2, pad, pad)) 20 | if sbn then ret:add(nn.SpatialBatchNormalization(nOut)) end 21 | return ret 22 | end 23 | local dropCount = 1 24 | 25 | local function RFS(nIn,nOut,sbn) 26 | if sbn == nil then sbn = true end 27 | local ret = nn.Sequential() 28 | ret:add(nn.ReLU(true)) 29 | :add(nn.SpatialFullConvolution(nIn, nOut, 4, 4, 2, 2, pad, pad)) 30 | if sbn then ret:add(nn.SpatialBatchNormalization(nOut)) end 31 | if dropCount <= nDrop then 32 | ret:add(nn.Dropout(0.5)) 33 | dropCount = dropCount + 1 34 | end 35 | return ret 36 | end 37 | 38 | local function recurse(current,next,...) 39 | if #{...} == 0 then 40 | return nn.Sequential() 41 | :add(nn.ConcatTable() 42 | :add( nn.Sequential() 43 | :add(LSS(current,next,false)) 44 | :add(RFS(next,current)) 45 | ) 46 | :add(nn.Identity()) 47 | ) 48 | :add(nn.JoinTable(1,3)) 49 | else 50 | return nn.Sequential() 51 | :add(nn.ConcatTable() 52 | :add( nn.Sequential() 53 | :add(LSS(current,next)) 54 | :add(recurse(next,...)) 55 | :add(RFS(next*2,current)) 56 | ) 57 | :add(nn.Identity()) 58 | ) 59 | :add(nn.JoinTable(1,3)) 60 | end 61 | end 62 | 63 | local model = nn.Sequential() 64 | :add(nn.SpatialConvolution(nInput, layers[1], 4, 4, 2, 2, pad, pad)) 65 | :add(recurse(unpack(layers))) 66 | :add(nn.ReLU(true)) 67 | :add(nn.SpatialFullConvolution(layers[1]*2, nOutput, 4, 4, 2, 2, 68 | pad, pad)) 69 | if not noTanh then model:add(nn.Tanh(true)) end 70 | 71 | return model 72 | end -------------------------------------------------------------------------------- /src/loaders/cifar.lua: -------------------------------------------------------------------------------- 1 | local dlt = require('dlt._env') 2 | local C,parent = torch.class('dlt.Cifar','dlt.Loader',dlt) 3 | 4 | -- Loads images and labels for training and validation sets in memory 5 | -- (byte [0-255]). 6 | ------ Settings 7 | -- Required: 8 | -- s.path (Must contain cifar10-train.t7, cifar10-test.t7 9 | -- or cifar100-train.t7, cifar100-test.t7) 10 | -- Optional 11 | -- s.assignPoint = function(point,iBatchMember,img,cls) 12 | -- s.shuffle (defaults to true) 13 | -- s.transform is a function of the whole dataset 14 | -- (since it's small and loaded once in memory) 15 | -- s.name (100, 10) defaults to 10 16 | -- s.download [true] if true and path does not contain datasets then will use 17 | -- https://github.com/soumith/cifar.torch to get the data 18 | 19 | local function download(name,path) 20 | local fname = name .. 'BinToTensor.lua' 21 | local code = 22 | 'https://raw.githubusercontent.com/soumith/cifar.torch/master/' 23 | .. fname 24 | os.execute('wget ' .. code .. ' --directory-prefix=' .. path) 25 | os.execute('cd ' .. path .. '\n th ' .. fname) 26 | end 27 | 28 | function C:__init(s) 29 | 30 | if dlt.help.inTable({'10',10,'cifar10','Cifar10'},s.name) 31 | or s.name == nil then 32 | self.name = 'cifar10' 33 | elseif dlt.help.inTable({'100',100,'cifar100','Cifar100'},s.name) then 34 | self.name = 'cifar100' 35 | else dlt.log:error('Unknown places name: ' .. s.name) end 36 | parent.__init(self,s) 37 | local train = paths.concat(s.path, self.name .. '-train.t7') 38 | local val = paths.concat(s.path, self.name .. '-test.t7') 39 | self.path = {training = train, validation = val} 40 | 41 | if s.download and (not paths.filep(self.path.training) 42 | or not paths.filep(self.path.validation)) then 43 | download('C' .. self.name:sub(2,-1),s.path) 44 | end 45 | if not paths.filep(self.path.training) then 46 | dlt.log:error('Could not find ' .. self.path.training) 47 | end 48 | if not paths.filep(self.path.validation) then 49 | dlt.log:error('Could not find ' .. self.path.validation) 50 | end 51 | -- Transformation 52 | self.transform = s.transform or function(imagesTensor) 53 | return imagesTensor 54 | end 55 | -- Internals 56 | self.sets.testing = nil 57 | end 58 | 59 | function C:initInstance(setName) 60 | local f = torch.load(self.path[setName]) 61 | self.set = self.set or {} 62 | self.set[setName] = { 63 | images = self.transform(f.data), 64 | labels = f.label, 65 | nPoints = f.label:size(1) 66 | } 67 | end 68 | 69 | function C:dataPoint(index,setName) 70 | return self.set[setName].images[index], 71 | self.set[setName].labels[index] 72 | end -------------------------------------------------------------------------------- /src/models/vgg.lua: -------------------------------------------------------------------------------- 1 | local dlt = require('dlt._env') 2 | 3 | -- Original paper 4 | -- https://arxiv.org/pdf/1409.1556.pdf 5 | -- Adapted from 6 | -- https://github.com/soumith/imagenet-multiGPU.torch/blob/master/models/vggbn.lua 7 | 8 | function dlt.models.vgg(modelType, nClasses,bn,dropout,w,h) 9 | modelType = modelType or 'A' 10 | nClasses = nClasses or 205 11 | bn = bn or true 12 | dropout = dropout or true 13 | w = w or 224 14 | h = h or 224 15 | 16 | -- Create tables describing VGG configurations A, B, D, E 17 | local cfg = {} 18 | if modelType == 'A' then 19 | cfg = {64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'} 20 | elseif modelType == 'B' then 21 | cfg = {64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 22 | 512, 512, 'M', 512, 512, 'M'} 23 | elseif modelType == 'D' then 24 | cfg = {64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 25 | 'M', 512, 512, 512, 'M'} 26 | elseif modelType == 'E' then 27 | cfg = {64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 28 | 512, 512, 'M', 512, 512, 512, 512, 'M'} 29 | else 30 | dlt.log:error('Unknown model type for VGG : ' .. modelType .. 31 | '. Available types: [A,B,D,E]') 32 | end 33 | 34 | local currentW,currentH = w,h 35 | local features = nn.Sequential() 36 | do 37 | local iChannels = 3; 38 | for k,v in ipairs(cfg) do 39 | if v == 'M' then 40 | features:add(nn.SpatialMaxPooling(2,2,2,2)) 41 | currentW,currentH = dlt.help.SpatialMaxPoolingSize(currentW, 42 | currentH,2,2,2,2) 43 | else 44 | local oChannels = v; 45 | local conv3 = nn.SpatialConvolution(iChannels,oChannels,3,3, 46 | 1,1,1,1); 47 | currentW,currentH = dlt.help.SpatialConvolutionSize(currentW, 48 | currentH,3,3,1,1,1,1) 49 | features:add(conv3) 50 | features:add(nn.ReLU(true)) 51 | iChannels = oChannels; 52 | end 53 | end 54 | end 55 | 56 | 57 | local classifier = nn.Sequential() 58 | classifier:add(nn.View(512*currentW*currentH)) 59 | classifier:add(nn.Linear(512*currentW*currentH, 4096)) 60 | classifier:add(nn.ReLU(true)) 61 | if bn then classifier:add(nn.BatchNormalization(4096, 1e-4)) end 62 | if dropout then classifier:add(nn.Dropout(0.5)) end 63 | classifier:add(nn.Linear(4096, 4096)) 64 | classifier:add(nn.ReLU(true)) 65 | if bn then classifier:add(nn.BatchNormalization(4096, 1e-4)) end 66 | if dropout then classifier:add(nn.Dropout(0.5)) end 67 | classifier:add(nn.Linear(4096, nClasses)) 68 | classifier:add(nn.LogSoftMax()) 69 | 70 | local model = nn.Sequential() 71 | model:add(features):add(classifier) 72 | 73 | return model 74 | end -------------------------------------------------------------------------------- /src/loaders/loader.lua: -------------------------------------------------------------------------------- 1 | -- Abstract loader class 2 | local dlt = require('dlt._env') 3 | local M = torch.class('dlt.Loader',dlt) 4 | 5 | function M:__init(s) 6 | -- Defaults 7 | self.assignPoint = s.assignPoint or function() end 8 | self.shuffle = s.shuffle == nil and true or s.shuffle 9 | if torch.type(self.shuffle) ~= 'boolean' then 10 | dlt.log:warning('shuffle must be boolean, setting to default (true).') 11 | self.shuffle = true 12 | end 13 | -- Paths 14 | self.name = self.name or 'data' 15 | if not s.path then 16 | dlt.log:error('Path not provided for ' .. self.name .. ' loader.') 17 | end 18 | s.path = dlt.help.checkHomePath(s.path) 19 | s.path = paths.concat(s.path) 20 | if not paths.dirp(s.path) then 21 | dlt.log:error('Path provided for ' .. self.name .. 22 | ' loader does not exist. ' .. s.path) 23 | end 24 | -- Internals 25 | self.shuffler = {} 26 | self.currentSet = 'training' 27 | self.sets = {training = true, validation = true, testing = true} 28 | 29 | end 30 | 31 | function M:transformIndex(index) 32 | local np = self.set[self.currentSet].nPoints 33 | return self.shuffler[self.currentSet][(index - 1) % np + 1] 34 | end 35 | 36 | function M:size(set) 37 | set = set or self.currentSet 38 | return self.set[set].nPoints 39 | end 40 | 41 | function M:mode(setName) 42 | if self.sets[setName] then 43 | self.currentSet = setName 44 | else 45 | dlt.log:warning('Unknown mode: ' .. setName ..'. Keeping: ' 46 | .. self.currentSet .. '.') 47 | end 48 | end 49 | function M:reshuffle() 50 | self.shuffler[self.currentSet] = self.shuffle 51 | and torch.randperm(self.set[self.currentSet].nPoints):long() 52 | or self.shuffler[self.currentSet] 53 | end 54 | function M:get(index) 55 | return self:dataPoint(self:transformIndex(index),self.currentSet) 56 | end 57 | 58 | function M:assignBatch(batch,iDataPoint,nPoints) 59 | for i = iDataPoint,iDataPoint+nPoints - 1 do 60 | local iPnt = (i - 1) % self.set[self.currentSet].nPoints + 1 61 | self.assignPoint(batch,i - iDataPoint + 1,self:get(iPnt)) 62 | end 63 | end 64 | 65 | local function initHelp(self,setName) 66 | if not self.initialized[setName] then 67 | self:initInstance(setName) 68 | if self.shuffle then 69 | local np = self.set[setName].nPoints 70 | self.shuffler[setName] = torch.randperm(np):int() 71 | else 72 | self.shuffler[setName] = {} 73 | setmetatable(self.shuffler[setName], 74 | {__index = function(_,key) return key end}) 75 | end 76 | self.initialized[setName] = true 77 | end 78 | end 79 | 80 | function M:init(set) 81 | self.initialized = self.initialized or {} 82 | if set then 83 | initHelp(self,set) 84 | self:mode(set) 85 | else 86 | for setName,_ in pairs(self.sets) do 87 | initHelp(self,setName) 88 | end 89 | self:mode('training') 90 | end 91 | return self 92 | end -------------------------------------------------------------------------------- /src/core/dispatcher.lua: -------------------------------------------------------------------------------- 1 | local dlt = require('dlt._env') 2 | local D,parent = torch.class('dlt.Dispatcher',dlt) 3 | 4 | function D:__init(experimentFunction,extras) 5 | -- Store experiment Function and parsed settings 6 | -- We do not store the actual experiment table 7 | -- but the function that returns it 8 | self.experiment = experimentFunction 9 | self.configuration = dlt.parse(nil,extras) 10 | -- First Make paths 11 | if not self.configuration.runRoot then 12 | dlt.log:error('Dispatcher: runRoot not given.') 13 | end 14 | if not self.configuration.experimentName then 15 | dlt.log:error('Dispatcher: experimentName not given.') 16 | end 17 | self.configuration.runRoot = 18 | dlt.help.checkHomePath(self.configuration.runRoot) 19 | dlt.help.checkMakeDir(self.configuration.runRoot,'runRoot') 20 | self.configuration.runPath = paths.concat(self.configuration.runRoot, 21 | self.configuration.experimentName) 22 | dlt.help.checkMakeDir(self.configuration.runPath,'runPath') 23 | self.configuration.savePath = paths.concat(self.configuration.runPath, 24 | self.configuration.savePath) 25 | dlt.help.checkMakeDir(self.configuration.savePath) 26 | 27 | -- Create lua script that runs dispatcher inside run directory 28 | dlt.writeSettingsToFile(self.configuration, 29 | paths.concat(self.configuration.runPath, 30 | 'settings.txt')) 31 | torch.save(paths.concat(self.configuration.runPath,'dispatcher.t7'),self) 32 | local toRun = 33 | [[ 34 | local dlt = require('dlt') 35 | local dispatcher = torch.load('dispatcher.t7') 36 | dispatcher:runLoadedDispatcher() 37 | ]] 38 | local luaScriptPath = paths.concat(self.configuration.runPath,'script.lua') 39 | local file = io.open(luaScriptPath, 'w+') 40 | file:write(toRun) 41 | file:close() 42 | 43 | end 44 | 45 | function D:__call__(...) self:localRun(...) end 46 | 47 | -- Do not use this, use D:run() if dispatcher was created 48 | -- and not loaded in script 49 | function D:runLoadedDispatcher() 50 | -- Create settings 51 | -- We configure here because dispatcher may be created 52 | -- on a machine without a gpu 53 | -- e.g. the login nodes of an HPC facility 54 | -- Set global configuration 55 | dlt.settings = self.configuration 56 | self.configuration = dlt.configure(self.configuration) 57 | -- Run experiment 58 | -- Default globals, tensor types, threads, seeds etc 59 | -- should be set at the beginning of the experiment function 60 | self.experiment() 61 | end 62 | 63 | function D:makeSlurm() 64 | self.batchScript = dlt.Slurm():createScript(self.configuration.runPath) 65 | end 66 | 67 | function D:run(preCommands,qlua) 68 | local launch = qlua and 'qlua' or 'th' 69 | os.execute('cd ' .. self.configuration.runPath .. '\n' .. preCommands .. 70 | '\n ' .. launch .. ' script.lua') 71 | end 72 | 73 | -- SLURM SUBMIT NEEDS FIXING 74 | -- function D:submitSlurm() 75 | -- slurm:submit(self.batchScript) 76 | -- end -------------------------------------------------------------------------------- /doc/data.md: -------------------------------------------------------------------------------- 1 | # Data 2 | 3 | ## Usage 4 | ```lua 5 | data = dlt.Data( loader, pointSize [, datasets, currentEpoch] ) 6 | ``` 7 | * `loader` A data [loader](loader.md). 8 | * `pointSize` Table of point elements and their sizes. 9 | * `datasets` Table of datasets to use from loader. 10 | * `currentEpoch` Useful for checkpointing and resuming runs. 11 | * Main functionality is `data:iterate(callbacks)`. 12 | 13 | ## Example 14 | 15 | File data.lua : 16 | ```lua 17 | local dlt = require('dlt') 18 | -- Use mnist loader as example 19 | local mnist = dlt.Mnist{path = '~/data/mnist', shuffle = false, 20 | assignPoint = function(batch,i,img,cls) 21 | batch.img[i]:copy(img) 22 | batch.cls[i] = cls 23 | end, 24 | -- Mnist loads as ByteTensors, so we can use transform to convert all images to [0,1] floats. 25 | transform = function(images) return images:float():div(255) end} 26 | -- input is 32x32 image with 1 channel. 27 | -- class is dimensionless (might need to use {1} instead of {} depending on criterion) 28 | local pointSize = {img = {1,32,32}, cls = {}} 29 | 30 | -- Create data iterator for training and validation 31 | local data = dlt.Data( mnist, pointSize, {'training','validation'}) 32 | 33 | -- Make closure variables 34 | local trainClassSum, valClassSum = 0,0 35 | local trainCount, valCount = 0,0 36 | local didCheckpoint = false 37 | local batchSize, batchType 38 | -- Iterate datasets with checkpointing and termination conditions 39 | data:iterate{ 40 | training = function(batch) 41 | if didCheckpoint then return true, 'Did Checkpoint (Training)!' end -- return a termination statement 42 | -- Here we have access to the batch loaded from the dataset 43 | trainClassSum = trainClassSum + batch.cls:sum() 44 | trainCount = trainCount + batch.cls:nElement() 45 | batchSize = batch.img:size(1) 46 | batchType = torch.type(batch.img) 47 | end, 48 | validation = function(batch) 49 | if didCheckpoint then return true, 'Did Checkpoint (Validation)!' end -- return a termination statement 50 | valClassSum = valClassSum + batch.cls:sum() 51 | valCount = valCount + batch.cls:nElement() 52 | end, 53 | checkpoint = function() -- This is called at the end of EVERY iteration 54 | -- Stop at the first validation step 55 | if valCount > 0 then didCheckpoint = true end 56 | end 57 | } 58 | dlt.log:section('Results') 59 | dlt.log:yell(string.format('Training Class Average for %d points: %.2f',trainCount,trainClassSum/trainCount ) ) 60 | dlt.log:yell(string.format('Validation Class Average for %d points: %.2f',valCount,valClassSum/valCount ) ) 61 | dlt.log:yell(string.format('Batch size %d, type %s',batchSize,batchType)) 62 | dlt.log:endSection() 63 | ``` 64 | 65 | 66 | Possible runs: 67 | ```bash 68 | # Run on CPU only on master thread with batchSize of 16 69 | th data.lua -nGPU 0 -nThreads 0 -batchSize 16 70 | # Run on GPU no. 2 (callbacks batch will be on GPU 2) with 4 threads (loading of data) 71 | # Note that for mnist this will not make much difference, (might actually be slower) 72 | th data.lua -nGPU 1 -defGPU 2 -nThreads 4 73 | # Use verbose 5 to get timings printed to console 74 | # Use a batch with double precision 75 | th data.lua -verbose 5 -tensorType double 76 | ``` -------------------------------------------------------------------------------- /src/util/plot.lua: -------------------------------------------------------------------------------- 1 | local dlt = require('dlt._env') 2 | dlt.plot = {} 3 | local P = dlt.plot 4 | 5 | local inputFile = dlt.help.checkHomePath('~/results/itmpixel2/save/training.log') 6 | 7 | -- Loads csv file and returns an array of tables each containing the key/name 8 | -- of the column and a 1D FloatTensor with the data. 9 | function P.loadToTensors(file) 10 | local loaded = csvigo.load({path = file , verbose = false}) 11 | local ret = {} 12 | for key,val in pairs(loaded) do 13 | ret[#ret + 1] = { data = torch.FloatTensor(dlt.help.apply(val,tonumber)), 14 | key = key} 15 | end 16 | return ret 17 | end 18 | 19 | function P.saveToFile(t,fileName) 20 | fileName = fileName or paths.tmpname() 21 | local f = io.open(fileName,'w') 22 | for i=1,t:size(1) do 23 | f:write(string.format('%1.4e\n',t[i])) 24 | end 25 | f:close() 26 | return fileName 27 | end 28 | 29 | -- Creates a simple 30 | function P.createEPS(filename,callback,epsName,title,ylabel,xlabel) 31 | epsName = epsName or 'plot.eps' 32 | callback = callback or function(t) return t[1].data end 33 | local outFile = P.saveToFile(callback(P.loadToTensors(filename))) 34 | local command = string.format( [[ 35 | set terminal postscript eps enhanced size 10in,7in 36 | set nokey 37 | set output "%s" 38 | set title "%s" 39 | set ylabel "%s" 40 | set xlabel "%s" 41 | plot "%s" using 1 w l 42 | ]],epsName, title, ylabel, xlabel, outFile) 43 | 44 | gnuplot.raw( command) 45 | return epsName 46 | end 47 | 48 | P.func = {} 49 | F = P.func 50 | 51 | -- e.g. funcs = {{f1,f1paramtable}, {f2,f2paramtable} } 52 | function F.compose(funcs) 53 | return function(t) 54 | for i,f in ipairs(funcs) do 55 | if type(f[2]) ~= 'table' then 56 | t = f[1](t,f[2]) 57 | else 58 | t = f[1](t,unpack(f[2])) 59 | end 60 | end 61 | return t 62 | end 63 | 64 | end 65 | 66 | function F.getColumn(t,i) 67 | return t[i].data 68 | end 69 | function F.movingAvg(t,size) 70 | local size = size or 10 71 | local ret = torch.FloatTensor(t:size(1) - size + 1) 72 | local count = 0 73 | ret:apply(function() 74 | count = count + 1 75 | return t[{{count, count + size - 1}}]:mean() 76 | end ) 77 | return ret 78 | end 79 | function F.avg(t,size) 80 | local size = size or 10 81 | local ret = torch.FloatTensor(math.floor(t:size(1)/size)) 82 | local count = 0 83 | ret:apply(function() 84 | count = count + 1 85 | return t[{{(count-1)*size + 1, count*size}}]:mean() 86 | end ) 87 | return ret 88 | end 89 | function F.movingVar(t,size) 90 | local size = size or 10 91 | local ret = torch.FloatTensor(t:size(1) - size + 1) 92 | local count = 0 93 | ret:apply(function() 94 | count = count + 1 95 | return t[{{count, count + size - 1}}]:var() 96 | end ) 97 | return ret 98 | end 99 | function F.var(t,size) 100 | local size = size or 10 101 | local ret = torch.FloatTensor(math.floor(t:size(1)/size)) 102 | local count = 0 103 | ret:apply(function() 104 | count = count + 1 105 | return t[{{(count-1)*size + 1, count*size}}]:var() 106 | end ) 107 | return ret 108 | end 109 | 110 | function F.getTail(t,n) 111 | local length = t:size(1) 112 | return t[{{length - n + 1,length}}] 113 | end 114 | function F.removeTail(t,n) 115 | local length = t:size(1) 116 | return t[{{1,length - n}}] 117 | end 118 | function F.getHead(t,n) 119 | return t[{{1,n}}] 120 | end 121 | function F.removeHead(t,n) 122 | local length = t:size(1) 123 | return t[{{1+n,length}}] 124 | end -------------------------------------------------------------------------------- /src/core/model.lua: -------------------------------------------------------------------------------- 1 | local dlt = require('dlt._env') 2 | 3 | local M,parent = torch.class('dlt.Model',dlt) 4 | 5 | 6 | function M:__init(modelCreate,name,save) 7 | self.name = name or 'model' 8 | self.shouldSave = (save == nil) and true or save 9 | if torch.type(self.name) ~= 'string' then 10 | dlt.log:error('Model name must be a string.') 11 | end 12 | self.name = self.name:gsub("^%l", string.upper) 13 | dlt.log:section(self.name .. ' initialization') 14 | dlt.parse(self) 15 | dlt.configure(self) 16 | if torch.type(modelCreate) == 'string' then 17 | modelCreate = dlt.help.checkHomePath(modelCreate) 18 | if not paths.filep(modelCreate) then 19 | dlt.log:error('Could not find model ' .. modelCreate) 20 | end 21 | self.model = torch.load(modelCreate) 22 | dlt.log:yell('Loaded ' .. self.name .. ' from file.'); 23 | elseif torch.type(modelCreate) == 'function' then 24 | self.model = modelCreate() 25 | dlt.log:yell('Created ' .. self.name .. '.') 26 | else 27 | dlt.log:error('dlt.Model parameter (modelCreate)' .. 28 | ' must be a string or a function.') 29 | end 30 | 31 | self:processModel() 32 | 33 | if self.parameters:size():size() ~= 0 then 34 | dlt.log:yell(string.format('%s parameters: %d.', 35 | self.name, self.parameters:size(1))) 36 | end 37 | collectgarbage() 38 | dlt.log:endSection() 39 | end 40 | 41 | function M:getCleanModel() 42 | -- Clear State 43 | self.model:clearState() 44 | -- Remove DataParallelTable 45 | self.model = torch.type(self.model) == 'nn.DataParallelTable' 46 | and self.model:get(1) 47 | or self.model 48 | -- Remove cudnn 49 | if self.useCudnn then 50 | cudnn.convert(self.model,nn) 51 | cutorch.synchronizeAll() 52 | end 53 | -- Return CPU model 54 | return self:cpu() 55 | end 56 | 57 | function M:cpu() 58 | return self.model:type('torch.' .. dlt.help.tensorList.cpu[self.tensorType]) 59 | end 60 | function M:gpu() 61 | return self.model:type('torch.' .. dlt.help.tensorList.gpu[self.tensorType]) 62 | end 63 | 64 | function M:processModel() 65 | -- First cast to cpu 66 | self:cpu() 67 | 68 | if self.nGPU > 0 then 69 | -- Convert to cudnn 70 | if self.useCudnn then 71 | cudnn.fastest = self.cudnnFastest 72 | cudnn.benchmark = self.cudnnBenchmark 73 | cudnn.verbose = self.cudnnVerbose 74 | cudnn.convert(self.model,cudnn) 75 | self:sync() 76 | end 77 | 78 | -- Transfer to gpu 79 | self:gpu() 80 | 81 | -- If multiple GPUs wrap in DataParallelTable 82 | if self.nGPU > 1 then 83 | local dpt = nn.DataParallelTable(1, self.dptFlatten,self.dptNccl) 84 | :add(self.model, torch.range(1, self.nGPU):totable()) 85 | self.model = dpt 86 | self:sync() 87 | end 88 | end 89 | -- Default to training mode 90 | self:training() 91 | -- Reflatten parameters 92 | self.parameters, self.gradParameters = self.model:getParameters() 93 | end 94 | 95 | function M:sync() 96 | if self.nGPU > 0 then 97 | cutorch.synchronizeAll() 98 | end 99 | end 100 | 101 | function M:save(filename) 102 | if self.shouldSave then 103 | torch.save(filename, self:getCleanModel()) 104 | self:processModel() 105 | end 106 | end 107 | 108 | function M:training() 109 | self.model:training() 110 | end 111 | function M:evaluate() 112 | self.model:evaluate() 113 | end 114 | function M:zeroGradParameters() 115 | self.model:zeroGradParameters() 116 | end 117 | function M:forward(input) 118 | return self.model:forward(input) 119 | end 120 | function M:backward(input,gradOutput) 121 | return self.model:backward(input,gradOutput) 122 | end 123 | function M:updateGradInput(input,gradOutput) 124 | return self.model:updateGradInput(input,gradOutput) 125 | end 126 | function M:__tostring__() 127 | return self.name .. '\n' .. self.model:__tostring() 128 | end 129 | -------------------------------------------------------------------------------- /src/loaders/places.lua: -------------------------------------------------------------------------------- 1 | local dlt = require('dlt._env') 2 | 3 | local P,parent = torch.class('dlt.Places','dlt.Loader',dlt) 4 | 5 | 6 | -- Settings 7 | -- s.path 8 | -- s.shuffle (defaults to true) 9 | -- s.assignPoint = function(point,iBatchMember,img,cls) 10 | -- s.type [byte] 11 | -- s.name (365 or 205) defaults to 365 12 | 13 | -- NOTE: Class is indexed from 0 so add 1 for criterion 14 | function P:__init(s) 15 | if not dlt.have.csvigo then 16 | dlt.log:error('Places dataset requires csvigo package') 17 | end 18 | if dlt.help.inTable({'365',365,'2',2,'places2','Places2', 'places365', 19 | 'Places365'},s.name) or s.name == nil then 20 | self.name = 'places365' 21 | elseif dlt.help.inTable({'205',205,'1',1,'places','Places','places1', 22 | 'Places1','places205','Places205'},s.name) 23 | or not s.name then 24 | self.name = 'places205' 25 | else 26 | dlt.log:error('Unknown places name: ' .. s.name) 27 | end 28 | parent.__init(self,s) 29 | self.path = {} 30 | if self.name == 'places365' then 31 | 32 | for _,val in ipairs{'training','validation','testing'} do 33 | self.path[val] = paths.concat(s.path,val) 34 | if not paths.dirp(self.path[val]) then 35 | dlt.log:warning('Could not find '.. val .. ' path for ' 36 | .. self.name .. '. ' .. self.path[val]) 37 | end 38 | end 39 | local train = paths.concat(self.path.training, 40 | 'places365_train_standard.txt') 41 | local val = paths.concat(self.path.validation,'places365_val.txt') 42 | local test = paths.concat(self.path.testing,'places365_test.txt') 43 | self.fileList = s.fileList 44 | or { training = train, 45 | validation = val, 46 | testing = test } 47 | else 48 | if not s.path then 49 | dlt.log:error('Path not provided for places205 loader.') 50 | end 51 | if not paths.dirp(s.path) then 52 | dlt.log:error('Path provided for places205 loader does not exist. ' 53 | .. s.path) 54 | end 55 | 56 | for _,val in ipairs{'training','validation'} do 57 | self.path[val] = s.path 58 | end 59 | local train = paths.concat(self.path.training,'train_places205.csv') 60 | local val = paths.concat(self.path.validation,'val_places205.csv') 61 | self.fileList = s.fileList 62 | or { training = train, 63 | validation = val} 64 | self.sets.testing = nil 65 | end 66 | 67 | self.type = s.type or 'byte' 68 | end 69 | 70 | -- Classes in places are 0 indexed, so we add 1. 71 | function P:splitNameClass(index,setName) 72 | setName = setName or self.currentSet 73 | local imgName,cls = string.match(self.set[setName].list[index][1], 74 | "^/*([^%s]+)%s*(%d*)") 75 | return imgName,tonumber(cls) + 1 76 | end 77 | 78 | -- Potentially needs a faster implementation if to be used a lot 79 | -- Only tested with 365 80 | -- Made it to (quickly) look at some sample images 81 | function P:sample(cls,setName) 82 | setName = setName or self.currentSet 83 | local index 84 | local nElem = self:size() 85 | while true do 86 | index = torch.random(1,nElem) 87 | local imgName,clsTest = self:splitNameClass(index,setName) 88 | if clsTest == cls then break end 89 | end 90 | local img,clsTest = self:dataPoint(index,setName) 91 | return img 92 | end 93 | 94 | function P:getFullPathAndClass(index,setName) 95 | setName = setName or self.currentSet 96 | local imgName,cls = self:splitNameClass(index,setName) 97 | return paths.concat(self.path[setName],imgName), cls 98 | end 99 | 100 | function P:dataPoint(index,setName) 101 | setName = setName or self.currentSet 102 | local imgName,cls = self:splitNameClass(index,setName) 103 | return image.load(paths.concat(self.path[setName],imgName),nil,self.type), 104 | cls 105 | end 106 | 107 | function P:initInstance(setName) 108 | self.set = self.set or {} 109 | setName = setName or self.currentSet 110 | local list = csvigo.load({path = self.fileList[setName] , mode = 'large', 111 | verbose = false}) 112 | self.set[setName] = { 113 | list = list, 114 | nPoints = #list 115 | } 116 | end 117 | -------------------------------------------------------------------------------- /src/util/slurm.lua: -------------------------------------------------------------------------------- 1 | local dlt = require('dlt._env') 2 | 3 | local S,parent = torch.class('Slurm',dlt) 4 | 5 | -- Class for slurm scheduler support 6 | -- Initialize Slurm object with settings 7 | function S:__init() 8 | dlt.parse(self) 9 | -- No need to configure 10 | -- dlt.configure(self) 11 | self.sTh = dlt.help.checkHomePath(self.sTh) 12 | self.sPrecommands = dlt.help.checkHomePath(self.sPrecommands) 13 | end 14 | 15 | function S:createScript(runPath) 16 | -- default runPath is jobname 17 | runPath = runPath or self.sJobname 18 | runPath = dlt.help.checkHomePath(runPath) 19 | -- Make script 20 | local script = [[#!/bin/bash]] .. '\n\n' 21 | 22 | -- Job Name, time, nodes, tasks, partition 23 | script = script .. [[#SBATCH --job-name=]] .. self.sJobname .. '\n' 24 | script = script .. [[#SBATCH --time=]] .. self.sTime .. '\n' 25 | script = script .. [[#SBATCH --nodes=]] .. self.sNodes .. '\n' 26 | script = script .. [[#SBATCH --ntasks-per-node=]] .. self.sTasks .. '\n' 27 | 28 | -- Memory. 29 | -- If total memory is given and we are not using fat nodes 30 | -- use total, otherwise request mem-per-cpu 31 | -- This is a hack for the HPC facility i'm currently using. FIX 32 | if self.sMem ~= 0 and self.sPartition ~= 'fat' then 33 | script = script .. [[#SBATCH --mem=]] .. self.sMem .. '\n' 34 | else 35 | script = script .. [[#SBATCH --mem-per-cpu=]] .. 36 | self.sMempercpu .. '\n' 37 | end 38 | 39 | -- Partition 40 | if self.sPartition ~= 'none' then 41 | script = script .. [[#SBATCH --partition=]] .. self.sPartition .. '\n' 42 | end 43 | 44 | -- Generic resources request 45 | if self.sGres ~= 'none' then 46 | script = script .. [[#SBATCH --gres=]] .. self.sGres .. '\n' 47 | end 48 | 49 | -- sExclude nodes 50 | if self.sExclude ~= 'none' then 51 | script = script .. [[#SBATCH --exclude=]] .. self.sExclude .. '\n' 52 | end 53 | 54 | -- Request nodes 55 | if self.sRequest ~= 'none' then 56 | script = script .. [[#SBATCH --nodelist=]] .. self.sRequest .. '\n' 57 | end 58 | 59 | -- Output name 60 | if self.sOutname == 'default' then 61 | local outputString = 'slurm_' .. self.sJobname .. [[_%A]] 62 | script = script .. [[#SBATCH --output=]] .. outputString .. '\n' 63 | else 64 | script = script .. [[#SBATCH --output=]] .. self.sOutname .. '\n' 65 | end 66 | 67 | -- email 68 | if self.sEmail ~= 'none' then 69 | script = script .. [[#SBATCH --mail-type=ALL]] .. '\n' 70 | script = script .. [[#SBATCH --mail-user=]] .. self.sEmail .. '\n\n' 71 | end 72 | script = script .. '\n' 73 | 74 | -- pre-commands 75 | if self.sPrecommands ~= 'none' then 76 | -- check pre-commands 77 | if not paths.filep(paths.concat(runPath,self.sPrecommands)) then 78 | dlt.log:error('Could not find file with pre-commands ' .. 79 | paths.concat(runPath,self.sPrecommands) ) 80 | end 81 | local preFile = io.open(paths.concat(runPath,self.sPrecommands),'r') 82 | local commands = preFile:read('*all') 83 | preFile:close() 84 | script = script .. commands ..'\n\n' 85 | end 86 | 87 | -- check script 88 | if self.sTh ~= 'none' and 89 | not paths.filep(paths.concat(runPath,self.sTh)) then 90 | dlt.log:error('Could not find torch script ' 91 | .. paths.concat(runPath,self.sTh) ) 92 | end 93 | 94 | -- if runPath == nil then self.runPath = self.sJobname end 95 | dlt.help.checkMakeDir(runPath) 96 | 97 | -- cd to runPath in slurm script 98 | script = script .. 'cd ' .. paths.concat(runPath) .. '\n\n' 99 | -- invoke torch 100 | if self.sTh ~= 'none' then script = script .. 'th ' .. self.sTh .. '\n' end 101 | 102 | -- Write slurm script to runPath 103 | local fullScriptName = paths.concat(runPath,'job') 104 | local file = io.open(fullScriptName, 'w+') 105 | file:write(script) 106 | file:close() 107 | 108 | return fullScriptName, script 109 | 110 | end 111 | 112 | -- NEEDS FIXING 113 | -- For now submit manually 114 | -- function S:submit(scriptFile) 115 | -- if not paths.filep(scriptFile) then 116 | -- dlt.log:error('Could not find batch script ' .. scriptFile) 117 | -- end 118 | -- local scriptPath = paths.dirname(scriptFile) 119 | -- dlt.log:print('Submitting script') 120 | -- os.execute('cd ' .. scriptPath .. '\n' .. 'sbatch ' .. scriptFile ) 121 | -- end -------------------------------------------------------------------------------- /src/models/colornet.lua: -------------------------------------------------------------------------------- 1 | local dlt = require('dlt._env') 2 | 3 | 4 | -- Paper: Let there be Color!: 5 | -- Joint End-to-end Learning of Global and Local Image Priors 6 | -- for Automatic Image Colorization with Simultaneous Classification 7 | -- http://hi.cs.waseda.ac.jp/~iizuka/projects/colorization/data/colorization_sig2016.pdf 8 | 9 | -- Accepts a table with two inputs. 10 | -- The second is wxh, the first is not constrained 11 | -- (dimensions should be divisible by 8) 12 | -- Output is a table. 13 | -- First output is an image of the size of the first input. 14 | -- Second output is the class probability vector 15 | -- Trained on 224x224 input 16 | -- (1 channel in (L from LUV) 17 | -- 2 channels out (ab from Lab - normalized to [0,1])) 18 | function dlt.models.colornet(w,h,nClasses,inChannels,outChannels,useBatchNorm) 19 | -- Short names and defaults 20 | local SBatchNorm = nn.SpatialBatchNormalization 21 | local BatchNorm = nn.BatchNormalization 22 | w = w or 224 23 | h = h or 224 24 | nClasses = nClasses or 205 25 | inChannels = inChannels or 1 26 | outChannels = outChannels or 2 27 | 28 | -- Helper function for adding layers 29 | local function layer(net,inputDepth,outputDepth,stride) 30 | net:add(nn.SpatialConvolution(inputDepth,outputDepth, 31 | 3,3,stride,stride,1,1)) 32 | if useBatchNorm then net:add(SBatchNorm(outputDepth)) end 33 | net:add(nn.ReLU(true)) 34 | end 35 | 36 | -- Local Features 37 | local lF = nn.Sequential() 38 | layer(lF,inChannels,64,2) 39 | layer(lF,64,128,1) 40 | layer(lF,128,128,2) 41 | layer(lF,128,256,1) 42 | layer(lF,256,256,2) 43 | layer(lF,256,512,1) 44 | 45 | -- Global Features start (before it splits for fusion and classifier) 46 | local gFstart = nn.Sequential() 47 | layer(gFstart,512,512,2) 48 | layer(gFstart,512,512,1) 49 | layer(gFstart,512,512,2) 50 | layer(gFstart,512,512,1) 51 | -- Change the view for fully connected layers 52 | local fCLength = w*h/2 53 | gFstart:add(nn.View(fCLength)) 54 | :add(nn.Linear(fCLength,1024)) 55 | if useBatchNorm then 56 | gFstart:add(BatchNorm(1024)) 57 | end 58 | gFstart:add(nn.ReLU(true)) 59 | gFstart:add(nn.Linear(1024,512)) 60 | if useBatchNorm then 61 | gFstart:add(BatchNorm(512)) 62 | end 63 | gFstart:add(nn.ReLU(true)) 64 | 65 | -- Global Features end for fusion 66 | local gFend = nn.Sequential() 67 | gFend:add(nn.Linear(512,256)) 68 | if useBatchNorm then 69 | gFend:add(BatchNorm(256)) 70 | end 71 | gFend:add(nn.ReLU(true)) 72 | gFend:add(nn.Replicate(w/8,3)) 73 | :add(nn.Replicate(h/8,4)) 74 | 75 | -- Mid Features 76 | local mF = nn.Sequential() 77 | layer(mF,512,512,1) 78 | layer(mF,512,256,1) 79 | 80 | -- largeNet is the beginning of the convolutional network 81 | -- where image size is not restricted 82 | -- up to fusion part 83 | local largeNet = nn.Sequential() 84 | largeNet:add(lF) 85 | :add(mF) 86 | 87 | -- Small is the classifier network (fusion branch) 88 | local smallNet = nn.Sequential() 89 | smallNet:add(lF):add(gFstart):add(gFend) 90 | 91 | -- Color model up to fusion part 92 | local colorModel = nn.Sequential() 93 | colorModel:add(nn.ParallelTable() 94 | :add(largeNet) 95 | :add(smallNet)) 96 | :add(nn.JoinTable(2)) 97 | 98 | -- Complete color model 99 | layer(colorModel,512,256,1) 100 | layer(colorModel,256,128,1) 101 | colorModel:add(nn.SpatialUpSamplingNearest(2)) 102 | layer(colorModel,128,64,1) 103 | layer(colorModel,64,64,1) 104 | colorModel:add(nn.SpatialUpSamplingNearest(2)) 105 | layer(colorModel,64,32,1) 106 | colorModel:add(nn.SpatialConvolution(32,outChannels,3,3,1,1,1,1)) 107 | colorModel:add(nn.Sigmoid()) 108 | 109 | -- Complete classifier 110 | local classifier = nn.Sequential():add(nn.SelectTable(2)) 111 | classifier:add(lF) 112 | :add(gFstart) 113 | :add(nn.Linear(512,256)) 114 | if useBatchNorm then 115 | classifier:add(BatchNorm(256)) 116 | end 117 | classifier:add(nn.ReLU(true)) 118 | :add(nn.Linear(256,nClasses)) 119 | if useBatchNorm then 120 | classifier:add(BatchNorm(nClasses)) 121 | end 122 | classifier:add(nn.Sigmoid()) 123 | 124 | -- Complete model 125 | local model = nn.ConcatTable() 126 | model:add(colorModel) 127 | :add(classifier) 128 | 129 | return model 130 | end -------------------------------------------------------------------------------- /src/models/dcgan.lua: -------------------------------------------------------------------------------- 1 | local dlt = require('dlt._env') 2 | 3 | -- Code adapted from https://github.com/soumith/dcgan.torch 4 | 5 | local function getClosures(net,sbn) 6 | local function Conv(nIn,nOut) 7 | net:add(nn.SpatialConvolution(nIn,nOut,4,4,2,2,1,1)) 8 | end 9 | local function FConv(nIn,nOut) 10 | net:add(nn.SpatialFullConvolution(nIn,nOut,4,4,2,2,1,1)) 11 | end 12 | local function SBN(nf) 13 | if sbn then net:add(nn.SpatialBatchNormalization(nf)) end 14 | end 15 | local function LReLU() net:add(nn.LeakyReLU(0.2, true)) end 16 | local function ReLU() net:add(nn.ReLU(true)) end 17 | return Conv,SBN,LReLU,FConv,ReLU 18 | end 19 | 20 | dlt.models.dcgan = {} 21 | 22 | function dlt.models.dcgan.initWeights(m) 23 | local name = torch.type(m) 24 | if name:find('Convolution') then 25 | m.weight:normal(0.0, 0.02) 26 | m:noBias() 27 | elseif name:find('BatchNormalization') then 28 | if m.weight then m.weight:normal(1.0, 0.02) end 29 | if m.bias then m.bias:fill(0) end 30 | end 31 | end 32 | 33 | function dlt.models.dcgan.generator64(nz,ngf,nc,init,sbn) 34 | nz = nz or 100 35 | ngf = ngf or 64 36 | nc = nc or 3 37 | if init == nil then init = true end 38 | local netG = nn.Sequential() 39 | _,SBN,_,FConv,ReLU = getClosures(netG,sbn) 40 | 41 | netG:add(nn.View(nz,1,1)) 42 | 43 | netG:add(nn.SpatialFullConvolution(nz, ngf * 8, 4, 4)) 44 | SBN(ngf*8) 45 | ReLU() 46 | -- state size: (ngf*8) x 4 x 4 47 | FConv(ngf * 8, ngf * 4) 48 | SBN(ngf*4) 49 | ReLU() 50 | -- state size: (ngf*4) x 8 x 8 51 | FConv(ngf * 4, ngf * 2) 52 | SBN(ngf*2) 53 | ReLU() 54 | -- state size: (ngf*2) x 16 x 16 55 | FConv(ngf * 2, ngf) 56 | SBN(ngf) 57 | ReLU() 58 | -- state size: (ngf) x 32 x 32 59 | FConv(ngf, nc) 60 | netG:add(nn.Tanh()) 61 | -- state size: (nc) x 64 x 64 62 | if init then netG:apply(dlt.models.dcgan.initWeights) end 63 | return netG 64 | end 65 | 66 | function dlt.models.dcgan.generator32(nz,ngf,nc,init,sbn) 67 | nz = nz or 100 68 | ngf = ngf or 64 69 | nc = nc or 3 70 | if init == nil then init = true end 71 | local netG = nn.Sequential() 72 | _,SBN,_,FConv,ReLU = getClosures(netG,sbn) 73 | 74 | netG:add(nn.View(nz,1,1)) 75 | 76 | netG:add(nn.SpatialFullConvolution(nz, ngf * 4, 4, 4)) 77 | SBN(ngf*4) 78 | ReLU() 79 | -- state size: (ngf*8) x 4 x 4 80 | FConv(ngf * 4, ngf * 2) 81 | SBN(ngf*2) 82 | ReLU() 83 | -- state size: (ngf*4) x 8 x 8 84 | FConv(ngf * 2, ngf) 85 | SBN(ngf) 86 | ReLU() 87 | -- state size: (ngf*2) x 16 x 16 88 | FConv(ngf, nc) 89 | netG:add(nn.Tanh()) 90 | -- state size: (nc) x 32 x 32 91 | if init then netG:apply(dlt.models.dcgan.initWeights) end 92 | return netG 93 | end 94 | 95 | function dlt.models.dcgan.discriminator64(ndf,nc,sigmoid,init,sbn) 96 | ndf = ndf or 64 97 | nc = nc or 3 98 | if init == nil then init = true end 99 | local netD = nn.Sequential() 100 | Conv,SBN,LReLU,_ = getClosures(netD,sbn) 101 | -- input is (nc) x 64 x 64 102 | Conv(nc, ndf) 103 | LReLU() 104 | -- state size: (ndf) x 32 x 32 105 | Conv(ndf, ndf * 2) 106 | SBN(ndf*2) 107 | LReLU() 108 | -- state size: (ndf*2) x 16 x 16 109 | Conv(ndf * 2, ndf * 4) 110 | SBN(ndf*4) 111 | LReLU() 112 | -- state size: (ndf*4) x 8 x 8 113 | Conv(ndf * 4, ndf * 8) 114 | SBN(ndf*8) 115 | LReLU() 116 | -- state size: (ndf*4) x 4 x 4 117 | netD:add(nn.SpatialConvolution(ndf * 8, 1, 4, 4)) 118 | if sigmoid then netD:add(nn.Sigmoid()) end 119 | -- state size: 1 x 1 x 1 120 | netD:add(nn.View(1):setNumInputDims(3)) 121 | -- state size: 1 122 | if init then netD:apply(dlt.models.dcgan.initWeights) end 123 | return netD 124 | end 125 | 126 | function dlt.models.dcgan.discriminator32(ndf,nc,sigmoid,init,sbn) 127 | ndf = ndf or 64 128 | nc = nc or 3 129 | if init == nil then init = true end 130 | local netD = nn.Sequential() 131 | Conv,SBN,LReLU,_ = getClosures(netD,sbn) 132 | -- input is (nc) x 32 x 32 133 | Conv(nc, ndf) 134 | LReLU() 135 | -- state size: (ndf*2) x 16 x 16 136 | Conv(ndf, ndf * 2) 137 | SBN(ndf*2) 138 | LReLU() 139 | -- state size: (ndf*4) x 8 x 8 140 | Conv(ndf * 2, ndf * 4) 141 | SBN(ndf*4) 142 | LReLU() 143 | -- state size: (ndf*4) x 4 x 4 144 | netD:add(nn.SpatialConvolution(ndf * 4, 1, 4, 4)) 145 | -- state size: 1 x 1 x 1 146 | netD:add(nn.View(1):setNumInputDims(3)) 147 | if sigmoid then netD:add(nn.Sigmoid()) end 148 | -- state size: 1 149 | if init then netD:apply(dlt.models.dcgan.initWeights) end 150 | return netD 151 | end -------------------------------------------------------------------------------- /src/models/alexnet.lua: -------------------------------------------------------------------------------- 1 | local dlt = require('dlt._env') 2 | 3 | -- Adapted from https://github.com/soumith/imagenet-multiGPU.torch/blob/master/models/alexnet.lua 4 | -- Note, parallelism was removed, 5 | -- but if using multiGPU then DPT will wrap the whole model 6 | 7 | local function makeFeatures(w,h,inChannels,featureList,bn) 8 | local currentW,currentH = w,h 9 | local feat = nn.Sequential() 10 | -- 224 -> 55 11 | feat:add(nn.SpatialConvolution(inChannels,featureList[1],11,11,4,4,2,2)) 12 | if bn then feat:add(nn.SpatialBatchNormalization(featureList[1],1e-4)) end 13 | currentW,currentH = dlt.help.SpatialConvolutionSize(currentW,currentH, 14 | 11,11,4,4,2,2) 15 | feat:add(nn.ReLU(true)) 16 | -- 55 -> 27 17 | feat:add(nn.SpatialMaxPooling(3,3,2,2)) 18 | currentW,currentH = dlt.help.SpatialMaxPoolingSize(currentW,currentH, 19 | 3,3,2,2) 20 | -- 27 -> 27 21 | feat:add(nn.SpatialConvolution(featureList[1],featureList[2],5,5,1,1,2,2)) 22 | if bn then feat:add(nn.SpatialBatchNormalization(featureList[2],1e-4)) end 23 | currentW,currentH = dlt.help.SpatialConvolutionSize(currentW,currentH, 24 | 5,5,1,1,2,2) 25 | feat:add(nn.ReLU(true)) 26 | -- 27 -> 13 27 | feat:add(nn.SpatialMaxPooling(3,3,2,2)) 28 | currentW,currentH = dlt.help.SpatialMaxPoolingSize(currentW,currentH, 29 | 3,3,2,2) 30 | -- 13 -> 13 31 | feat:add(nn.SpatialConvolution(featureList[2],featureList[3],3,3,1,1,1,1)) 32 | if bn then feat:add(nn.SpatialBatchNormalization(featureList[3],1e-4)) end 33 | currentW,currentH = dlt.help.SpatialConvolutionSize(currentW,currentH, 34 | 3,3,1,1,1,1) 35 | feat:add(nn.ReLU(true)) 36 | -- 13 -> 13 37 | feat:add(nn.SpatialConvolution(featureList[3],featureList[4],3,3,1,1,1,1)) 38 | if bn then feat:add(nn.SpatialBatchNormalization(featureList[4],1e-4)) end 39 | currentW,currentH = dlt.help.SpatialConvolutionSize(currentW,currentH, 40 | 3,3,1,1,1,1) 41 | feat:add(nn.ReLU(true)) 42 | -- 13 -> 13 43 | feat:add(nn.SpatialConvolution(featureList[4],featureList[5],3,3,1,1,1,1)) 44 | if bn then feat:add(nn.SpatialBatchNormalization(featureList[5],1e-4)) end 45 | currentW,currentH = dlt.help.SpatialConvolutionSize(currentW,currentH 46 | ,3,3,1,1,1,1) 47 | feat:add(nn.ReLU(true)) 48 | -- 13 -> 6 49 | feat:add(nn.SpatialMaxPooling(3,3,2,2)) 50 | currentW,currentH = dlt.help.SpatialMaxPoolingSize(currentW,currentH 51 | ,3,3,2,2) 52 | return feat, currentW, currentH 53 | end 54 | 55 | local function makeClassifier(currentW,currentH,nClasses,dropout,bn) 56 | local classifier = nn.Sequential() 57 | classifier:add(nn.View(256*currentW*currentH)) 58 | if dropout then classifier:add(nn.Dropout(0.5)) end 59 | classifier:add(nn.Linear(256*currentW*currentH, 4096)) 60 | if bn then classifier:add(nn.BatchNormalization(4096, 1e-4)) end 61 | classifier:add(nn.ReLU(true)) 62 | if dropout then classifier:add(nn.Dropout(0.5)) end 63 | classifier:add(nn.Linear(4096, 4096)) 64 | if bn then classifier:add(nn.BatchNormalization(4096, 1e-4)) end 65 | classifier:add(nn.ReLU(true)) 66 | classifier:add(nn.Linear(4096, nClasses)) 67 | classifier:add(nn.LogSoftMax()) 68 | return classifier 69 | end 70 | 71 | function dlt.models.alexnet(w,h,inChannels,nClasses) 72 | w = w or 224 73 | h = h or 224 74 | inChannels = inChannels or 3 75 | nClasses = nClasses or 1000 76 | local features = nn.Concat(2) 77 | -- branch 1 78 | local fb1,currentW,currentH = makeFeatures(w,h,inChannels, 79 | {48,128,192,192,128},false) 80 | -- branch 2 81 | local fb2 = fb1:clone() 82 | -- reset branch 2's weights 83 | for k,v in ipairs(fb2:findModules('nn.SpatialConvolution')) do 84 | v:reset() 85 | end 86 | features:add(fb1):add(fb2) 87 | -- 1.3. Create Classifier (fully connected layers) 88 | local classifier = makeClassifier(currentW,currentH,nClasses,true,false) 89 | -- 1.4. Combine 1.1 and 1.3 to produce final model 90 | local model = nn.Sequential():add(features):add(classifier) 91 | return model 92 | end 93 | 94 | -- this is AlexNet that was presented in the One Weird Trick paper. 95 | -- http://arxiv.org/abs/1404.5997 96 | function dlt.models.alexnet2(w,h,inChannels,nClasses) 97 | w = w or 224 98 | h = h or 224 99 | inChannels = inChannels or 3 100 | nClasses = nClasses or 1000 101 | local features,currentW,currentH = makeFeatures(w,h,inChannels, 102 | {64,192,384,256,256},true) 103 | local classifier = makeClassifier(currentW,currentH,nClasses,true,true) 104 | local model = nn.Sequential():add(features):add(classifier) 105 | return model 106 | end 107 | -------------------------------------------------------------------------------- /doc/loader.md: -------------------------------------------------------------------------------- 1 | # Loader 2 | Loaders are to be used with [`dlt.Data`](data.md) but can also be used on their own. 3 | 4 | Implemented loaders: *MNIST*, *CIFAR*, *CelebA*, *Places*, *pix2pix*. 5 | 6 | It's straightforward to create a loader for a new dataset that's compatible with the rest of the toolbox. 7 | 8 | ## Usage 9 | ```lua 10 | -- LoaderName is a placeholder name here. 11 | data = dlt.LoaderName( s ) 12 | ``` 13 | s is a table of settings. Some commonly used settings are: 14 | 15 | * `path` Path to data. 16 | * `shuffle` Whether to shuffle the indices when loading. 17 | * `assignPoint` Function that assigns a loaded datapoint into a given batch. 18 | 19 | Some loaders may accept additional settings, e.g: 20 | 21 | * `transform` Function that is applied on the dataset (for datasets that are loaded on initialization). 22 | * `name` String for loaders that handle more than one dataset (e.g. CIFAR may be `10` or `100`). 23 | * `type` For datasets that use image.load at every call to :get(). 24 | * `download` For datasets that are easily downloaded (*MNIST*, *CIFAR*) 25 | * `fileList` For *Places* datasets. Files that contain custom lists (e.g. the list for the colored images only). e.g. 26 | * `fileList = {training = '/path/to/col_places.csv'} ` 27 | 28 | ## Directory structures 29 | ### MNIST 30 | Path should contain *train_32x32.t7* and *test_32x32.t7* from extracting [this](https://s3.amazonaws.com/torch7/data/mnist.t7.tgz). 31 | Can use `download` setting to automatically get the data (provided path must already exist). 32 | ### CIFAR 33 | Path should contain *cifar10-train.t7*, *cifar100-train.t7*, *cifar10-test.t7*, *cifar100-test.t7* created using [cifar.torch](https://github.com/soumith/cifar.torch). 34 | Can use `download` setting to automatically get the data (provided path must already exist). 35 | ### CelebA 36 | Path must contain all the original images. 37 | ### Places (Places205) 38 | Path must contain directories a,b,c... as well as train_places205.csv and val_places205.csv (unless custom lists are provided through `fileList`) 39 | ### Places2 (Places365) 40 | Path must contain the (renamed) directories *training*, *validation*, *testing*. Each of these directory contents are unchanged from the extracted (standard 256x256), only renamed. The three directories must also contain *places365_val.txt*, *places365_val.txt* and *places365_test.txt* respectively (unless custom lists are provided through `fileList`). The *.txt* lists are found from extracting `filelist_places365-standard.tar`. 41 | ### pix2pix 42 | Path must contain the directories *cityscapes*, *maps*, *facades*, *edges2handbags*, *edges2shoes* as downloaded from [here](https://github.com/phillipi/pix2pix). 43 | 44 | ## Methods 45 | 46 | ### `init([setName])` 47 | Initializes `setName` (*training, validation, testing*). If not provided, all available sets are initialized. 48 | 49 | **NOTE**: This MUST be called right after the creation of the loader, before any other functions are used. It is separated from the `__init()` method so that it can be called independently on multiple threads (e.g. to initialize non serializable objects). 50 | 51 | ### `mode(setName)` 52 | Changes set (*training, validation, testing*). 53 | 54 | ### `get(index)` 55 | Returns datapoint at `index` from current set. If shuffle is on, then it returns the shuffled index. 56 | 57 | ### `[s] size([setName])` 58 | Returns the size of `setName` (or current set if not provided). 59 | 60 | ### `reshuffle()` 61 | Reshuffles! 62 | 63 | ### `assignBatch(batch,iDataPoint,n)` 64 | Fills given batch with n consecutive points starting from `iDataPoint` according to the `assignPoint` function (provided on initialization). 65 | 66 | ## Example 1 67 | 68 | ```lua 69 | local places = dlt.Places{ path = '~/data/places365', shuffle = true, type = 'byte'} 70 | 71 | -- Must call init 72 | -- Initializes all (training,validation,testing for places) 73 | -- current mode is training by default 74 | places:init() 75 | 76 | -- get returns image and class for Places2 77 | local img, cls = places:get(1) 78 | image.display(img) 79 | print(cls) 80 | 81 | -- Reshuffling 82 | places:reshuffle() 83 | img, cls = places:get(1) -- should be different 84 | image.display(img) 85 | print(cls) 86 | 87 | -- Get validation 88 | places:mode('validation') 89 | img, cls = places:get(1) -- should be from validation 90 | image.display(img) 91 | print(cls) 92 | ``` 93 | 94 | (Run with qlua) 95 | 96 | ## Example 2 97 | ```lua 98 | local dlt = require('dlt') 99 | local places = dlt.Places{ path = '~/data/places365', 100 | shuffle = true, 101 | type = 'byte', 102 | -- assignPoint describes the rule that gets a loaded image and class and puts it into a batch 103 | assignPoint = function(batch,iBatchMember,img,cls) 104 | if img:size(1) == 1 then img:repeatTensor(3,1,1) end -- Greyscale images 105 | img = image.scale(img,64,64) 106 | img = img:float():div(255) -- This should be avoided by just using type = 'float' 107 | batch.discriminatorInput[iBatchMember]:copy(img) 108 | batch.generatorInput[1][iBatchMember]:copy(img:mul(3):clamp(0,1)) 109 | batch.generatorInput[2][iBatchMember] = cls 110 | end } 111 | 112 | places:init('training') -- Initialize only training 113 | local batchSize = 4 114 | local myBatch = {discriminatorInput = torch.Tensor(batchSize,3,64,64), 115 | generatorInput = { torch.Tensor(batchSize,3,64,64), torch.Tensor(batchSize) } } 116 | 117 | places:assignBatch(myBatch,100,batchSize) 118 | image.display(myBatch.discriminatorInput) 119 | image.display(myBatch.generatorInput[1]) 120 | ``` 121 | 122 | (Run with qlua) -------------------------------------------------------------------------------- /src/util/logger.lua: -------------------------------------------------------------------------------- 1 | local dlt = require('dlt._env') 2 | 3 | -- I did not really need a stack, 4 | -- but sometimes you just feel like writing one. 5 | local stackC = torch.class('Stack') 6 | function stackC:__init() 7 | self.tab = {} 8 | self.push = function(sel,...) 9 | if ... then 10 | for _,val in ipairs({...}) do 11 | table.insert(sel.tab, val) 12 | end 13 | end 14 | end 15 | self.size = function(sel) return #sel.tab end 16 | self.pop = function(sel,num) 17 | num = num or 1 18 | num = num < 1 and 1 or num 19 | num = num > #sel.tab and #sel.tab or num 20 | if num == 0 then return end 21 | local ret = {} 22 | for i = 1, num do 23 | table.insert(ret, sel.tab[#sel.tab]) 24 | table.remove(sel.tab) 25 | end 26 | return unpack(ret) 27 | end 28 | end 29 | 30 | -- Logger for dlt 31 | -- Has verbose levels 32 | -- Can be set to write to file 33 | -- Prints things in boxes -- Almost surely unnecessary, but I like it. 34 | -- Maybe should put a setting to be able to turn boxes off 35 | local L,parent = torch.class('Logger',dlt) 36 | function L:__init(level,filename) 37 | if filename then self:setFile(filename) end 38 | level = level or 3 39 | if level < 1 or level > 6 then 40 | print('Verbose level needs to be between 1 and 6. Setting to 3') 41 | level = 3 42 | end 43 | self.vlevel = level 44 | self.levels = { 45 | error = 1, 46 | warning = 2, 47 | yell = 3, 48 | say = 4, 49 | detail = 5, 50 | debug = 6, 51 | section = 3 52 | } 53 | self.width = 78; 54 | self.sectionStack = Stack() 55 | return self 56 | end 57 | 58 | function L:setLevel(level) self.vlevel = level end 59 | function L:getLevel() return self.vlevel end 60 | 61 | function L:print(message,level) 62 | level = level or 3 63 | if level == 1 then 64 | print(message) 65 | os.exit() 66 | elseif level <= self.vlevel then 67 | print(message) 68 | end 69 | if self.toFile then 70 | local file = io.open(self.loggerFileName,'a') 71 | file:write(message .. '\n') 72 | file:close() 73 | end 74 | return self 75 | end 76 | 77 | function L:setFile(filename) 78 | dlt.help.checkMakeDir(paths.dirname(filename)) 79 | self.loggerFileName = filename 80 | self.toFile = true 81 | local exists = paths.filep(self.loggerFileName) 82 | local file = io.open(self.loggerFileName,'a') 83 | if exists then 84 | file:write('\n\nRESTARTING\n\n') 85 | end 86 | file:close() 87 | return self 88 | end 89 | 90 | function L:underdash(level,length) 91 | length = length or self.width 92 | level = level or self.levels['section'] 93 | self:print(' ' .. string.rep('_',length),level) 94 | return self 95 | end 96 | 97 | function L:paddedText(message,padding,level,length) 98 | padding = padding or ' ' 99 | length = length or self.width 100 | level = level or self.levels['section'] 101 | local leftPad = string.rep(padding,torch.floor((length - #message) / 2)) 102 | local rightPad = string.rep(padding,torch.ceil((length - #message) / 2)) 103 | local toPrint = '|' .. leftPad .. message .. rightPad .. '|' 104 | self:print(toPrint,level) 105 | return self 106 | end 107 | 108 | function L:box(message,level,length) 109 | level = level or self.levels['section'] 110 | length = length or self.width 111 | self:underdash(level,length) 112 | self:paddedText(message,'_',level,length) 113 | self:padPrint(' ',self.levels['section']) 114 | return self 115 | end 116 | 117 | local function split(str, delim) 118 | -- Eliminate bad cases... 119 | if string.find(str, delim) == nil then return { str } end 120 | 121 | local result,pat,lastpos = {},"(.-)" .. delim .. "()",nil 122 | for part, pos in string.gfind(str, pat) do table.insert(result, part); lastPos = pos; end 123 | table.insert(result, string.sub(str, lastPos)) 124 | return result 125 | end 126 | 127 | function L:lineSplit(message,length) 128 | length = length or self.width 129 | local words = split(message,' ') 130 | local lines = {''} 131 | local count = 1 132 | for _,val in ipairs(words) do 133 | if #lines[count] + #tostring(val) + 4 >= length then 134 | count = count + 1 135 | lines[count] = ' ' 136 | end 137 | lines[count] = lines[count] .. ' ' .. val 138 | end 139 | return lines 140 | end 141 | 142 | function L:padPrint(message,level) 143 | local lines = self:lineSplit(message) 144 | for _,val in ipairs(lines) do 145 | local padding = self.width - #val - 1 146 | self:print('| ' .. val .. string.rep(' ',padding) .. '|',level) 147 | end 148 | return self 149 | end 150 | 151 | function L:error(message) 152 | self:print('ERROR: ' .. message .. '\nABORTING',self.levels['error']) 153 | end 154 | function L:warning(message) 155 | self:print('WARNING: ' .. message,self.levels['warning']) 156 | end 157 | function L:yell(message) 158 | self:padPrint(message,self.levels['yell']) 159 | end 160 | function L:say(message) 161 | self:padPrint(message,self.levels['say']) 162 | end 163 | function L:detail(message) 164 | self:padPrint(message,self.levels['detail']) 165 | end 166 | function L:debug(message) 167 | self:print('DEBUG: ' .. message,self.levels['debug']) 168 | end 169 | 170 | 171 | function L:section(name) 172 | self.sectionStack:push(name) 173 | self:box(name) 174 | return self 175 | end 176 | function L:endSection() 177 | local name = self.sectionStack:pop(1) 178 | self:padPrint(name .. ' done.',self.levels['section']) 179 | self:paddedText('','_') 180 | return self 181 | end 182 | -------------------------------------------------------------------------------- /src/util/color.lua: -------------------------------------------------------------------------------- 1 | -- Colorspace conversions 2 | -- Assumptions: 3 | -- RGB is CIE RGB 4 | -- RGB is float and 5 | -- image is depth * width * height (e.g. 3*1920*1080 for FHD) 6 | -- Images are all colored (3 channels) 7 | -- For each colorspace, indices correspond to the letters 8 | -- (e.g. im[1] = R, im[2] = G, im[3] = B) 9 | local dlt = require('dlt._env') 10 | 11 | dlt.color = {} 12 | local C = dlt.color 13 | 14 | local rt2, rt3, rt6 = math.sqrt(2), math.sqrt(3), math.sqrt(6) 15 | local irt2, irt3, irt6 = 1/rt2, 1/rt3, 1/rt6 16 | local epsilon = 1e-15 17 | 18 | C.mat = { 19 | 20 | rgb2xyz = torch.Tensor({{ 0.4887180, 0.3106803, 0.2006017 }, 21 | { 0.1762044, 0.8129847, 0.0108109 }, 22 | { 0.0000000, 0.0102048, 0.9897952 }}), 23 | 24 | xyz2rgb = torch.Tensor({{ 2.3706743, -0.9000405, -0.4706338 }, 25 | {-0.5138850, 1.4253036, 0.0885814 }, 26 | { 0.0052982, -0.0146949, 1.0093968 }}), 27 | 28 | xyz2lms = torch.Tensor({{ 0.3897100, 0.6889800, -0.0786800 }, 29 | {-0.2298100, 1.1834000, 0.0464100 }, 30 | { 0.0000000, 0.0000000, 1.0000000 }}), 31 | 32 | lms2xyz = torch.Tensor({{ 1.9102000, -1.1121200, 0.2019080 }, 33 | { 0.3709500, 0.6290540, 0.0000000 }, 34 | { 0.0000000, 0.0000000, 1.0000000 }}), 35 | -- CIE XYZ to LMS D65 36 | xyz2lmsD65 = torch.Tensor({{ 0.4002000, 0.7075000, -0.0807000 }, 37 | {-0.2280000, 1.1500000, 0.0612000 }, 38 | { 0.0000000, 0.0000000, 0.9184000 }}), 39 | -- L'M'S' TO IPT 40 | lpmpsp2ipt = torch.Tensor({{ 0.4000000, 0.4000000, 0.2000000 }, 41 | { 4.4550000, -4.8510000, 0.3960000 }, 42 | { 0.8056000, 0.3572000, -1.1628000 }}), 43 | -- IPT to L'M'S' 44 | ipt2lpmpsp = torch.Tensor({{ 1.0000000, 0.0975689, 0.2052260 }, 45 | { 1.0000000, -0.1138760, 0.1332170 }, 46 | { 1.0000000, 0.0326151, -0.6768870 }}), 47 | -- LMS D65 to CIE XYZ 48 | lmsD652xyz = torch.Tensor({{ 1.8502400, -1.1383000, 0.2384350 }, 49 | { 0.3668310, 0.6438850, -0.0106734 }, 50 | { 0.0000000, 0.0000000, 1.0888500 }}), 51 | 52 | loglms2lalphabeta = torch.Tensor({{ irt3 , irt3 , irt3 }, 53 | { irt6 , irt6 , -2*irt6 }, 54 | { irt2 , -irt2 , 0 }}), 55 | 56 | lalphabeta2loglms = torch.Tensor({{ irt3 , irt6 , irt2 }, 57 | { irt3 , irt6 , -irt2 }, 58 | { irt3 , -2*irt6 , 0 }}) 59 | } 60 | 61 | -- There must be a better way to do this 62 | -- Multiplies each pixel of input (3,w,h) with matrix mat 63 | function C.matrixMultiply(input,mat) 64 | local output = input.new():resizeAs(input):zero() 65 | output[1]:add(mat[1][1],input[1]) 66 | :add(mat[1][2],input[2]) 67 | :add(mat[1][3],input[3]) 68 | output[2]:add(mat[2][1],input[1]) 69 | :add(mat[2][2],input[2]) 70 | :add(mat[2][3],input[3]) 71 | output[3]:add(mat[3][1],input[1]) 72 | :add(mat[3][2],input[2]) 73 | :add(mat[3][3],input[3]) 74 | return output 75 | end 76 | 77 | -- CIE RGB - CIE XYZ 78 | function C.rgb2xyz(im) 79 | return C.matrixMultiply(im,C.mat.rgb2xyz) 80 | end 81 | -- CIE XYZ - CIE RGB 82 | function C.xyz2rgb(im) 83 | return C.matrixMultiply(im,C.mat.xyz2rgb) 84 | end 85 | -- CIE XYZ - LMS (equal energy) 86 | function C.xyz2lms(im) 87 | return C.matrixMultiply(im,C.mat.xyz2lms) 88 | end 89 | -- LMS (equal energy) - CIE XYZ 90 | function C.lms2xyz(im) 91 | return C.matrixMultiply(im,C.mat.lms2xyz) 92 | end 93 | -- CIE RGB - LMS (equal energy) 94 | function C.rgb2lms(im) 95 | return C.xyz2lms(C.rgb2xyz(im)) 96 | end 97 | -- LMS (equal energy) - CIE RGB 98 | function C.lms2rgb(im) 99 | return C.xyz2rgb(C.lms2xyz(im)) 100 | end 101 | -- LMS - Lαβ 102 | function C.lms2lalphabeta(im) 103 | return C.matrixMultiply(torch.log(im+epsilon),C.mat.loglms2lalphabeta) 104 | end 105 | -- Lαβ - LMS 106 | function C.lalphabeta2lms(im) 107 | return torch.exp(C.matrixMultiply(im,C.mat.lalphabeta2loglms)) 108 | end 109 | -- CIE RGB - Lαβ 110 | function C.rgb2lalphabeta(im) 111 | return C.lms2lalphabeta(C.rgb2lms(im)) 112 | end 113 | -- Lαβ - CIE RGB 114 | function C.lalphabeta2rgb(im) 115 | return C.lms2rgb(C.lalphabeta2lms(im)) 116 | end 117 | -- CIE XYZ - LMS D65 118 | function C.xyz2lmsD65(im) 119 | return C.matrixMultiply(im,C.mat.xyz2lmsD65) 120 | end 121 | -- LMS D65 - CIE XYZ 122 | function C.lmsD652xyz(im) 123 | return C.matrixMultiply(im,C.mat.lmsD652xyz) 124 | end 125 | -- L'M'S' - IPT 126 | function C.lpmpsp2ipt(im) 127 | return C.matrixMultiply(im,C.mat.lpmpsp2ipt) 128 | end 129 | -- IPT - L'M'S' 130 | function C.ipt2lpmpsp(im) 131 | return C.matrixMultiply(im,C.mat.ipt2lpmpsp) 132 | end 133 | -- LMS D65 - L'M'S' 134 | function C.lmsD652lpmpsp(im) 135 | local res = torch.abs(im:clone()) 136 | res:pow(0.43) 137 | res:cmul(torch.sign(im)) 138 | return res 139 | end 140 | -- L'M'S' - LMS D65 141 | function C.lpmpsp2lmsD65(im) 142 | local res = torch.abs(im:clone()) 143 | res:pow(1/0.43) 144 | res:cmul(torch.sign(im)) 145 | return res 146 | end 147 | 148 | -- CIE XYZ - IPT 149 | function C.xyz2ipt(im) 150 | return C.lpmpsp2ipt(C.lmsD652lpmpsp(C.xyz2lmsD65(im))) 151 | end 152 | -- IPT - CIE XYZ 153 | function C.ipt2xyz(im) 154 | return C.lmsD652xyz(C.lpmpsp2lmsD65(C.ipt2lpmpsp(im))) 155 | end 156 | -- CIE RGB - IPT 157 | function C.rgb2ipt(im) 158 | return C.xyz2ipt(C.rgb2xyz(im)) 159 | end 160 | -- IPT - CIE RGB 161 | function C.ipt2rgb(im) 162 | return C.xyz2rgb(C.ipt2xyz(im)) 163 | end 164 | -- CIE RGB - LMS D65 165 | function C.rgb2lmsD65(im) 166 | return C.xyz2lmsD65(C.rgb2xyz(im)) 167 | end 168 | -- LMS D65 - CIE RGB 169 | function C.lmsD652rgb(im) 170 | return C.xyz2rgb(C.lmsD652xyz(im)) 171 | end 172 | -- CIE RGB - L'M'S' 173 | function C.rgb2lpmpsp(im) 174 | return C.lmsD652lpmpsp(C.rgb2lmsD65(im)) 175 | end 176 | -- L'M'S' - CIE RGB 177 | function C.lpmpsp2rgb(im) 178 | return C.lmsD652rgb(C.lpmpsp2lmsD65(im)) 179 | end 180 | 181 | -- MISC 182 | function C.linearizeSRGB_(img) 183 | if not C.linearLookup then 184 | C.linearLookup = torch.FloatTensor(256) 185 | count = 0 186 | C.linearLookup:apply( function() 187 | local x = count/255 188 | count = count + 1 189 | if x <= 0.04045 then 190 | return x/12.92 191 | else 192 | return torch.pow((x + 0.055) / (1.055),2.4) 193 | end 194 | end ) 195 | end 196 | return img:apply(function(x) 197 | return C.linearLookup[x*255 + 1] 198 | end ) 199 | end 200 | 201 | 202 | return C -------------------------------------------------------------------------------- /src/core/data.lua: -------------------------------------------------------------------------------- 1 | local dlt = require('dlt._env') 2 | 3 | local D,parent = torch.class('dlt.Data',dlt) 4 | 5 | -- dlt.Data objects utilize loaders to iterate over datasets 6 | -- using single or multiple threads 7 | -- Requires a loader that implements init(), assignBatch(), mode(), size(), 8 | -- [reshuffle()] (with or without arguments) 9 | function D:__init(loader,pointSize, datasets, currentEpoch) 10 | dlt.parse(self) 11 | dlt.configure(self) 12 | if loader == nil then 13 | dlt.log:error('No loader provided for data.') 14 | end 15 | if pointSize == nil then 16 | dlt.log:error('No pointSize provided for data.') 17 | end 18 | self.datasets = datasets or {'training'} 19 | self.currentEpoch = currentEpoch or 1 20 | if self.currentEpoch > self.maxEpochs then 21 | dlt.log:error('Max epochs exeeded (' .. self.maxEpochs .. ').') 22 | end 23 | 24 | dlt.log:section(('Data initialization')) 25 | -- Launch Donkeys in threads or on master thread 26 | if self.nThreads > 0 then 27 | dlt.log:yell(string.format('Initializing %d thread(s)',self.nThreads)) 28 | -- local threads = require('threads') 29 | threads.Threads.serialization('threads.sharedserialize') 30 | local mid = threads.Mutex():id() 31 | self.datathreads = threads.Threads( 32 | self.nThreads, function() 33 | dl = require('dlt') 34 | torch.setdefaulttensortype('torch.FloatTensor') 35 | end, 36 | function(idx) 37 | tid = idx 38 | torch.manualSeed(self.seed) 39 | dlt = dl 40 | local t = require('threads') 41 | _donkey = dlt.Donkey(loader, pointSize, self.batchSize, 42 | self.useLocks, self.collectGarbage, 43 | self.tensorType) 44 | _donkey.loader:mode('training') 45 | mutex = t.Mutex(mid) 46 | end 47 | ); 48 | else 49 | _donkey = dlt.Donkey(loader,pointSize,self.batchSize,self.useLocks, 50 | self.collectGarbage,self.tensorType) 51 | _donkey.loader:mode('training') 52 | self.datathreads = {} 53 | function self.datathreads:addjob(f1, f2) f2(f1()) end 54 | function self.datathreads:synchronize() end 55 | end 56 | -- Get nPoints and nBatches for datasets 57 | for _,datasetName in ipairs(self.datasets) do 58 | self[datasetName] = {} 59 | self.datathreads:addjob( 60 | function() 61 | return _donkey.loader:size(datasetName) 62 | end, 63 | function(nPoints) 64 | self[datasetName].nPoints = nPoints 65 | self[datasetName].nBatches = 66 | math.ceil(self[datasetName].nPoints / self.batchSize) 67 | end ) 68 | end 69 | self:syncThreads() 70 | -- Create batch of master thread (or gpu memory) 71 | local device = self.nGPU > 0 and 'gpu' or 'cpu' 72 | self.batch = dlt.help.createBatch(self.batchSize, pointSize, 73 | self.tensorType, device) 74 | 75 | -- Create Timers 76 | self.iterationTimer = torch.Timer() 77 | self.epochTimer = torch.Timer() 78 | self.transferTimer = torch.Timer() 79 | self.computeTimer = torch.Timer() 80 | -- Initializations 81 | self.currentSetID = 1 82 | self.currentPoint = {} 83 | for _,set in ipairs(self.datasets) do 84 | self.currentPoint[set] = 1 85 | end 86 | 87 | -- Report dataset sizes 88 | for _,set in ipairs(self.datasets) do 89 | dlt.log:yell( string.gsub(set,'^%l', string.upper) .. ' dataset: ' 90 | .. self[set].nPoints .. ' points, ' 91 | .. self[set].nBatches .. ' batches.') 92 | end 93 | dlt.log:endSection() 94 | end 95 | 96 | function D:iterate(callbacks) 97 | self.callbacks = callbacks 98 | self:syncGPU() 99 | self.iterationTimer:reset() 100 | repeat self:next() until self.terminate 101 | self:syncGPU() 102 | self:syncThreads() 103 | dlt.log:yell(self.terminateMessage) 104 | end 105 | 106 | 107 | 108 | function D:next() 109 | local currentSet = self:currentSet() 110 | local iPoint = self.currentPoint[currentSet] 111 | self:addjob(function() return _donkey:getBatch(iPoint,currentSet) end, 112 | function(donkeyBatch,donkeyTime) 113 | 114 | self.transferTimer:reset() 115 | 116 | self:syncGPU() 117 | self:transfer(donkeyBatch) 118 | 119 | local trt = self.transferTimer:time().real 120 | self.computeTimer:reset() 121 | self.terminate, self.terminateMessage = 122 | self.callbacks[self:currentSet()](self.batch) 123 | 124 | self:syncGPU() 125 | local computeTime = self.computeTimer:time().real 126 | local iterTime = self.iterationTimer:time().real 127 | self.iterationTimer:reset() 128 | dlt.log:detail(string.format( 129 | 'load: %.3fs, iteration: %.3fs,' .. 130 | ' transfer: %.3fs, compute: %.3fs.', 131 | donkeyTime, iterTime, trt, computeTime 132 | )) 133 | end) 134 | self:nextPoint() 135 | if self.callbacks.checkpoint then self.callbacks.checkpoint() end 136 | end 137 | 138 | -- Increase counter and raise flags (for epoch change, checkpointing) 139 | function D:nextPoint() 140 | local curP = self.currentPoint[self:currentSet()] + self.batchSize 141 | self.currentPoint[self:currentSet()] = curP 142 | 143 | if curP > self[self:currentSet()].nPoints then 144 | curP = (curP - 1) % self[self:currentSet()].nPoints + 1 145 | self.currentPoint[self:currentSet()] = curP 146 | self:nextSet() 147 | if self.currentSetID == 1 then self:nextEpoch() end 148 | end 149 | end 150 | 151 | function D:nextSet() 152 | self:syncThreads() 153 | self.currentSetID = self.currentSetID % #self.datasets + 1 154 | local currentSet = self:currentSet() 155 | if self.epochReshuffle then 156 | self:runOnAllThreads(function() 157 | _donkey.loader:reshuffle() 158 | end) 159 | end 160 | self:runOnAllThreads(function() 161 | _donkey.loader:mode(currentSet) 162 | end) 163 | end 164 | 165 | function D:currentSet() 166 | return self.datasets[self.currentSetID] 167 | end 168 | 169 | function D:nextEpoch() 170 | self:syncThreads() 171 | dlt.log:yell(string.format('Epoch %d took %.3fs to complete.', 172 | self.currentEpoch, self.epochTimer:time().real)) 173 | self.epochTimer:reset() 174 | self.currentEpoch = self.currentEpoch + 1 175 | -- if epochs are over then signal termination 176 | if self.currentEpoch > self.maxEpochs then 177 | self.terminate = true 178 | self.terminateMessage = 'Done with all epochs!' 179 | end 180 | end 181 | 182 | function D:transfer(donkeyBatch) 183 | if self.nGPU > 0 then dlt.help.copyPoint(donkeyBatch,self.batch) 184 | else self.batch = donkeyBatch end 185 | end 186 | 187 | function D:runOnAllThreads(fun,callback) 188 | if self.nThreads > 0 then 189 | self.datathreads:specific(true) 190 | for i=1,self.nThreads do self.datathreads:addjob(i,fun) end 191 | self.datathreads:specific(false) 192 | else 193 | callback = callback or function() end 194 | self.datathreads:addjob(fun,callback) 195 | end 196 | end 197 | 198 | function D:addjob(fun,callback) 199 | self.datathreads:addjob(fun,callback) 200 | end 201 | function D:syncThreads() 202 | self.datathreads:synchronize() 203 | end 204 | function D:syncGPU() 205 | if self.nGPU > 0 then 206 | cutorch.synchronizeAll() 207 | end 208 | end 209 | function D:getEpoch() 210 | return self.currentEpoch 211 | end -------------------------------------------------------------------------------- /src/core/settings.lua: -------------------------------------------------------------------------------- 1 | local dlt = require('dlt._env') 2 | 3 | -- Helps with parsing arguments and setting defaults 4 | function dlt.parse(out,extra,onlyExtra) 5 | if dlt.settings == nil then 6 | local seed = torch.random() 7 | local cmd = torch.CmdLine() 8 | -- First print extra settings 9 | if extra then 10 | cmd:text('User provided settings:') 11 | for _,val in ipairs(extra) do cmd:option(unpack(val)) end 12 | end 13 | cmd:text() 14 | if not onlyExtra then 15 | cmd:text('Global(ish) Settings:') 16 | cmd:option('-verbose', 3, 'Verbose level.') 17 | cmd:option('-makeLogFile', 'false', 'Whether log output is saved to file.') 18 | cmd:option('-defGPU', 1, 'Default GPU.') 19 | cmd:option('-tensorType', 'float', 'Tensor Type for model and optimizer. ') 20 | cmd:option('-batchSize', 128, 'Batch size.') 21 | cmd:option('-maxEpochs', 1000, 'Maximum number of epochs.') 22 | cmd:text() 23 | cmd:text('Data:') 24 | cmd:option('-useLocks', 'false', 'Whether to use locks before and after loading data in threads.') 25 | cmd:option('-collectGarbage', 50, 'Garbage collection frequency (per iteration) for each loader thread.') 26 | cmd:option('-nGPU', 0, 'Number of gpus.') 27 | cmd:option('-nThreads', 0, 'Number of threads.') 28 | cmd:option('-seed', seed, 'Seed (default is random).') 29 | cmd:option('-epochReshuffle', 'true', 'Whether to reshuffle every epoch.') 30 | cmd:text() 31 | cmd:text('Model:') 32 | cmd:option('-useCudnn', 'true', 'Whether to use cudnn.') 33 | cmd:option('-cudnnFastest', 'true', 'Whether to use cudnn Fastest.') 34 | cmd:option('-cudnnBenchmark', 'true', 'Whether to use cudnn Benchmark.') 35 | cmd:option('-cudnnVerbose', 'false', 'Whether cudnn is verbose.') 36 | cmd:option('-dptFlatten', 'true', 'Whether to use DPT flattenParameters.') 37 | cmd:option('-dptNccl', 'false', 'Whether to use DPT NCCL.') 38 | cmd:text() 39 | cmd:text('Trainer:') 40 | cmd:option('-savePath', 'save', 'Directory name for saved progress.') 41 | cmd:option('-saveAll', 'false', 'Whether to keep all saved checkpoints.') 42 | cmd:text() 43 | cmd:text('Dispatcher:') 44 | cmd:option('-experimentName', 'experiment', 'Name of experiment.') 45 | cmd:option('-runRoot', 'runRoot', 'Root path for runs.') 46 | cmd:text() 47 | cmd:text('Slurm:') 48 | cmd:option('-sTime', '48:00:00', 'Requested time hh:mm:ss.') 49 | cmd:option('-sNodes', 1, 'Nodes to request.') 50 | cmd:option('-sTasks', 1, 'Tasks to request.') 51 | cmd:option('-sPartition', 'gpu', 'Partition on cluster.') 52 | cmd:option('-sMempercpu', 32240, 'Memory per task.') 53 | cmd:option('-sMem', 62112, 'Total Memory.') 54 | cmd:option('-sGres', 'none', 'Generic resource to request.') 55 | cmd:option('-sExclude', 'none', 'Nodes to exclude.') 56 | cmd:option('-sRequest', 'none', 'Nodes to request.') 57 | cmd:option('-sJobname', 'job', 'Name of job.') 58 | cmd:option('-sOutname', 'default', 'Name of output slurm file.') 59 | cmd:option('-sEmail', 'none', 'Email address for notifications.') 60 | cmd:option('-sPrecommands', 'none', 'Commands to run before main script.') 61 | cmd:option('-sTh', 'none', 'Torch script to run with full path.') 62 | cmd:text() 63 | end 64 | dlt.settings = cmd:parse(arg) 65 | 66 | -- Handle booleans (convert strings 'true' 'false' to boolean) 67 | for _,val in ipairs{'useCudnn','cudnnFastest','cudnnBenchmark', 68 | 'cudnnVerbose','dptFlatten','dptNccl','saveAll', 69 | 'useLocks', 'makeLogFile'} do 70 | if dlt.settings[val] ~= nil then 71 | dlt.settings[val] = dlt.settings[val] == 'true' 72 | end 73 | end 74 | if extra then 75 | for _,val in ipairs(extra) do 76 | if val[2] == 'false' or val[2] == 'true' then 77 | dlt.settings[val[1]:sub(2,-1)] = 78 | dlt.settings[val[1]:sub(2,-1)] == 'true' 79 | end 80 | end 81 | end 82 | end 83 | if out then for key,val in pairs(dlt.settings) do out[key] = val end end 84 | 85 | return dlt.settings 86 | end 87 | 88 | function dlt.configure(s) 89 | 90 | -- Set verbose level 91 | dlt.log:setLevel(s.verbose) 92 | 93 | -- Make log file 94 | if s.makeLogFile and not dlt.__setLoggerFile then 95 | dlt.log:setFile(paths.concat(s.savePath,'log')) 96 | dlt.__setLoggerFile = true 97 | end 98 | 99 | -- Check GPU 100 | local availGPU = dlt.have.cutorch and cutorch.getDeviceCount() or 0 101 | if s.nGPU > availGPU then 102 | dlt.log:warning(string.format( 103 | 'Available GPUs are %d, setting nGPU to %d', 104 | availGPU,availGPU)) 105 | s.nGPU = availGPU 106 | end 107 | 108 | -- Check cudnn 109 | if s.useCudnn and not dlt.have.cudnn then 110 | dlt.log:warning('Cudnn could not be loaded make sure it is' .. 111 | ' installed. Switching cudnn use off.') 112 | s.useCudnn = false 113 | end 114 | 115 | -- Set seeds 116 | torch.manualSeed(s.seed) 117 | if s.nGPU > 0 then cutorch.manualSeedAll(s.seed) end 118 | 119 | -- Set default GPU 120 | if s.nGPU > 0 then 121 | if s.nGPU > 1 and s.defGPU ~= 1 then 122 | dlt.log:warning('For multi-GPU use GPU 1 as default. ' .. 123 | 'Setting defGPU = 1.') 124 | s.defGPU = 1 125 | end 126 | cutorch.setDevice(s.defGPU) 127 | end 128 | return s 129 | end 130 | 131 | function dlt.reportExperiment(settings) 132 | dlt.log:section('Settings for ' .. settings.experimentName) 133 | dlt.log:yell('Verbose level: ' .. settings.verbose) 134 | dlt.log:yell('Seed: ' .. settings.seed) 135 | dlt.log:yell('Max Epochs: ' .. settings.maxEpochs) 136 | if settings.nGPU > 0 then 137 | dlt.log:yell('Running on GPUs (' .. settings.nGPU .. ')') 138 | dlt.help.logAllGPUMemory(settings.nGPU) 139 | if settings.useCudnn then 140 | dlt.log:yell(string.format('Using cudnn with: fastest = %s,' 141 | ..' benchmark = %s, verbose = %s', 142 | tostring(settings.cudnnFastest), 143 | tostring(settings.cudnnBenchmark), 144 | tostring(settings.cudnnVerbose) ) ) 145 | else 146 | dlt.log:yell('Not using cudnn') 147 | end 148 | if settings.nGPU > 1 then 149 | dlt.log:yell(string.format( 150 | 'GPU parallelism using DataParallelTable with:' 151 | .. ' flattenParameters = %s, NCCL = %s', 152 | tostring(settings.dptFlatten), 153 | tostring(settings.dptNccl))) 154 | end 155 | else 156 | dlt.log:yell('Running on CPU') 157 | end 158 | 159 | if settings.saveAll then 160 | dlt.log:yell('Will save all checkpoints') 161 | else 162 | dlt.log:yell('Will only save latest checkpoint') 163 | end 164 | dlt.log:yell('Save path: ' .. settings.savePath) 165 | dlt.log:yell('Mini-batch size: ' .. settings.batchSize) 166 | if settings.nThreads == 1 then 167 | dlt.log:yell('Will be loading data using 1 thread.') 168 | elseif settings.nThreads > 1 then 169 | dlt.log:yell('Will be loading data using ' .. settings.nThreads 170 | .. ' threads.') 171 | else 172 | dlt.log:yell('Will not be using threads to load data.') 173 | end 174 | if settings.useLocks then 175 | dlt.log:yell('Locks are turned on.') 176 | else 177 | dlt.log:yell('Locks are turned off.') 178 | end 179 | dlt.log:endSection() 180 | end 181 | 182 | function dlt.writeSettingsToFile(s,fileName) 183 | local file = io.open(fileName, 'w+') 184 | for set,val in pairs(s) do 185 | file:write('-' .. set .. ' ' .. tostring(val) .. '\n') 186 | end 187 | file:close() 188 | end -------------------------------------------------------------------------------- /src/util/helper.lua: -------------------------------------------------------------------------------- 1 | local dlt = require('dlt._env') 2 | 3 | -- Helper functions 4 | dlt.help = {} 5 | local H = dlt.help 6 | 7 | -- Checks if path exists, tries to create it if not 8 | -- Returns useful error and terminates if it can not create it 9 | function H.checkMakeDir(path) 10 | if not paths.dirp(path) and not paths.mkdir(path) then 11 | dlt.log:error('Unable to create directory: ' .. path .. '\n') 12 | end 13 | end 14 | 15 | -- Checks if path given is or starts with '~' 16 | -- and replaces '~' with full path 17 | function H.checkHomePath(path) 18 | path = path:match('^~/?') and paths.concat(paths.home,path:sub(3,-1)) 19 | or path 20 | return path 21 | end 22 | 23 | -- Handling of types 24 | H.tensorList = { 25 | cpu = { 26 | byte = 'ByteTensor', 27 | char = 'CharTensor', 28 | short = 'ShortTensor', 29 | int = 'IntTensor', 30 | long = 'LongTensor', 31 | half = 'HalfTensor', 32 | float = 'FloatTensor', 33 | double = 'DoubleTensor' 34 | }, 35 | gpu = { 36 | byte = 'CudaByteTensor', 37 | char = 'CudaCharTensor', 38 | short = 'CudaShortTensor', 39 | int = 'CudaIntTensor', 40 | long = 'CudaLongTensor', 41 | half = 'CudaHalfTensor', 42 | float = 'CudaTensor', 43 | double = 'CudaDoubleTensor' 44 | }, 45 | pinned = { 46 | byte = 'createCudaHostByteTensor', 47 | int = 'createCudaHostIntTensor', 48 | long = 'createCudaHostLongTensor', 49 | half = 'createCudaHostHalfTensor', 50 | float = 'createCudaHostTensor', 51 | double = 'createCudaHostDoubleTensor' 52 | } 53 | } 54 | 55 | local typeList = {'byte','char','short','int', 56 | 'long','half','float','double'} 57 | local pinnedList = {'byte','int', 'long','half','float','double'} 58 | 59 | -- Checks if value v is in table t 60 | function H.inTable(t,v) 61 | for _,x in pairs(t) do 62 | if x == v then 63 | return true 64 | end 65 | end 66 | return false 67 | end 68 | 69 | -- Apply a function to all elements in array 70 | function H.apply(t,f) 71 | for i,v in ipairs(t) do 72 | t[i] = f(v) 73 | end 74 | return t 75 | end 76 | 77 | -- Creates a batch of size batchSize with given dimensions 78 | -- dimensions must be a table of names with corresponding dimensions 79 | -- e.g. dimensions = {input = {3,32,32}, output = {}} 80 | -- If the corresponding dimensions is an empty table ({}) e.g. above, 81 | -- the output member will just be one dimensional of size batchSize. 82 | -- This is used for compatibility with classification criteria 83 | -- dimensions can be tables (up to one level deep) 84 | -- e.g. for models which take a table input we might have 85 | -- dimensions = { input = { {3,32,32}, {3,64,64} } } 86 | -- tensorType can be of byte,char,short,int,long,half,float,double 87 | -- defaults to float 88 | -- device can be 'cpu' or 'gpu' 89 | function H.createBatch(batchSize, dimensions,tensorType, device, pinned) 90 | -- Check dimensions 91 | if not dimensions then 92 | dlt.log:error('dataPoint dimensions of experiment' .. 93 | ' not provided for creation of batch.') 94 | end 95 | if torch.type(dimensions) ~= 'table' then 96 | dlt.log:error('dimensions must be a table for createBatch.') 97 | end 98 | if next(dimensions) == nil then 99 | dlt.log:error('dimensions must be a non-empty table for createBatch.') 100 | end 101 | -- Default tensor is 'float' 102 | tensorType = tensorType or 'float' 103 | if not H.inTable(typeList,tensorType) then 104 | dlt.log:error('Unsupported type ' .. tensorType) 105 | end 106 | -- Configure device 107 | device = device or 'cpu' 108 | if device ~= 'gpu' and device ~= 'cpu' then 109 | dlt.log:error('Device must be cpu or gpu for createBatch') 110 | end 111 | -- Pinned only if gpu 112 | local supportPinned = device == 'gpu' 113 | -- Pinned not supported for short and char 114 | if pinned and not H.inTable(pinnedList,tensorType) then 115 | dlt.log:warning('Pinned memory not supported for ' .. tensorType .. 116 | '. Setting to false.' ) 117 | supportPinned = false 118 | end 119 | 120 | local retType 121 | -- Pinned 122 | if pinned and supportPinned then 123 | retType = cutorch[H.tensorList.pinned[tensorType]] 124 | else 125 | retType = torch[H.tensorList[device][tensorType]] 126 | end 127 | -- Create batch from dimensions 128 | local ret = {} 129 | for name,conf in pairs(dimensions) do 130 | if #conf == 0 then -- Classifier data described by empty table 131 | ret[name] = retType(batchSize) 132 | elseif torch.type(conf[1]) == 'table' then -- Subtables 133 | ret[name] = {} 134 | for i, val in ipairs(conf) do 135 | ret[name][i] = retType(batchSize,unpack(val)) 136 | end 137 | else 138 | ret[name] = retType(batchSize,unpack(conf)) 139 | end 140 | end 141 | return ret 142 | end 143 | 144 | -- Yells gpu memory for devID in GB 145 | function H.logGPUMemory(devID) 146 | local freeMemory, totalMemory = cutorch.getMemoryUsage(devID) 147 | local div = 1024*1024*1024 148 | local str = 'GPU %d: total - %.3fGB, free - %.3fGB.' 149 | dlt.log:yell(string.format(str, devID,totalMemory/div, freeMemory/div)) 150 | end 151 | 152 | function H.logAllGPUMemory(nGPU) 153 | for i = 1, nGPU do H.logGPUMemory(i) end 154 | end 155 | 156 | -- Copies a point to another (batches that were created by createBatch) 157 | -- Useful for transfering a batch from cpu to gpu 158 | function H.copyPoint(fromPoint,toPoint) 159 | for name,data in pairs(fromPoint) do 160 | if torch.type(data) == 'table' then 161 | for subname,subdata in pairs(data) do 162 | if toPoint[name][subname].copy then 163 | -- toPoint[name][subname]:resize(subdata:size()):copy(subdata) 164 | toPoint[name][subname]:copy(subdata) 165 | else 166 | toPoint[name][subname] = subdata 167 | end 168 | end 169 | else 170 | if toPoint[name].copy then 171 | -- toPoint[name]:resize(data:size()):copy(data) 172 | toPoint[name]:copy(data) 173 | else 174 | toPoint[name] = data 175 | end 176 | end 177 | end 178 | end 179 | 180 | -- Returns true if tensor has NaN 181 | function H.hasNaN(t) return t:ne(t):sum() ~= 0 end 182 | 183 | -- Returns the resulting dimensions of a spatial convolution 184 | function H.SpatialConvolutionSize(width,height,kW,kH,dW,dH,padW,padH) 185 | dW = dW or 1 186 | dH = dH or 1 187 | padW = padW or 0 188 | padH = padH or 0 189 | local owidth = torch.floor((width + 2*padW - kW) / dW + 1) 190 | local oheight = torch.floor((height + 2*padH - kH) / dH + 1) 191 | return owidth,oheight 192 | end 193 | function H.SpatialMaxPoolingSize(width,height,kW,kH,dW,dH,padW,padH) 194 | return H.SpatialConvolutionSize(width,height,kW,kH,dW,dH,padW,padH) 195 | end 196 | 197 | ---- tensor transformations 198 | ---- If function ends in _ then it is in-place 199 | 200 | -- Assumption 0 < a < b 201 | function H.normalize_(t,a,b) 202 | a = a or 0 203 | b = b or 1 204 | local tmin,tmax = t:min(),t:max() 205 | return t:add(-tmin):div(math.max((tmax-tmin)/(b-a),1e-4)):add(a) 206 | end 207 | function H.normalize(t,a,b) 208 | return H.normalize_(t:clone(),a,b) 209 | end 210 | 211 | -- Assumptions: 0 <= clampA < clampB <=1, see normalize_ 212 | function H.clampAndNormalize_(t,clampA,clampB,a,b) 213 | return H.normalize_(t:clamp(clampA,clampB),a,b) 214 | end 215 | function H.clampAndNormalize(t,clampA,clampB,a,b) 216 | return H.clampAndNormalize_(t:clone(),clampA,clampB,a,b) 217 | end 218 | 219 | -- Mean squared difference (t1 is changed for the in-place version) 220 | function H.mse_(t1,t2) 221 | return torch.sum(t1:add(-t2):pow(2):div(torch.numel(t1))) 222 | end 223 | function H.mse(t1,t2) 224 | return H.mse_(t1:clone(),t2) 225 | end 226 | 227 | -- PSNR (t1 is changed for the in-place version) 228 | function H.psnr_(t1,t2) 229 | return -(10/torch.log(10))*torch.log(H.mse_(t1,t2)) 230 | end 231 | function H.psnr(t1,t2) 232 | return -(10/torch.log(10))*torch.log(H.mse(t1,t2)) 233 | end 234 | 235 | -- Assumptions, has 3 dimensions, w,h are less than t's w and h 236 | -- returns a copy 237 | function H.randomCrop(t,w,h) 238 | local wstart = torch.random(t:size(2) - w + 1) 239 | local hstart = torch.random(t:size(3) - h + 1) 240 | return t[{{},{wstart, wstart + w - 1},{hstart,hstart + h - 1}}]:clone() 241 | end 242 | 243 | -- Flips horizontally with probability p and crops randomly 244 | -- returns a copy 245 | function H.hflipAndRandomCrop(t,w,h,p) 246 | p = p or 0.5 247 | local ret = torch.uniform() < p and image.hflip(t) or t 248 | return H.randomCrop(ret,w,h) 249 | end 250 | 251 | -- Flips horizontally with probability p 252 | function H.randomHFlip(t,p) 253 | p = p or 0.5 254 | local ret = torch.uniform() < p and image.hflip(t) or t 255 | return ret 256 | end 257 | 258 | -- Returns the values that are at low*100% and high*100% 259 | function H.getPercentClamping(img,low,high) 260 | local imgSize = img:size() 261 | local npix = imgSize[1]*imgSize[2]*imgSize[3] 262 | local oned = img:view(npix) 263 | local lowIndex = low*npix 264 | if lowIndex < 1 then lowIndex = 1 end 265 | local highIndex = high*npix 266 | if highIndex > npix then highIndex = npix end 267 | local lowRet = oned:kthvalue(lowIndex) 268 | local highRet = oned:kthvalue(highIndex) 269 | return lowRet[1], highRet[1] 270 | end 271 | 272 | 273 | 274 | -- local helper functions for getFiles 275 | -- clean removes '.' and '..' from an array of strings 276 | local function clean(files) 277 | for i = #files,1,-1 do 278 | if files[i]:sub(-1,-1) == '.' then 279 | table.remove(files,i) 280 | end 281 | end 282 | return files 283 | end 284 | -- calls paths.concat on all files to append directory and 285 | local function fullPath(directory,files) 286 | local ret = {} 287 | for i,val in pairs(files) do 288 | ret[i] = paths.concat(directory,files[i]) 289 | end 290 | return ret 291 | end 292 | -- Appends array l2 to l1 293 | local function mergeArrays(l1,l2) 294 | if #l1 == 0 then return l2 end 295 | if #l2 == 0 then return l1 end 296 | for _,val in ipairs(l2) do 297 | table.insert(l1,val) 298 | end 299 | return l1 300 | end 301 | 302 | -- getFiles gets all files from directory with given extensions 303 | -- recursively too if flag set 304 | function H.getFiles(directory,extensions,recursive) 305 | local fileList = fullPath(directory,clean(paths.dir( directory ))) 306 | local ret = clean(paths.dir( directory )) 307 | for i = #ret,1,-1 do 308 | if paths.dirp(ret[i]) or not extensions[ret[i]:sub(-3,-1)] then 309 | table.remove(ret,i) 310 | end 311 | end 312 | ret = fullPath(directory,ret) 313 | if recursive then 314 | for i = 1,#fileList do 315 | ret = paths.dirp(fileList[i]) and mergeArrays(ret,H.getFiles(fileList[i],extensions,recursive),fileList[i]) or ret 316 | end 317 | end 318 | return ret 319 | end 320 | 321 | 322 | 323 | return H -------------------------------------------------------------------------------- /doc/trainer.md: -------------------------------------------------------------------------------- 1 | # Trainer 2 | 3 | ## Usage 4 | ```lua 5 | trainer = dlt.Trainer(experiment) 6 | ``` 7 | * `experiment` Table with experiment configuration 8 | * Main functionality is `trainer:run()`. 9 | * Automatically creates loss logs, saves model and optimizer state checkpoints. 10 | * Upon resume, automatically continues from previous checkpoint. 11 | * Checkpoints every epoch AND according to `checkpointCondition`. 12 | 13 | ## Experiment Configuration 14 | 15 | ### `trainingCallback` 16 | A function `trainingCallback(state,batch)` that performs a training step on a batch. `state` gives access to the trainer (models, optimizers etc.). Overrides the default training callback set by `trainingType`. 17 | e.g. 18 | ```lua 19 | local function train(state,batch) 20 | local net, crit, opt = state.model, state.criterion, state.optimizer 21 | -- If net, crit, opt is a table of models/criteria/optimizers 22 | -- local netG, netD = net.generator, net.discriminator 23 | -- local critG, critD = crit.generator, crit.discriminator 24 | -- local optG, optD = opt.generator, opt.discriminator 25 | 26 | -- Do a normal step 27 | net.gradParameters:zero() 28 | -- Note: batch will have the fields set in pointSize 29 | local prediction = net:forward(batch.input) 30 | local loss = crit:forward(prediction, batch.output) 31 | local gradOutput = crit:backward(prediction, batch.output) 32 | net:backward(batch.input,gradOutput) 33 | opt:updateState(state.data.currentEpoch, loss) -- If we need to reduce lr etc.. 34 | opt:step( function() return loss,net.gradParameters end, net.parameters ) 35 | -- Access to log 36 | state.log.training:log(loss) 37 | end 38 | ``` 39 | 40 | ### `loader` 41 | A [dlt.Loader](loader.md) (Without `loader:init()` called before). 42 | ```lua 43 | experiment.loader = dlt.Mnist{path = '~/data/mnist'} 44 | ``` 45 | ### `pointSize` 46 | Table that describes each data point. 47 | ```lua 48 | -- For class predictions use empty table {} 49 | experiment.pointSize = {input = {1,32,32}, output = {}} 50 | ``` 51 | 52 | ### `trainingType` 53 | String. Currently supports: 54 | 55 | * `'simple'`: Iterates training set, minimizes one model given a criterion. 56 | * `'validate'`: Same as simple but goes through validation set too. 57 | * `'GAN'`: Generative adversarial networks (training set only) 58 | * `'WGAN'`: Wasserstein GAN (training set only) 59 | * `'BEGAN'`: Boundary Equilibrium GAN (training set only) 60 | 61 | **NOTE:** `'simple'` and `'validate'` assume that dataPoints have fields *input* and *output* while 62 | `'GAN'`,`'WGAN'`,`'BEGAN'` assume fields *input* (z goes into generator), *sample* (x goes into discriminator), 63 | *output* (y out of discriminator) 64 | 65 | ### `model` 66 | Table. Contents depend on `trainingType`. 67 | 68 | * For `'simple'` or `'validate'` provide [model.create](model.md) [ and model.name ]. 69 | * For `'GAN'`,`'WGAN'`,`'BEGAN'` provide tables model.generator and model.discriminator each with [.create](model.md) [ and .name ]. 70 | 71 | ```lua 72 | experiment.trainingType = 'simple' 73 | experiment.model = {create = functionThatCreatesModel, name = 'myAwesomeModel'} 74 | -- OR 75 | experiment.trainingType = 'GAN' 76 | experiment.model ={ generator = {create = makeGeneratorFunction, name = 'Generator'}, 77 | discriminator = {create = '~/savedModels/discriminatorOnDisk.t7', name = 'Discriminator'} } 78 | ``` 79 | 80 | ### `criterion` 81 | An nn.Criterion or anything with correctly defined :forward() :backward() and :type() 82 | * For `'simple'` or `'validate'` provide just the criterion 83 | * For `'GAN'`,`'WGAN'`,`'BEGAN'` provide a table with model.discriminator and criterion.generator (useful for fancy GAN losses). 84 | ```lua 85 | experiment.trainingType = 'simple' 86 | experiment.criterion = nn.MSECriterion() 87 | -- OR 88 | experiment.trainingType = 'GAN' 89 | experiment.criterion = {discriminator = nn.CrossEntropyCriterion} -- generator Criterion might not be needed 90 | ``` 91 | Note that `'WGAN'` and `'BEGAN'` do not need a criterion. 92 | 93 | ### `optim` 94 | * For `'simple'` or `'validate'` provide a table with a name (config and hook optional) 95 | * For `'GAN'`,`'WGAN'`,`'BEGAN'` provide table with optim.discriminator and optim.generator tables 96 | ```lua 97 | experiment.trainingType = 'simple' 98 | experiment.optim = {name = 'adam', config = {beta1 = 0.5} 99 | -- hook (if provided) must return updated state. 100 | -- Called before each optimizer update ('simple'/'validate' modes) 101 | hook = function(epoch,loss,currentState) 102 | if epoch == 2 then currentState.beta1 = 0.2 end -- Or something that is actually useful 103 | return currentState 104 | } 105 | -- OR 106 | experiment.trainingType = 'GAN' 107 | experiment.optim = {discriminator = {name = 'rmsprop', config = {learningRate = 5e-5}}, 108 | generator = {name = 'rmsprop', config = {learningRate = 1e-5}}} 109 | ``` 110 | 111 | ### `trainingHooks` 112 | Hooks for `'GAN'`,`'WGAN'`,`'BEGAN'` training, e.g. `onGeneratorTrainBegin(state)`. 113 | 114 | ### `checkpointCondition` 115 | 116 | Number or function(state) 117 | 118 | * If number then represents checkpointing frequency in minutes. 119 | * If function, must return true when a checkpoint is required, otherwise return false. 120 | * Takes one argument `checkpointCondition(state)` 121 | * `state` is the trainer. Care must be taken not to change the internal state. 122 | 123 | ### Miscellaneous 124 | * `nDFew`,`nDMany`,`manyInitial`,`manyFrequency` to set training schedule for `GAN` and `WGAN` (from [WGAN paper](https://arxiv.org/abs/1701.07875)) 125 | * `clampMin` [-0.01] and `clampMax` [0.01] FOR `WGAN` 126 | * `diversityRatio` [0.5], `ktVarInit` [0], `ktLearningRate` [0.001] and `loss` [nn.AbsCriterion()] for `BEGAN` 127 | 128 | ## Example 1 129 | 130 | File lenet.lua : 131 | ```lua 132 | local dlt = require('dlt') 133 | -- Train LeNet on MNIST. Yes, LeNet. Again. On MNIST. 134 | local experiment = { 135 | loader = dlt.Mnist{ 136 | path = '~/data/mnist', 137 | assignPoint = function(batch,i,img,cls) 138 | batch.input[i]:copy(img) 139 | batch.output[i] = cls 140 | end 141 | }, 142 | model = { create = dlt.models.lenet5, name = 'Lenet5' }, 143 | trainingType = 'validate', 144 | pointSize = {input = {1,32,32}, output = {}}, 145 | criterion = nn.CrossEntropyCriterion(), 146 | optim = {name = 'adadelta'} -- If not given it defaults to adam 147 | } 148 | dlt.Trainer(experiment):run() 149 | ``` 150 | 151 | Possible run: 152 | ```bash 153 | # Run on 2 GPUs, load data only on master thread with batch size of 1000 154 | # Save training log, models, optim state and checkpoint.t7 in ~/save/examples/1 155 | th lenet.lua -nGPU 2 -nThreads 0 -batchSize 1000 -savePath ~/save/examples/1 156 | ``` 157 | ## Example 2 158 | File colornet.lua : 159 | ```lua 160 | local dlt = require('dlt') 161 | -- Train Colornet on Places2 (aka Places365). 162 | -- Colornet paper: http://hi.cs.waseda.ac.jp/~iizuka/projects/colorization/data/colorization_sig2016.pdf 163 | local experiment = { 164 | loader = dlt.Places{ 165 | path = '~/data/places365', type = 'float', 166 | assignPoint = function(batch,i,img,cls) 167 | -- Scale to size used in paper 168 | img = image.scale(img,224,224) 169 | -- Get greyscale image 170 | local grey = image.rgb2y(img) 171 | -- Convert to Lab (also scaled to half width,height) 172 | local lab = image.rgb2lab(image.scale(img,112,112)) 173 | lab:add(108):div(208) -- normalize ab to [0,1] 174 | -- Model takes table input with two images (first is size invariant) 175 | batch.input[1][i]:copy(grey) 176 | batch.input[2][i]:copy(grey) 177 | -- Output is table with ab predictions and class 178 | batch.output[1][i]:copy(lab[{{2,3},{},{}}]) 179 | batch.output[2][i] = cls 180 | end 181 | }, 182 | model = { create = function() return dlt.models.colornet(224,224,365,1,2) end, name = 'Colornet' }, 183 | trainingType = 'simple', 184 | -- Pointsize supports tables! 185 | pointSize = { input = { {1,224,224} , {1,224,224} } , output = { {2,112,112}, {} } }, 186 | criterion = nn.ParallelCriterion():add(nn.MSECriterion()):add(nn.CrossEntropyCriterion(),1/300), 187 | optim = {name = 'adam'} 188 | } 189 | dlt.Trainer(experiment):run() 190 | ``` 191 | 192 | Possible run: 193 | ```bash 194 | # Run on GPU no. 2, load data on 8 threads, batch size 4 (In the paper they used 128 on a single K80 core) 195 | # Do not overwrite checkpoints (i.e. keep models and optim states from each checkpoint) 196 | # print timings 197 | th colornet.lua -nGPU 1 -defGPU 2 -nThreads 8 -batchSize 4 -saveAll true -verbose 5 198 | ``` 199 | ## Example 3 200 | File dcgan.lua : 201 | ```lua 202 | local dlt = require('dlt') 203 | -- Train DCGAN on MNIST 204 | -- Code adapted from https://github.com/soumith/dcgan.torch 205 | -- First define model creation (with weight init etc) 206 | local function weights_init(m) 207 | local name = torch.type(m) 208 | if name:find('Convolution') then 209 | m.weight:normal(0.0, 0.02) 210 | m:noBias() 211 | elseif name:find('BatchNormalization') then 212 | if m.weight then m.weight:normal(1.0, 0.02) end 213 | if m.bias then m.bias:fill(0) end 214 | end 215 | end 216 | local SBN, SConv, SFConv = nn.SpatialBatchNormalization, nn.SpatialConvolution, nn.SpatialFullConvolution 217 | local nc,nz,ndf,ngf = 3, 100, 64, 64 218 | local function makeGenerator() 219 | local netG = nn.Sequential() 220 | netG:add(nn.View(nz,1,1)) 221 | netG:add(SFConv(nz, ngf * 8, 4, 4)):add(SBN(ngf * 8)):add(nn.ReLU(true)) -- state size: (ngf*8) x 4 x 4 222 | :add(SFConv(ngf * 8, ngf * 4, 4, 4, 2, 2, 1, 1)):add(SBN(ngf * 4)):add(nn.ReLU(true)) -- state size: (ngf*4) x 8 x 8 223 | :add(SFConv(ngf * 4, ngf * 2, 4, 4, 2, 2, 1, 1)):add(SBN(ngf * 2)):add(nn.ReLU(true)) -- state size: (ngf*2) x 16 x 16 224 | :add(SFConv(ngf * 2, ngf, 4, 4, 2, 2, 1, 1)):add(SBN(ngf)):add(nn.ReLU(true)) -- state size: (ngf) x 32 x 32 225 | :add(SFConv(ngf , nc, 4, 4, 2, 2, 1, 1)):add(nn.Tanh()) -- state size: (nc) x 64 x 64 226 | netG:apply(weights_init) 227 | return netG 228 | end 229 | local function makeDiscriminator() 230 | local netD = nn.Sequential() -- input is (nc) x 64 x 64 231 | netD:add(SConv(nc, ndf, 4, 4, 2, 2, 1, 1)):add(nn.LeakyReLU(0.2, true)) -- state size: (ndf) x 32 x 32 232 | :add(SConv(ndf, ndf * 2, 4, 4, 2, 2, 1, 1)):add(SBN(ndf * 2)):add(nn.LeakyReLU(0.2, true)) -- state size: (ndf*2) x 16 x 16 233 | :add(SConv(ndf * 2, ndf * 4, 4, 4, 2, 2, 1, 1)):add(SBN(ndf * 4)):add(nn.LeakyReLU(0.2, true)) -- state size: (ndf*4) x 8 x 8 234 | :add(SConv(ndf * 4, ndf * 8, 4, 4, 2, 2, 1, 1)):add(SBN(ndf * 8)):add(nn.LeakyReLU(0.2, true)) -- state size: (ndf*4) x 4 x 4 235 | :add(SConv(ndf * 8, 1, 4, 4)):add(nn.Sigmoid()) -- state size: 1 x 1 x 1 236 | :add(nn.View(1):setNumInputDims(3)) -- state size: 1 237 | netD:apply(weights_init) 238 | return netD 239 | end 240 | 241 | -- Define and run experiment 242 | local experiment = { 243 | loader = dlt.CelebA{ 244 | path = '~/data/celeba', 245 | type = 'float', 246 | assignPoint = function(batch,i,img,cls) 247 | img = image.scale(image.crop(img,'c',178,178),64,64) -- crop and resize 248 | batch.sample[i]:copy(img:mul(2):add(-1)) 249 | end 250 | }, 251 | trainingHooks = { getGeneratorInput = function(batch) return batch.input:normal(0,1) end }, 252 | model = { discriminator = { create = makeDiscriminator, name = 'Discriminator' }, 253 | generator = { create = makeGenerator, name = 'Generator' }}, 254 | trainingType = 'GAN', 255 | checkpointCondition = 1, -- Checkpoint every minute 256 | pointSize = { input = {nz} , sample = {nc,64,64}, output = {} }, 257 | criterion = {discriminator = nn.BCECriterion()}, 258 | optim = {discriminator = {name = 'adam', config = {learningRate = 2e-4, beta1 = 0.5}}, 259 | generator = { name = 'adam', config = {learningRate = 2e-4, beta1 = 0.5} } } 260 | } 261 | dlt.Trainer(experiment):run() 262 | ``` 263 | 264 | Possible run: 265 | ```bash 266 | # Run on GPU, load data on 4 threads, batch size 8 267 | # Do not overwrite checkpoints (i.e. keep models and optim states from each checkpoint) 268 | th dcgan.lua -nGPU 1 -nThreads 4 -batchSize 8 -saveAll true 269 | ``` -------------------------------------------------------------------------------- /src/core/trainer.lua: -------------------------------------------------------------------------------- 1 | local dlt = require('dlt._env') 2 | 3 | local T,parent = torch.class('dlt.Trainer',dlt) 4 | 5 | function T:__init(experiment) 6 | dlt.parse(self) 7 | dlt.configure(self) 8 | dlt.reportExperiment(self) 9 | dlt.log:section('Trainer initialization') 10 | 11 | -- Generic settings 12 | self.useGPU = self.nGPU > 0 13 | self.savePath = self.savePath 14 | -- format 15 | self.format = self.nGPU > 0 and 'gpu' or 'cpu' 16 | self.format = 'torch.' .. dlt.help.tensorList[self.format][self.tensorType] 17 | 18 | -- Configure training type 19 | local trainingTypes = {'simple', 'validate', 'GAN', 20 | 'WGAN','BEGAN','custom'} 21 | self.trainingType = experiment.trainingType or 'simple' 22 | if not dlt.help.inTable(trainingTypes,self.trainingType) then 23 | dlt.log:error('Unkown training type: ' .. self.trainingType) 24 | end 25 | self.datasets = self.trainingType == 'validate' 26 | and {'training','validation'} 27 | or {'training'} 28 | 29 | -- Load checkpoint 30 | self:loadCheckpoint() 31 | 32 | -- Create data 33 | self.data = dlt.Data(experiment.loader,experiment.pointSize, 34 | self.datasets,self.currentEpoch) 35 | 36 | -- Setup training 37 | if self.trainingType == 'simple' or self.trainingType == 'validate' then 38 | self.trainingCallback = self.standardCallback 39 | local modelCreate = self.modelFile or experiment.model.create 40 | self.model = dlt.Model(modelCreate,experiment.model.name, 41 | experiment.model.save) 42 | self.optimizer = dlt.Optimizer(experiment.optim,self.tensorType, 43 | self.useGPU,self.optimFile) 44 | self.criterion = experiment.criterion 45 | self.criterion:type(self.format) 46 | self.log = { 47 | training = dlt.Trainlog('training', self.savePath) 48 | } 49 | if self.trainingType == 'validate' then 50 | self.log.validation = dlt.Trainlog('validation', self.savePath) 51 | end 52 | elseif dlt.help.inTable({'GAN','WGAN'},self.trainingType) then 53 | self.model, self.optimizer = {},{} 54 | self.trainingCallback = self[self.trainingType .. 'Callback'] 55 | 56 | if self.trainingType == 'WGAN' then 57 | local clamp = experiment.clamp or {} 58 | self.clampMin = clamp[1] or -0.01 59 | self.clampMax = clamp[2] or 0.01 60 | -- Training frequencies 61 | self.nDSteps, self.nGSteps = 0, 0 62 | local steps = experiment.steps or {} 63 | local defSteps = {5,100,25,100} 64 | for i,val in ipairs(defSteps) do steps[i] = steps[i] or val end 65 | self.trainGenerator = function() 66 | local ng = self.nGSteps 67 | local many = ((ng % steps[4] == 0) or (ng < steps[3])) 68 | local size = many and steps[2] or steps[1] 69 | if self.nDSteps == size - 1 then 70 | self.nGSteps = self.nGSteps + 1 71 | self.nDSteps = 0 72 | return true 73 | else 74 | self.nDSteps = self.nDSteps + 1 75 | return false 76 | end 77 | end 78 | 79 | end 80 | -- Create models,optimizers and criterions 81 | self.modelFile = self.modelFile or {} 82 | self.optimFile = self.optimFile or {} 83 | self.criterion = experiment.criterion 84 | for _,val in ipairs{'discriminator','generator'} do 85 | local modelCreate = self.modelFile[val] 86 | or experiment.model[val].create 87 | local modelName = experiment.model[val].name or val 88 | self.model[val] = dlt.Model(modelCreate,modelName, 89 | experiment.model[val].save) 90 | self.optimizer[val] = dlt.Optimizer(experiment.optim[val], 91 | self.tensorType,self.useGPU, 92 | self.optimFile[val]) 93 | if self.criterion and self.criterion[val] then 94 | self.criterion[val]:type(self.format) 95 | end 96 | end 97 | self.log = { 98 | discriminator = dlt.Trainlog('discriminator', self.savePath), 99 | generator = dlt.Trainlog('generator', self.savePath) 100 | } 101 | elseif self.trainingType == 'BEGAN' then 102 | self.model, self.optimizer = {},{} 103 | self.trainingCallback = self.BEGANCallback 104 | self.modelFile = self.modelFile or {} 105 | self.optimFile = self.optimFile or {} 106 | for _,val in ipairs{'discriminator','generator'} do 107 | local modelCreate = self.modelFile[val] 108 | or experiment.model[val].create 109 | local modelName = experiment.model[val].name or val 110 | self.model[val] = dlt.Model(modelCreate,modelName, 111 | experiment.model[val].save) 112 | self.optimizer[val] = dlt.Optimizer(experiment.optim[val], 113 | self.tensorType,self.useGPU, 114 | self.optimFile[val]) 115 | end 116 | 117 | self.diversityRatio = experiment.diversityRatio or 0.5 118 | self.ktVar = experiment.ktVarInit or 0 119 | self.ktLearningRate = experiment.ktLearningRate or 0.001 120 | self.loss = experiment.loss or nn.AbsCriterion() 121 | self.loss:type(self.format) 122 | 123 | self.log = {} 124 | for _,val in ipairs{'discriminator','generator','autoencoder', 125 | 'kt','convergence'} do 126 | self.log[val] = dlt.Trainlog(val,self.savePath) 127 | end 128 | elseif self.trainingType == 'custom' then 129 | -- Replace default callbacks if given 130 | if experiment.trainingCallback == nil then 131 | dlt.log:error('Trainer custom mode requires training callback.') 132 | end 133 | self.trainingCallback = experiment.trainingCallback 134 | -- Model 135 | if experiment.model.create then 136 | local modelCreate = self.modelFile or experiment.model.create 137 | self.model = dlt.Model(modelCreate,experiment.model.name, 138 | experiment.model.save) 139 | else -- multiple 140 | self.model = {} 141 | self.modelFile = self.modelFile or {} 142 | for name,val in pairs(experiment.model) do 143 | local modelCreate = self.modelFile[name] or val.create 144 | local modelName = val.name or name 145 | self.model[name] = dlt.Model(modelCreate,modelName,val.save) 146 | end 147 | end 148 | -- Criterion 149 | if torch.type(experiment.criterion) ~= 'table' then 150 | self.criterion = experiment.criterion 151 | self.criterion:type(self.format) 152 | else -- multiple 153 | self.criterion = {} 154 | for name,val in pairs(experiment.criterion) do 155 | self.criterion[name] = val 156 | self.criterion[name]:type(self.format) 157 | end 158 | end 159 | -- Optimizer 160 | if not experiment.optim or experiment.optim.name then 161 | self.optimizer = dlt.Optimizer(experiment.optim,self.tensorType, 162 | self.useGPU,self.optimFile) 163 | else -- multiple 164 | self.optimizer = {} 165 | self.optimFile = self.optimFile or {} 166 | for name,val in pairs(experiment.optim) do 167 | self.optimizer[name] = dlt.Optimizer(val,self.tensorType, 168 | self.useGPU, 169 | self.optimFile[name]) 170 | end 171 | end 172 | -- Logs 173 | if experiment.log then 174 | self.log = {} 175 | for _,name in pairs(experiment.log) do 176 | self.log[name] = dlt.Trainlog(name,self.savePath) 177 | end 178 | end 179 | 180 | end 181 | 182 | -- Replace default callbacks if given 183 | if experiment.trainingCallback then 184 | dlt.log:yell('Replacing default training callback with given.') 185 | self.trainingCallback = experiment.trainingCallback 186 | end 187 | 188 | -- Checkpoint condition 189 | self.checkpointCondition = experiment.checkpointCondition 190 | or function() return false end 191 | 192 | if torch.type(self.checkpointCondition) == 'number' 193 | and self.checkpointCondition <=0 then 194 | dlt.log:error('Checkpoint Condition must be positive number' .. 195 | ' (minutes) or function') 196 | end 197 | if torch.type(self.checkpointCondition) ~= 'function' 198 | and torch.type(self.checkpointCondition) ~= 'number' then 199 | dlt.log:error('Checkpoint Condition must be positive number' .. 200 | ' (minutes) or function') 201 | end 202 | if torch.type(self.checkpointCondition) == 'number' then 203 | local checkpointTimer = torch.Timer() 204 | local checkpointFrequency = self.checkpointCondition 205 | self.checkpointCondition = function() 206 | local ret = checkpointTimer:time().real > 60*checkpointFrequency 207 | if ret then checkpointTimer:reset() end 208 | return ret 209 | end 210 | else end 211 | 212 | dlt.log:endSection() 213 | end 214 | 215 | function T:run() self:train() end 216 | 217 | function T:train() 218 | dlt.log:section('Training') 219 | self.data:iterate( { 220 | training = function(batch) return self:trainingCallback(batch) end, 221 | validation = function(batch) return self:validationCallback(batch) end, 222 | checkpoint = function() self:checkpoint() end 223 | } ) 224 | dlt.log:endSection() 225 | end 226 | 227 | 228 | function T:validationCallback(batch) 229 | self.model:evaluate() 230 | local prediction = self.model:forward(batch.input) 231 | self.log.validation:log(self.criterion:forward(prediction,batch.output)) 232 | end 233 | 234 | -- THIS IS THE DEFAULT MAIN OPTIMIZATION FUNCTION! 235 | -- requires batch = {input = ..., output = ...} 236 | function T:standardCallback(batch) 237 | self.model:zeroGradParameters() 238 | local prediction = self.model:forward(batch.input) 239 | local loss = self.criterion:forward(prediction, batch.output) 240 | self.log.training:log(loss) 241 | local gradOutput = self.criterion:backward(prediction, batch.output) 242 | self.model.model:backward(batch.input,gradOutput) 243 | self.optimizer:updateState(self.data.currentEpoch, loss) 244 | self.optimizer:step( function() 245 | return loss,self.model.gradParameters 246 | end, 247 | self.model.parameters ) 248 | end 249 | 250 | 251 | -- GAN 252 | -- Assumes batch is 'input' --> G --> 'sample' --> D --> 'output' 253 | function T:GANCallback(batch) 254 | -- Some shortcuts 255 | local D = self.model.discriminator 256 | local critD = self.criterion.discriminator 257 | local optD = self.optimizer.discriminator 258 | local G = self.model.generator 259 | local optG = self.optimizer.generator 260 | 261 | D:training() 262 | G:training() 263 | 264 | -- Discriminator optimization 265 | -- Real data 266 | D.gradParameters:zero() 267 | local rPred = D:forward(batch.sample) 268 | local rDLoss, dGradOut = critD(realPrediction, batch.output:fill(1)) 269 | D:backward(batch.sample,dGradOut) 270 | 271 | -- Fake data 272 | local gPred = G:forward(batch.input:uniform(-1,1)) 273 | local fPred = D:forward(gPred) 274 | local fDLoss, dGradOut = critD(fPred, batch.output:fill(0)) 275 | D:backward(gPred,dGradOut) 276 | 277 | local dLoss = rDLoss + fDLoss 278 | self.log.discriminator:log(dLoss) 279 | optD:step(function() return dLoss, D.gradParameters end, D.parameters) 280 | 281 | -- Generator optimization 282 | G.gradParameters:zero() 283 | -- Invert labels (to minimize positive instead of maximizing negative) 284 | local gLoss, dGradOut = critD(fPred, batch.output:fill(1)) 285 | local gGradOut = D:updateGradInput(gPred,dGradOut) 286 | G:backward(batch.input,gGradOut) 287 | self.log.generator:log(gLoss) 288 | optG:step(function() return gLoss, G.gradParameters end, G.parameters) 289 | 290 | end 291 | 292 | 293 | -- Wasserstein GAN 294 | -- Assumes batch is 'input' --> G --> 'sample' --> D --> 'output' 295 | function T:WGANCallback(batch) 296 | -- Some shortcuts 297 | local D = self.model.discriminator 298 | local optD = self.optimizer.discriminator 299 | local G = self.model.generator 300 | local optG = self.optimizer.generator 301 | local batchSize = batch.sample:size(1) 302 | -- Make sure we are in training mode 303 | D:training() 304 | G:training() 305 | 306 | -- Discriminator optimization 307 | -- Real data 308 | D.gradParameters:zero() 309 | local rPred = D:forward(batch.sample) 310 | D:backward(batch.sample,batch.output:fill(1)) 311 | local rDLoss = rPred:mean() 312 | 313 | -- Fake data 314 | local gPred = G:forward(batch.input:normal(0,1)) 315 | local fPred = D:forward(gPred) 316 | local gGradOut = D:backward(gPred,batch.output:fill(-1)) 317 | local fDLoss = fPred:mean() 318 | 319 | local wDistance = rDLoss - fDLoss 320 | self.log.discriminator:log(wDistance) 321 | D.gradParameters:mul(-1/batchSize) 322 | optD:step(function() return -wDistance, D.gradParameters end, D.parameters) 323 | D.parameters:clamp(self.clampMin,self.clampMax) 324 | -- Generator optimization 325 | if self.trainGenerator() then 326 | G.gradParameters:zero() 327 | G:backward(batch.input,gGradOut) 328 | local gLoss = fDLoss 329 | self.log.generator:log(gLoss) 330 | G.gradParameters:div(batchSize) 331 | optG:step(function() return gLoss, G.gradParameters end, G.parameters) 332 | end 333 | end 334 | 335 | -- Boundary Equilibrium GAN (BEGAN) 336 | function T:BEGANCallback(batch) 337 | -- Some shortcuts 338 | local D = self.model.discriminator 339 | local optD = self.optimizer.discriminator 340 | local G = self.model.generator 341 | local optG = self.optimizer.generator 342 | 343 | G:training() 344 | D:training() 345 | 346 | D.gradParameters:zero() 347 | self.dGradParamBuffer = self.dGradParamBuffer or D.gradParameters:clone() 348 | self.dGradParamBuffer:zero() 349 | -- Discriminator loss 350 | -- real 351 | local rPred = D:forward(batch.sample) 352 | local rLoss, dGradOut = self.loss(rPred,batch.sample) 353 | D:backward(batch.sample,dGradOut) 354 | self.dGradParamBuffer:add(D.gradParameters) 355 | -- fake 356 | D.gradParameters:zero() 357 | G.gradParameters:zero() 358 | local gPred = G:forward(batch.input:uniform(-1,1)) 359 | local fPred = D:forward(gPred) 360 | local fLoss, dGradOut = self.loss(fPred,gPred) 361 | local gGradOut = D:backward(gPred,dGradOut) 362 | self.dGradParamBuffer:add(D.gradParameters:mul(-self.ktVar)) 363 | 364 | local dLoss = rLoss - self.ktVar*fLoss 365 | 366 | -- Generator loss 367 | -- SHOULD I RESAMPLE Z? 368 | local gLoss = fLoss 369 | G:backward(batch.input,gGradOut) 370 | 371 | -- Log 372 | self.log.autoencoder:log(rLoss) 373 | self.log.discriminator:log(dLoss) 374 | self.log.generator:log(gLoss) 375 | self.log.kt:log(self.ktVar) 376 | local balance = self.diversityRatio*rLoss - gLoss 377 | self.log.convergence:log(rLoss + torch.abs(balance)) 378 | 379 | -- Perform updates 380 | optD:step(function() return dLoss ,self.dGradParamBuffer end, D.parameters) 381 | optG:step(function() return gLoss, G.gradParameters end, G.parameters) 382 | self.ktVar = self.ktVar + self.ktLearningRate*balance 383 | self.ktVar = math.max(math.min(self.ktVar,1),0) 384 | end 385 | 386 | --- Checkpointing functionality 387 | function T:checkpoint() 388 | if self.currentEpoch ~= self.data.currentEpoch 389 | or self.checkpointCondition(self) then 390 | self:saveCheckpoint() 391 | end 392 | self.currentEpoch = self.data.currentEpoch 393 | end 394 | 395 | function T:makeFilename(name) 396 | local ret = name; 397 | if self.saveAll then 398 | ret = ret .. string.format(self.chkpFormat,self.chkpCount) 399 | end 400 | return paths.concat(self.savePath,ret .. '.t7') 401 | end 402 | 403 | function T:loadCheckpoint() 404 | self.chkpFile = paths.concat(self.savePath,'checkpoint.t7') 405 | self.chkpFormat = '%05d' 406 | if paths.filep(self.chkpFile) then 407 | local latest = torch.load(self.chkpFile) 408 | if latest.currentEpoch >= self.maxEpochs then 409 | dlt.log:error('Already did all Epochs!') 410 | end 411 | self.chkpCount = latest.chkpCount + 1 412 | self.modelFile = latest.model 413 | self.optimFile = latest.optimizer 414 | self.currentEpoch = latest.currentEpoch 415 | else 416 | self.currentEpoch = 1 417 | self.chkpCount = 1 418 | end 419 | end 420 | 421 | function T:saveCheckpoint() 422 | self.data:syncThreads() 423 | self.data:syncGPU() 424 | local checkpoint = {} 425 | for _,val in ipairs({'model','optimizer'}) do 426 | if torch.type(self[val]) == 'table' then 427 | checkpoint[val] = {} 428 | for objName,obj in pairs(self[val]) do 429 | local fileName = self:makeFilename(objName .. '_' .. val) 430 | obj:save(fileName) 431 | checkpoint[val][objName] = fileName 432 | end 433 | else 434 | local fileName = self:makeFilename(val) 435 | self[val]:save(fileName) 436 | checkpoint[val] = fileName 437 | end 438 | end 439 | checkpoint.chkpCount = self.chkpCount 440 | checkpoint.currentEpoch = self.data.currentEpoch 441 | torch.save(self.chkpFile,checkpoint) 442 | dlt.log:yell('Saved checkpoint ' .. self.chkpCount) 443 | self.chkpCount = self.chkpCount + 1 444 | collectgarbage() 445 | collectgarbage() 446 | end --------------------------------------------------------------------------------