├── .gitignore ├── .luacov ├── .wercker.yml ├── CMakeLists.txt ├── README.md ├── docs ├── graph │ ├── DirectedGraph.md │ ├── Graph.md │ └── UndirectedGraph.md ├── index.md ├── list │ ├── ArrayList.md │ ├── Heap.md │ ├── LinkedList.md │ ├── List.md │ ├── Queue.md │ └── Stack.md ├── map │ ├── Counter.md │ ├── HashMap.md │ └── Map.md ├── ml │ ├── Dataset.md │ ├── Experiment.md │ ├── GloveVocab.md │ ├── Model.md │ ├── ProbTable.md │ ├── Scorer.md │ ├── VariableTensor.md │ └── Vocab.md ├── set │ └── Set.md ├── tree │ ├── BinarySearchTree.md │ ├── BinaryTree.md │ └── Tree.md └── util │ ├── Download.md │ ├── global.md │ ├── string.md │ └── table.md ├── init.lua ├── mkdocs.yml ├── src ├── graph │ ├── DirectedGraph.lua │ ├── Graph.lua │ └── UndirectedGraph.lua ├── list │ ├── ArrayList.lua │ ├── Heap.lua │ ├── LinkedList.lua │ ├── List.lua │ ├── Queue.lua │ └── Stack.lua ├── map │ ├── Counter.lua │ ├── HashMap.lua │ └── Map.lua ├── ml │ ├── Dataset.lua │ ├── Experiment.lua │ ├── GloveVocab.lua │ ├── Model.lua │ ├── ProbTable.lua │ ├── Scorer.lua │ ├── VariableTensor.lua │ └── Vocab.lua ├── set │ └── Set.lua ├── tree │ ├── BinarySearchTree.lua │ ├── BinaryTree.lua │ └── Tree.lua └── util │ ├── Download.lua │ ├── global.lua │ ├── string.lua │ └── table.lua ├── test.lua ├── test ├── mock │ └── conll.mock ├── test_counter.lua ├── test_dataset.lua ├── test_download.lua ├── test_graph.lua ├── test_heap.lua ├── test_list.lua ├── test_map.lua ├── test_model.lua ├── test_prob_table.lua ├── test_queue.lua ├── test_scorer.lua ├── test_set.lua ├── test_stack.lua ├── test_tree.lua ├── test_util.lua ├── test_variable_tensor.lua └── test_vocab.lua └── torchlib-scm-1.rockspec /.gitignore: -------------------------------------------------------------------------------- 1 | # Object files 2 | *.o 3 | *.ko 4 | *.obj 5 | *.elf 6 | 7 | # Libraries 8 | *.lib 9 | *.a 10 | 11 | # Shared objects (inc. Windows DLLs) 12 | *.dll 13 | *.so 14 | *.so.* 15 | *.dylib 16 | 17 | # Executables 18 | *.exe 19 | *.out 20 | *.app 21 | *.i*86 22 | *.x86_64 23 | *.hex 24 | 25 | # build 26 | build/ 27 | 28 | # docs 29 | site/ 30 | 31 | glove/ 32 | .DS_Store 33 | -------------------------------------------------------------------------------- /.luacov: -------------------------------------------------------------------------------- 1 | --- Global configuration file. Copy, customize and store in your 2 | -- project folder as '.luacov' for project specific configuration. 3 | -- If some options are missing, their default values from this file 4 | -- will be used. 5 | -- @class module 6 | -- @name luacov.defaults 7 | return { 8 | 9 | -- default filename to load for config options if not provided 10 | -- only has effect in 'luacov.defaults.lua' 11 | ["configfile"] = ".luacov", 12 | 13 | -- filename to store stats collected 14 | ["statsfile"] = "luacov.stats.out", 15 | 16 | -- filename to store report 17 | ["reportfile"] = "luacov.report.out", 18 | 19 | -- luacov.stats file updating frequency. 20 | -- The lower this value - the more frequenty results will be written out to luacov.stats 21 | -- You may want to reduce this value for short lived scripts (to for example 2) to avoid losing coverage data. 22 | ["savestepsize"] = 100, 23 | 24 | -- Run reporter on completion? (won't work for ticks) 25 | runreport = true, 26 | 27 | -- Delete stats file after reporting? 28 | deletestats = false, 29 | 30 | -- Process Lua code loaded from raw strings 31 | -- (that is, when the 'source' field in the debug info 32 | -- does not start with '@') 33 | codefromstrings = false, 34 | 35 | -- Patterns for files to include when reporting 36 | -- all will be included if nothing is listed 37 | -- (exclude overrules include, do not include 38 | -- the .lua extension, path separator is always '/') 39 | ["include"] = {'torchlib'}, 40 | 41 | -- Patterns for files to exclude when reporting 42 | -- all will be included if nothing is listed 43 | -- (exclude overrules include, do not include 44 | -- the .lua extension, path separator is always '/') 45 | ["exclude"] = { 46 | "luacov$", 47 | "luacov/reporter$", 48 | "luacov/defaults$", 49 | "luacov/runner$", 50 | "luacov/stats$", 51 | "luacov/tick$", 52 | }, 53 | 54 | 55 | } 56 | -------------------------------------------------------------------------------- /.wercker.yml: -------------------------------------------------------------------------------- 1 | # This references a standard debian container from the 2 | # Docker Hub https://registry.hub.docker.com/_/debian/ 3 | # Read more about containers on our dev center 4 | # http://devcenter.wercker.com/docs/containers/index.html 5 | box: kaixhin/cuda-torch 6 | # You can also use services such as databases. Read more on our dev center: 7 | # http://devcenter.wercker.com/docs/services/index.html 8 | # services: 9 | # - postgres 10 | # http://devcenter.wercker.com/docs/services/postgresql.html 11 | 12 | # - mongodb 13 | # http://devcenter.wercker.com/docs/services/mongodb.html 14 | 15 | # This is the build pipeline. Pipelines are the core of wercker 16 | # Read more about pipelines on our dev center 17 | # http://devcenter.wercker.com/docs/pipelines/index.html 18 | build: 19 | # Steps make up the actions in your pipeline 20 | # Read more about steps on our dev center: 21 | # http://devcenter.wercker.com/docs/steps/index.html 22 | steps: 23 | - script: 24 | name: install 25 | code: | 26 | apt-get -yqq update 27 | apt-get install -yqq wget 28 | luarocks make 29 | - script: 30 | name: test and generate coverage 31 | code: | 32 | luarocks install luacov 33 | th -lluacov test.lua 34 | bash <(curl -s https://codecov.io/bash) || echo "Codecov did not collect coverage reports" 35 | - script: 36 | name: generate markdown 37 | code: | 38 | git clone https://github.com/vzhong/docroc.git 39 | cd docroc && luarocks make && cd - 40 | th docroc/docroc src docs --index README.md --config mkdocs.yml --github_src_dir http://github.com/vzhong/torchlib/blob/master/src/ 41 | deploy: 42 | steps: 43 | - script: 44 | name: install dependencies 45 | code: | 46 | apt-get -yqq update 47 | apt-get install -yqq python python-pip 48 | apt-get install -yqq git 49 | pip install mkdocs 50 | - script: 51 | name: build docs 52 | code: | 53 | mkdocs build --clean 54 | touch site/.nojekyll 55 | - lukevivier/gh-pages: 56 | token: $GH_TOKEN 57 | basedir: site 58 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR) 3 | CMAKE_POLICY(VERSION 2.6) 4 | IF(LUAROCKS_PREFIX) 5 | MESSAGE(STATUS "Installing Torch through Luarocks") 6 | STRING(REGEX REPLACE "(.*)lib/luarocks/rocks.*" "\\1" CMAKE_INSTALL_PREFIX "${LUAROCKS_PREFIX}") 7 | MESSAGE(STATUS "Prefix inferred from Luarocks: ${CMAKE_INSTALL_PREFIX}") 8 | ENDIF() 9 | FIND_PACKAGE(Torch REQUIRED) 10 | 11 | FILE(GLOB luasrc *.lua) 12 | ADD_TORCH_PACKAGE(torchlib "" "${luasrc}" "Torch NLP") 13 | 14 | INSTALL(DIRECTORY "src" DESTINATION "${Torch_INSTALL_LUA_PATH_SUBDIR}/torchlib") 15 | FILE(GLOB luasrc *.lua) 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Torchlib 2 | 3 | [![wercker status](https://app.wercker.com/status/c7bd97d06535598d96937e0cf5ace629/s/master "wercker status")](https://app.wercker.com/project/bykey/c7bd97d06535598d96937e0cf5ace629) 4 | [![codecov](https://codecov.io/gh/vzhong/torchlib/branch/master/graph/badge.svg)](https://codecov.io/gh/vzhong/torchlib) 5 | 6 | [View documentation](http://torchlib.github.io). 7 | 8 | Data structures and libraries for Torch. All instances are Torch serializable with `torch.save` and `torch.load`. 9 | 10 | 11 | ## Installation 12 | 13 | You can install `torchlib` as follows: 14 | 15 | `git clone https://github.com/vzhong/torchlib.git && cd torchlib && luarocks make` 16 | 17 | Torchlib is namespaced locally. To use it: 18 | 19 | ```lua 20 | local tl = require 'torchlib' 21 | 22 | local m = tl.DirectedGraph() 23 | ... 24 | ``` 25 | 26 | Examples and use cases are shown in the documentation. 27 | 28 | 29 | ## Documentation 30 | 31 | The documentation is hosted [here](http://www.victorzhong.com/torchlib). 32 | Alternatively you can build your own documentation with `docroc`, which you can get [here](https://github.com/vzhong/docroc). 33 | 34 | 35 | ## Overview 36 | 37 | Torchlib's can be divided into categories based on usecases. 38 | 39 | ### Basic Datastructures and Algorithms 40 | 41 | - Graphs 42 | - Lists, heaps, queues and stacks 43 | - Maps and counters 44 | - Sets 45 | - Trees 46 | 47 | ### Machine Learning 48 | 49 | The machine learning package contains utilities that facilitate the training of and evaluation of machine learning models. These include: 50 | 51 | - Dataset, which provides mechanisms for subsampling, shuffling, batching of arbitrary examples. 52 | - Vocab, for mapping between indices and words. 53 | - Model, an abstract class to facilitate the training of Torch based machine learning models. 54 | - Scorer, for evaluating precision/recall metrics. 55 | - ProbTable, for modeling probability distributions. 56 | - Experiment, for logging experiment progress to a postgres instance. 57 | 58 | ### Utilities 59 | 60 | - Downloader, for downloading content via http. 61 | - Global, global convenience functions namespaced under `tl`. 62 | - String, string convenience functions namespaced under `tl.string` and monkeypatched into `string`. 63 | - Table, table convenience functions namespaced under `tl.table` and monkeypatched into `table`. 64 | 65 | 66 | ## Contribution 67 | 68 | Pull requests are welcome! Torchlib is unit tested with the default Torch testing framework. Continuous integration is hosted on [Wercker](http://wercker.com/) which also automatically builds the documentations and deploys them on Github pages (of this repo). 69 | -------------------------------------------------------------------------------- /docs/graph/DirectedGraph.md: -------------------------------------------------------------------------------- 1 | # DirectedGraph 2 | A directed graph implementation. 3 | This is a subclass of `Graph`. 4 | 5 | 6 | 7 | 8 | ## DirectedGraph:connect(nodeA, nodeB) 9 | [View source](http://github.com/vzhong/torchlib/blob/master/src//graph/DirectedGraph.lua#L12) 10 | 11 | Connects two nodes. 12 | 13 | Arguments: 14 | 15 | - `nodeA ` (`Graph.Node`): starting node. 16 | - `nodeB ` (`Graph.Node`): ending node. 17 | 18 | 19 | ## DirectedGraph:topologicalSort() 20 | [View source](http://github.com/vzhong/torchlib/blob/master/src//graph/DirectedGraph.lua#L20) 21 | 22 | Returns nodes in this graph in topologically sorted order 23 | 24 | Returns: 25 | 26 | - (`table`) 27 | 28 | ## DirectedGraph:hasCycle() 29 | [View source](http://github.com/vzhong/torchlib/blob/master/src//graph/DirectedGraph.lua#L31) 30 | 31 | Returns whether the graph has a cycle 32 | 33 | Returns: 34 | 35 | - (`boolean`) 36 | 37 | ## DirectedGraph:transpose() 38 | [View source](http://github.com/vzhong/torchlib/blob/master/src//graph/DirectedGraph.lua#L63) 39 | 40 | Returns a transpose of this graph (eg. with the edges reversed) 41 | 42 | Returns: 43 | 44 | - (`DirectedGraph`) 45 | 46 | ## DirectedGraph:stronglyConnectedComponents() 47 | [View source](http://github.com/vzhong/torchlib/blob/master/src//graph/DirectedGraph.lua#L87) 48 | 49 | Returns strongly connected components. 50 | 51 | Each strongly connected component is itself a table. 52 | 53 | Returns: 54 | 55 | - (`table[table]`) a table of strongly connected components. 56 | 57 | -------------------------------------------------------------------------------- /docs/graph/Graph.md: -------------------------------------------------------------------------------- 1 | # Graph 2 | Abstract graph implementation. 3 | 4 | A `Graph` consists of `GraphNode`s. Each `GraphNode` can be in three states: 5 | - `UNDISCOVERED` 6 | - `VISITED` 7 | - `FINISHED` 8 | 9 | 10 | 11 | 12 | ## GraphNode:\_\_init(val) 13 | [View source](http://github.com/vzhong/torchlib/blob/master/src//graph/Graph.lua#L26) 14 | 15 | Constructor. 16 | 17 | Arguments: 18 | 19 | - `val ` (`any`): value for the new node. 20 | 21 | 22 | ## GraphNode:\_\_tostring\_\_() 23 | [View source](http://github.com/vzhong/torchlib/blob/master/src//graph/Graph.lua#L31) 24 | 25 | 26 | 27 | Returns: 28 | 29 | - (`string`) string representation 30 | 31 | ## Graph:\_\_init() 32 | [View source](http://github.com/vzhong/torchlib/blob/master/src//graph/Graph.lua#L36) 33 | 34 | Constructor. 35 | 36 | 37 | ## Graph:size() 38 | [View source](http://github.com/vzhong/torchlib/blob/master/src//graph/Graph.lua#L41) 39 | 40 | 41 | 42 | Returns: 43 | 44 | - (`int`) number of nodes in the graph. 45 | 46 | ## Graph:assertValidNode(node) 47 | [View source](http://github.com/vzhong/torchlib/blob/master/src//graph/Graph.lua#L47) 48 | 49 | Verifies that the node is in the graph 50 | 51 | Arguments: 52 | 53 | - `node ` (`Graph.Node`): the node to verify. 54 | 55 | 56 | ## Graph:addNode(val) 57 | [View source](http://github.com/vzhong/torchlib/blob/master/src//graph/Graph.lua#L54) 58 | 59 | Adds a node with given value to the graph. 60 | 61 | Arguments: 62 | 63 | - `val ` (`any`): value for the new node. 64 | 65 | Returns: 66 | 67 | - (`Graph.Node`) 68 | 69 | ## Graph:connectionsOf(node) 70 | [View source](http://github.com/vzhong/torchlib/blob/master/src//graph/Graph.lua#L63) 71 | 72 | Returns neighbours of a given node. 73 | 74 | Arguments: 75 | 76 | - `node ` (`Graph.Node`): the node to find neighbours for. 77 | 78 | Returns: 79 | 80 | - (`table(Graph.Node)`) 81 | 82 | ## Graph:nodeSet() 83 | [View source](http://github.com/vzhong/torchlib/blob/master/src//graph/Graph.lua#L70) 84 | 85 | Returns a set of nodes in the graph. 86 | 87 | Returns: 88 | 89 | - (`Set(Graph.Node)`) 90 | 91 | ## Graph:resetState() 92 | [View source](http://github.com/vzhong/torchlib/blob/master/src//graph/Graph.lua#L78) 93 | 94 | Initializes all nodes to `Graph.state.UNDISCOVERED`. 95 | 96 | Returns: 97 | 98 | - (`Graph`) 99 | 100 | The graph will be returned 101 | 102 | ## Graph:breadthFirstSearch(source, callbacks) 103 | [View source](http://github.com/vzhong/torchlib/blob/master/src//graph/Graph.lua#L99) 104 | 105 | Performs breadth first search. 106 | 107 | Arguments: 108 | 109 | - `source ` (`Graph.Node`): the source node to start BFS. 110 | - `callbacks ` (`table[string:function]`): a map with optional callbacks 111 | 112 | Available callbacks:. Optional. 113 | 114 | - `discover = function(Graph.Node)`: called when a node is initially encountered 115 | 116 | - `finish = function(Graph.Node)`: called when a node has been fully explored (eg. its connected nodes have all been visited) 117 | 118 | ## Graph:shortestPath(source, destination, skipBFS) 119 | [View source](http://github.com/vzhong/torchlib/blob/master/src//graph/Graph.lua#L138) 120 | 121 | Returns the shortest path from source to destination 122 | 123 | Arguments: 124 | 125 | - `source ` (`Graph.Node`): starting node. 126 | - `destination ` (`Graph.Node`): end node. 127 | - `skipBFS ` (`boolean`): whether BFS has already been performned. Optional. 128 | 129 | Note: This function relies on the results from a BFS call. By default, a BFS is performed before 130 | 131 | retrieving the shortest path. Alternatively, if the caller has already performed BFS, then 132 | 133 | this BFS can be skipped by passing in `skipBFS = true`. 134 | 135 | ## Graph:depthFirstSearch(nodes, callbacks) 136 | [View source](http://github.com/vzhong/torchlib/blob/master/src//graph/Graph.lua#L167) 137 | 138 | Performs depth first search. 139 | 140 | Arguments: 141 | 142 | - `nodes ` (`table[Graph.Node]`): the table of nodes on which to perform DFS. If not set, then all nodes in the graph are used. 143 | - `callbacks ` (`table[string:function]`): a map with optional callbacks 144 | 145 | Available callbacks:. Optional. 146 | 147 | - `discover = function(Graph.Node)`: called when a node is initially encountered 148 | 149 | - `finish = function(Graph.Node)`: called when a node has been fully explored (eg. its connected nodes have all been visited) 150 | 151 | -------------------------------------------------------------------------------- /docs/graph/UndirectedGraph.md: -------------------------------------------------------------------------------- 1 | # UndirectedGraph 2 | Undirected graph implementation 3 | This is a subclass of `Graph`. 4 | 5 | 6 | 7 | 8 | ## UndirectedGraph:connect(nodeA, nodeB) 9 | [View source](http://github.com/vzhong/torchlib/blob/master/src//graph/UndirectedGraph.lua#L11) 10 | 11 | Connects two nodes. 12 | 13 | Arguments: 14 | 15 | - `nodeA ` (`Graph.Node`): starting node. 16 | - `nodeB ` (`Graph.Node`): ending node. 17 | 18 | 19 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Torchlib 2 | 3 | [![wercker status](https://app.wercker.com/status/c7bd97d06535598d96937e0cf5ace629/s/master "wercker status")](https://app.wercker.com/project/bykey/c7bd97d06535598d96937e0cf5ace629) 4 | [![codecov](https://codecov.io/gh/vzhong/torchlib/branch/master/graph/badge.svg)](https://codecov.io/gh/vzhong/torchlib) 5 | 6 | [View documentation](http://www.victorzhong.com/torchlib). 7 | 8 | Data structures and libraries for Torch. All instances are Torch serializable with `torch.save` and `torch.load`. 9 | 10 | 11 | ## Installation 12 | 13 | You can install `torchlib` as follows: 14 | 15 | `git clone https://github.com/vzhong/torchlib.git && cd torchlib && luarocks make` 16 | 17 | Torchlib is namespaced locally. To use it: 18 | 19 | ```lua 20 | local tl = require 'torchlib' 21 | 22 | local m = tl.DirectedGraph() 23 | ... 24 | ``` 25 | 26 | Examples and use cases are shown in the documentation. 27 | 28 | 29 | ## Documentation 30 | 31 | The documentation is hosted [here](http://www.victorzhong.com/torchlib). 32 | Alternatively you can build your own documentation with `docroc`, which you can get [here](https://github.com/vzhong/docroc). 33 | 34 | 35 | ## Overview 36 | 37 | Torchlib's can be divided into categories based on usecases. 38 | 39 | ### Basic Datastructures and Algorithms 40 | 41 | - Graphs 42 | - Lists, heaps, queues and stacks 43 | - Maps and counters 44 | - Sets 45 | - Trees 46 | 47 | ### Machine Learning 48 | 49 | The machine learning package contains utilities that facilitate the training of and evaluation of machine learning models. These include: 50 | 51 | - Dataset, which provides mechanisms for subsampling, shuffling, batching of arbitrary examples. 52 | - Vocab, for mapping between indices and words. 53 | - Model, an abstract class to facilitate the training of Torch based machine learning models. 54 | - Scorer, for evaluating precision/recall metrics. 55 | - ProbTable, for modeling probability distributions. 56 | - Experiment, for logging experiment progress to a postgres instance. 57 | 58 | ### Utilities 59 | 60 | - Downloader, for downloading content via http. 61 | - Global, global convenience functions namespaced under `tl`. 62 | - String, string convenience functions namespaced under `tl.string` and monkeypatched into `string`. 63 | - Table, table convenience functions namespaced under `tl.table` and monkeypatched into `table`. 64 | 65 | 66 | ## Contribution 67 | 68 | Pull requests are welcome! Torchlib is unit tested with the default Torch testing framework. Continuous integration is hosted on [Wercker](http://wercker.com/) which also automatically builds the documentations and deploys them on Github pages (of this repo). 69 | -------------------------------------------------------------------------------- /docs/list/ArrayList.md: -------------------------------------------------------------------------------- 1 | # ArrayList 2 | Array list implementation. 3 | This is a subclass of `List`. 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /docs/list/Heap.md: -------------------------------------------------------------------------------- 1 | # Heap 2 | Max heap implementation. 3 | This is a subclass of `List`. 4 | 5 | 6 | 7 | 8 | ## Heap.parent(i) 9 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/Heap.lua#L10) 10 | 11 | 12 | 13 | Arguments: 14 | 15 | - `i ` (`int`): index to compute parent for. 16 | 17 | Returns: 18 | 19 | - (`int`) parent index of `i` 20 | 21 | ## Heap.left(i) 22 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/Heap.lua#L16) 23 | 24 | 25 | 26 | Arguments: 27 | 28 | - `i ` (`int`): index to compute left child for. 29 | 30 | Returns: 31 | 32 | - (`int`) left child index of `i` 33 | 34 | ## Heap.right(i) 35 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/Heap.lua#L22) 36 | 37 | 38 | 39 | Arguments: 40 | 41 | - `i ` (`int`): index to compute right child for. 42 | 43 | Returns: 44 | 45 | - (`int`) right child index of `i` 46 | 47 | ## Heap:maxHeapify(i, effectiveSize) 48 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/Heap.lua#L35) 49 | 50 | Restores max heap condition at the `i`th index. 51 | 52 | Arguments: 53 | 54 | - `i ` (`int`): index at which to restore max heap condition. 55 | - `effectiveSize ` (`int`): effective size of the heap (eg. number of valid elements). Optional, Default: `size`. 56 | 57 | Returns: 58 | 59 | - (`Heap`) modified heap 60 | 61 | Recursively swaps down the node at `i` until the max heap condition is restored at `a[i]`. 62 | 63 | Note: this function assumes that the binary trees rooted at left and right are max heaps but 64 | 65 | `a[i]` may violate the max-heap condition. 66 | 67 | ## Heap:sort() 68 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/Heap.lua#L65) 69 | 70 | Sorts the heap using heap sort. 71 | 72 | Returns: 73 | 74 | - (`Heap`) sorted heap 75 | 76 | ## Heap:push(key, val) 77 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/Heap.lua#L79) 78 | 79 | Adds an element to the heap while keeping max heap property. 80 | 81 | Arguments: 82 | 83 | - `key ` (`number`): priority to add with. 84 | - `val ` (`any`): element to add to heap. 85 | 86 | Returns: 87 | 88 | - (`Heap`) modified heap 89 | 90 | ## Heap:pop() 91 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/Heap.lua#L88) 92 | 93 | Removes and returns the max priority element from the heap. 94 | 95 | Returns: 96 | 97 | - (`any`) removed element 98 | 99 | ## Heap:peek() 100 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/Heap.lua#L101) 101 | 102 | 103 | 104 | {any} max priority element from the heap 105 | 106 | Note: the element is not removed. 107 | 108 | -------------------------------------------------------------------------------- /docs/list/LinkedList.md: -------------------------------------------------------------------------------- 1 | # LinkedList 2 | Array list implementation. 3 | This is a subclass of `List`. 4 | 5 | 6 | 7 | 8 | ## LinkedList:head() 9 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/LinkedList.lua#L33) 10 | 11 | 12 | 13 | Returns: 14 | 15 | - (`LinkedList.Node`) head of the linked list 16 | 17 | -------------------------------------------------------------------------------- /docs/list/List.md: -------------------------------------------------------------------------------- 1 | # List 2 | Abstract list implementation. 3 | 4 | 5 | 6 | 7 | ## List:\_\_init(values) 8 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/List.lua#L9) 9 | 10 | Constructor. 11 | 12 | Arguments: 13 | 14 | - `values ` (`table[any]`): used to initialize the list. Optional. 15 | 16 | 17 | ## List:add(val, index) 18 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/List.lua#L17) 19 | 20 | Adds element to list. 21 | 22 | Arguments: 23 | 24 | - `val ` (`any`): value to add. 25 | - `index ` (`int`): index to add value at. Optional, Default: `end`. 26 | 27 | Returns: 28 | 29 | - (`List`) - modified list 30 | 31 | ## List:get(index) 32 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/List.lua#L25) 33 | 34 | 35 | 36 | Arguments: 37 | 38 | - `index ` (`int`): index to retrieve value for. 39 | 40 | Returns: 41 | 42 | - (`any`) - value at index 43 | 44 | Asserts error if `index` is out of bounds. 45 | 46 | ## List:set(index, val) 47 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/List.lua#L35) 48 | 49 | Sets the value at index. 50 | 51 | Arguments: 52 | 53 | - `index ` (`int`): inde to set value for. 54 | - `val ` (`any`): value to set. 55 | 56 | Returns: 57 | 58 | - (`List`) - modified list 59 | 60 | Asserts error if `index` is out of bounds. 61 | 62 | ## List:remove(index) 63 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/List.lua#L44) 64 | 65 | 66 | 67 | Arguments: 68 | 69 | - `index ` (`int`): index to remove at. 70 | 71 | Returns: 72 | 73 | - (`any`) - value at index 74 | 75 | Elements after `index` will be shifted to the left by 1. 76 | 77 | Asserts error if `index` is out of bounds. 78 | 79 | ## List:equals(another) 80 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/List.lua#L53) 81 | 82 | Compares two lists. 83 | 84 | Arguments: 85 | 86 | - `another ` (`List`): another list to compare to. 87 | 88 | Returns: 89 | 90 | - (`boolean`) whether this list is equal to `another` 91 | 92 | Lists are considered equal if their values match at every position. 93 | 94 | ## List:swap(i, j) 95 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/List.lua#L61) 96 | 97 | Swaps value at two indices. 98 | 99 | Arguments: 100 | 101 | - `i ` (`int`): first index. 102 | - `j ` (`int`): second index. 103 | 104 | Returns: 105 | 106 | - (`List`) - modified list 107 | 108 | ## List:totable() 109 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/List.lua#L67) 110 | 111 | Returns the list in table form. 112 | 113 | Returns: 114 | 115 | - (`table[any]`) a table containing the values in the list. 116 | 117 | ## List:assertValidIndex(index) 118 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/List.lua#L73) 119 | 120 | Asserts that index is inside the list. 121 | 122 | Arguments: 123 | 124 | - `index ` (`int`): index to check. 125 | 126 | 127 | ## List:size() 128 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/List.lua#L78) 129 | 130 | 131 | 132 | Returns: 133 | 134 | - (`int`) size of the list 135 | 136 | ## List:addMany(...) 137 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/List.lua#L85) 138 | 139 | Adds items to the list. 140 | 141 | Arguments: 142 | 143 | - `vararg ` (`vararg[any]`): values to add to the list. 144 | 145 | Returns: 146 | 147 | - (`List`) modified list 148 | 149 | ## List:contains(val) 150 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/List.lua#L96) 151 | 152 | Returns whether the list contains a value. 153 | 154 | Arguments: 155 | 156 | - `val ` (`any`): value to check. 157 | 158 | Returns: 159 | 160 | - (`boolean`) whether `val` is in the list 161 | 162 | ## List:copy() 163 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/List.lua#L106) 164 | 165 | 166 | 167 | Returns: 168 | 169 | - (`List`) a copy of this list 170 | 171 | ## List:isEmpty() 172 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/List.lua#L111) 173 | 174 | 175 | 176 | Returns: 177 | 178 | - (`boolean`) whether the list is empty 179 | 180 | ## List:sublist(start, finish) 181 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/List.lua#L119) 182 | 183 | Returns a copy of a segment of this list. 184 | 185 | Arguments: 186 | 187 | - `start ` (`int`): start of the segment. 188 | - `finish ` (`int`): start of the segment. Optional, Default: `end`. 189 | 190 | 191 | ## List:sort(start, finish) 192 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/List.lua#L131) 193 | 194 | Sorts the list in place. 195 | 196 | Arguments: 197 | 198 | - `start ` (`int`): start index of the sort. Optional, Default: `1`. 199 | - `finish ` (`int`): end index of the sort. Optional, Default: `end`. 200 | 201 | 202 | ## List:\_\_tostring\_\_() 203 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/List.lua#L156) 204 | 205 | 206 | 207 | Returns: 208 | 209 | - (`string`) string representation 210 | 211 | -------------------------------------------------------------------------------- /docs/list/Queue.md: -------------------------------------------------------------------------------- 1 | # Queue 2 | Queue implementation. 3 | This is a subclass of `List`. 4 | 5 | 6 | 7 | 8 | ## Queue:enqueue(val) 9 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/Queue.lua#L13) 10 | 11 | Adds a value to the stack. 12 | 13 | Arguments: 14 | 15 | - `val ` (`any`): value to add. 16 | 17 | Returns: 18 | 19 | - (`Queue`) modified stack 20 | 21 | ## Queue:dequeue() 22 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/Queue.lua#L20) 23 | 24 | Returns and removes the first value in the queue. 25 | 26 | Returns: 27 | 28 | - (`any`) removed value 29 | 30 | -------------------------------------------------------------------------------- /docs/list/Stack.md: -------------------------------------------------------------------------------- 1 | # Stack 2 | Stack implementation. 3 | This is a subclass of `List`. 4 | 5 | 6 | 7 | 8 | ## Stack:push(val) 9 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/Stack.lua#L10) 10 | 11 | Adds a value to the stack. 12 | 13 | Arguments: 14 | 15 | - `val ` (`any`): value to add. 16 | 17 | Returns: 18 | 19 | - (`Stack`) modified stack 20 | 21 | ## Stack:pop() 22 | [View source](http://github.com/vzhong/torchlib/blob/master/src//list/Stack.lua#L17) 23 | 24 | Returns and removes the value at the top of the stack. 25 | 26 | Returns: 27 | 28 | - (`any`) removed value 29 | 30 | -------------------------------------------------------------------------------- /docs/map/Counter.md: -------------------------------------------------------------------------------- 1 | # Counter 2 | Implementation of a counter. 3 | 4 | 5 | 6 | 7 | ## Counter:\_\_init() 8 | [View source](http://github.com/vzhong/torchlib/blob/master/src//map/Counter.lua#L8) 9 | 10 | Constructor. 11 | 12 | 13 | ## Counter:add(key, count) 14 | [View source](http://github.com/vzhong/torchlib/blob/master/src//map/Counter.lua#L16) 15 | 16 | Increments the count for a key. 17 | 18 | Arguments: 19 | 20 | - `key ` (`any`): key to increment count for. 21 | - `count ` (`int`): how much to increment count by. 22 | 23 | Returns: 24 | 25 | - (`int`) the new count 26 | 27 | ## Counter:get(key) 28 | [View source](http://github.com/vzhong/torchlib/blob/master/src//map/Counter.lua#L27) 29 | 30 | 31 | 32 | Arguments: 33 | 34 | - `key ` (`any`): key to return count for. 35 | 36 | Returns: 37 | 38 | - (`int`) the count for the key 39 | 40 | If `key` has not been added to the counter, then returns 0. 41 | 42 | ## Counter:reset() 43 | [View source](http://github.com/vzhong/torchlib/blob/master/src//map/Counter.lua#L33) 44 | 45 | Clears the counter. 46 | 47 | Returns: 48 | 49 | - (`Counter`) the modified counter 50 | 51 | -------------------------------------------------------------------------------- /docs/map/HashMap.md: -------------------------------------------------------------------------------- 1 | # HashMap 2 | Implementation of hash map. 3 | This is a subclass of `Map` 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /docs/map/Map.md: -------------------------------------------------------------------------------- 1 | # Map 2 | Abstract map implementation. 3 | 4 | 5 | 6 | 7 | ## Map:\_\_init(key\_values) 8 | [View source](http://github.com/vzhong/torchlib/blob/master/src//map/Map.lua#L9) 9 | 10 | Constructor. 11 | 12 | Arguments: 13 | 14 | - `key_values ` (`table[any:any]`): used to initialize the map. Optional. 15 | 16 | 17 | ## Map:add(key, val) 18 | [View source](http://github.com/vzhong/torchlib/blob/master/src//map/Map.lua#L16) 19 | 20 | Adds an entry to the map. 21 | 22 | Arguments: 23 | 24 | - `key ` (`any`): key to add. 25 | - `value ` (`any`): value to add. 26 | 27 | 28 | ## Map:addMany(tab) 29 | [View source](http://github.com/vzhong/torchlib/blob/master/src//map/Map.lua#L22) 30 | 31 | Adds many entries to the map. 32 | 33 | Arguments: 34 | 35 | - `tab ` (`table[any:any]`): a map of key value pairs to add. Optional. 36 | 37 | 38 | ## Map:copy() 39 | [View source](http://github.com/vzhong/torchlib/blob/master/src//map/Map.lua#L27) 40 | 41 | 42 | 43 | Returns: 44 | 45 | - (`Map`) copy of this map 46 | 47 | ## Map:contains(key) 48 | [View source](http://github.com/vzhong/torchlib/blob/master/src//map/Map.lua#L33) 49 | 50 | 51 | 52 | Arguments: 53 | 54 | - `key ` (`any`): key to check. 55 | 56 | Returns: 57 | 58 | - (`coolean`) whether the map contains the key 59 | 60 | ## Map:get(key, returnNilIfMissing) 61 | [View source](http://github.com/vzhong/torchlib/blob/master/src//map/Map.lua#L44) 62 | 63 | Retrieves the value for a key. 64 | 65 | Arguments: 66 | 67 | - `key ` (`any`): key to retrive. 68 | - `returnNilIfMissing ` (`boolean`): whether to tolerate missing keys. Optional. 69 | 70 | Returns: 71 | 72 | - (`any`) value corresponding to the key 73 | 74 | By default, asserts error if `key` is not found. If `returnNilIfMissing` is true, 75 | 76 | then a `nil` will be returned if `key` is not found. 77 | 78 | ## Map:remove(key) 79 | [View source](http://github.com/vzhong/torchlib/blob/master/src//map/Map.lua#L53) 80 | 81 | Removes a key value pair 82 | 83 | Arguments: 84 | 85 | - `key ` (`any`): key to remove. 86 | 87 | Returns: 88 | 89 | - (`any`) the removed value 90 | 91 | Asserts error if `key` is not in the map. 92 | 93 | ## Map:keys() 94 | [View source](http://github.com/vzhong/torchlib/blob/master/src//map/Map.lua#L58) 95 | 96 | 97 | 98 | Returns: 99 | 100 | - (`table[any]`) a table of the keys in the map 101 | 102 | ## Map:totable() 103 | [View source](http://github.com/vzhong/torchlib/blob/master/src//map/Map.lua#L63) 104 | 105 | 106 | 107 | Returns: 108 | 109 | - (`table[any:any]`) the map in table form 110 | 111 | ## Map:equals(another) 112 | [View source](http://github.com/vzhong/torchlib/blob/master/src//map/Map.lua#L71) 113 | 114 | 115 | 116 | Arguments: 117 | 118 | - `another ` (`Map`): another map to compare to. 119 | 120 | Returns: 121 | 122 | - (`boolean`) whether this map equals `another`. 123 | 124 | Maps are considered equal if all keys and corresponding values match. 125 | 126 | ## Map:size() 127 | [View source](http://github.com/vzhong/torchlib/blob/master/src//map/Map.lua#L76) 128 | 129 | 130 | 131 | Returns: 132 | 133 | - (`int`) number of key value pairs in the map 134 | 135 | -------------------------------------------------------------------------------- /docs/ml/Dataset.md: -------------------------------------------------------------------------------- 1 | # Dataset 2 | Implementation of dataset container. 3 | The goal of this class is to provide utilities for manipulating generic datasets. in particular, a 4 | dataset can be a list of examples, each with a fixed set of fields. 5 | 6 | 7 | 8 | 9 | ## Dataset:\_\_init(fields) 10 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Dataset.lua#L30) 11 | 12 | Constructor. 13 | 14 | Arguments: 15 | 16 | - `fields ` (`table[any:any]`): a table containing key value pairs 17 | 18 | Each value is a list of tensors and `value[i]` contains the value corresponding to the `i`th example. 19 | 20 | Example:. 21 | 22 | Suppose we have two examples, with fields `X` and `Y`. The first example has `X=[1, 2, 3], Y=1` while 23 | 24 | the second example has `X=[4, 5, 6, 7, 8}, Y=4`. To create a dataset: 25 | 26 | ```lua 27 | X = {torch.Tensor{1, 2, 3}, torch.Tensor{4, 5, 6, 7, 8}} 28 | Y = {1, 4} 29 | d = Dataset{X = X, Y = Y} 30 | 31 | Of course, in practice the fields can be arbitrary, so long as each field is a table and has an equal 32 | number of elements. 33 | ``` 34 | 35 | ## Dataset.from\_conll(fname) 36 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Dataset.lua#L69) 37 | 38 | Creates a dataset from CONLL format. 39 | 40 | Arguments: 41 | 42 | - `fname ` (`string`): path to CONLL file. 43 | 44 | Returns: 45 | 46 | - (`Dataset`) loaded dataset 47 | 48 | The format is as follows: 49 | 50 | ```text 51 | # word subj subj_ner obj obj_ner stanford_pos stanford_ner stanford_dep_edge stanford_dep_governor 52 | per:city_of_birth 53 | - - - - - : O punct 1 54 | 20 - - - - CD DATE ROOT -1 55 | : - - - - : O punct 1 56 | Alexander SUBJECT PERSON - - NNP PERSON compound 4 57 | Haig SUBJECT PERSON - - NNP PERSON dep 1 58 | , - - - - , O punct 4 59 | US - - - - NNP LOCATION compound 7 60 | secretary - - - - NN O appos 4 61 | 62 | That is, the first line is a tab delimited header, followed by examples separated by a blank line. 63 | The first line of the example is the class label. The rest of the rows correspond to tokens and their associated attributes. 64 | 65 | Example: 66 | ``` 67 | 68 | ```lua 69 | dataset = Dataset.from_conll('data.conll') 70 | ``` 71 | 72 | ## Dataset:\_\_tostring\_\_() 73 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Dataset.lua#L112) 74 | 75 | 76 | 77 | Returns: 78 | 79 | - (`string`) string representation 80 | 81 | ## Dataset:size() 82 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Dataset.lua#L126) 83 | 84 | 85 | 86 | Returns: 87 | 88 | - (`int`) number of examples in the dataset 89 | 90 | ## Dataset:kfolds(k) 91 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Dataset.lua#L137) 92 | 93 | Returns a table of `k` folds of the dataset. 94 | 95 | Arguments: 96 | 97 | - `k ` (`int`): how many folds to create. 98 | 99 | Returns: 100 | 101 | - (`table[table]`) tables of indices corresponding to each fold 102 | 103 | Each fold consists of a random table of indices corresponding to the examples in the fold. 104 | 105 | ## Dataset:view(...) 106 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Dataset.lua#L155) 107 | 108 | Copies out a new Dataset which is a view into the current Dataset. 109 | 110 | Arguments: 111 | 112 | - `vararg ` (`vararg`): each argument is a tables of integer indices corresponding to a view. 113 | 114 | Returns: 115 | 116 | - (`vararg(Datasets)`) one dataset view for each list of indices 117 | 118 | Example: 119 | 120 | Suppose we already have a `dataset` and would like to split it into two datasets. We want 121 | the first dataset `a` to contain examples 1 and 3 of the original dataset. We want the 122 | second dataset `b` to contain examples 1, 2 and 3 (yes, duplicates are supported). 123 | 124 | ```lua 125 | a, b = dataset:view({1, 3}, {1, 2, 3}) 126 | ``` 127 | 128 | ## Dataset:train\_dev\_split(train\_indices) 129 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Dataset.lua#L181) 130 | 131 | Creates a train split and a test split given the train indices. 132 | 133 | Arguments: 134 | 135 | - `train_indices ` (`table[int]`): a table of integers corresponding to indices of training examples. 136 | 137 | Returns: 138 | 139 | - (`Dataset, Dataset`) train and test dataset views 140 | 141 | Other examples will be used as test examples. 142 | 143 | Example: 144 | 145 | Suppose we'd like to split a `dataset` and use its 1, 2, 4 and 5th examples for training. 146 | 147 | ```lua 148 | train, test = dataset:train_dev_split{1, 2, 4, 5} 149 | ``` 150 | 151 | ## Dataset:index(indices) 152 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Dataset.lua#L198) 153 | 154 | Reindexes the dataset accoring to the new indices. 155 | 156 | Arguments: 157 | 158 | - `indices ` (`table[int]`): indices to re-index the dataset with. 159 | 160 | Returns: 161 | 162 | - (`Dataset`) modified dataset 163 | 164 | Example: 165 | 166 | Suppose we have a `dataset` of 5 examples and want to swap example 1 with example 5. 167 | 168 | ```lua 169 | dataset:index{5, 2, 3, 4, 1} 170 | ``` 171 | 172 | ## Dataset:shuffle() 173 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Dataset.lua#L211) 174 | 175 | Shuffles the dataset in place 176 | 177 | Returns: 178 | 179 | - (`Dataset`) modified dataset 180 | 181 | ## Dataset:sort\_by\_length(field) 182 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Dataset.lua#L221) 183 | 184 | Sorts the examples in place by the length of the requested field. 185 | 186 | Arguments: 187 | 188 | - `field ` (`string`): field to sort with. 189 | 190 | Returns: 191 | 192 | - (`Dataset`) modified dataset 193 | 194 | It is assumed that the field contains torch Tensors. Sorts in ascending order. 195 | 196 | ## Dataset.pad(tensors, PAD) 197 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Dataset.lua#L244) 198 | 199 | Prepends shorter tensors in a table of tensors with `PAD` such that each tensor in the batch are of the same length. 200 | 201 | Arguments: 202 | 203 | - `tensors ` (`table[torch.Tensor]`): tensors of varying lengths. 204 | - `PAD ` (`int`): index to pad missing elements with. 205 | 206 | Example:. Optional, Default: `0`. 207 | 208 | ```lua 209 | X = {torch.Tensor{1, 2, 3}, torch.Tensor{4}} 210 | Y = Dataset.pad(X, 0) 211 | 212 | `Y` is now: 213 | ``` 214 | 215 | ```lua 216 | torch.Tensor{{1, 2, 3}, {0, 0, 4}} 217 | ``` 218 | 219 | ## Dataset:batches(batch\_size) 220 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Dataset.lua#L268) 221 | 222 | Creates a batch iterator over the dataset. 223 | 224 | Arguments: 225 | 226 | - `batch_size ` (`int`): maximum size of each batch 227 | 228 | Example:. 229 | 230 | ```lua 231 | d = Dataset{X=X, Y=Y} 232 | for batch, batch_end in d:batches(5) do 233 | print(batch.X) 234 | print(batch.Y) 235 | end 236 | ``` 237 | 238 | ## Dataset:transform(transforms, in\_place) 239 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Dataset.lua#L306) 240 | 241 | Applies transformations to fields in the dataset. 242 | 243 | Arguments: 244 | 245 | - `transforms ` (`table[string:function]`): a key-value map where a key is a field in the dataset and the corresponding value 246 | is a function that is to be applied to the requested field for each example. 247 | - `in_place ` (`boolean`): whether to apply the transformation in place or return a new dataset. Optional. 248 | 249 | Example: 250 | 251 | ```lua 252 | dataset = Dataset{names={'alice', 'bob', 'charlie'}, id={1, 2, 3}} 253 | dataset2 = dataset:transform{names=string.upper, id=function(x) return x+1 end} 254 | ``` 255 | 256 | `dataset2` is now `Dataset{names={'ALICE', 'BOB', 'CHARLIE'}, id={2, 3, 4}}` while `dataset` remains unchanged. 257 | 258 | ```lua 259 | dataset = Dataset{names={'alice', 'bob', 'charlie'}, id={1, 2, 3}} 260 | dataset2 = dataset:transform({names=string.upper}, true) 261 | ``` 262 | 263 | `dataset` is now `Dataset{names={'ALICE', 'BOB', 'CHARLIE'}, id={1, 2, 3}}` and `dataset2` refers to `dataset`. 264 | 265 | -------------------------------------------------------------------------------- /docs/ml/Experiment.md: -------------------------------------------------------------------------------- 1 | # Experiment 2 | Experiment container that is backed up to a Postgres instance. 3 | 4 | Example: 5 | 6 | Suppose we have already made a postgres database called `myexp`. 7 | 8 | 9 | 10 | ```lua 11 | local c = Experiment.new('myexp') 12 | local run = c:create_run{dataset='foobar', lr=1.0, n_hid=10} 13 | print(run:info()) 14 | run:submit_scores(1, {macro={f1=0.53, precision=0.52, recall=0.54}, micro={f1=0.10, precision=0.10, recall=0.10}}) 15 | run:submit_scores(2, {macro={f1=0.55, precision=0.55, recall=0.55}, micro={f1=0.10, precision=0.10, recall=0.10}}) 16 | print(run:scores()) 17 | 18 | run:submit_prediction(1, 'person', 'thing', {dataset='foobar'}) 19 | run:submit_prediction(2, 'person', 'person', {dataset='foobar'}) 20 | run:submit_prediction(3, 'person', 'thing', {dataset='foobar'}) 21 | run:submit_prediction(4, 'thing', 'thing', {dataset='foobar'}) 22 | print(run:predictions()) 23 | ``` 24 | 25 | ## Experiment:\_\_init(name, username, password, hostname, port) 26 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Experiment.lua#L38) 27 | 28 | Constructor. 29 | 30 | Arguments: 31 | 32 | - `name ` (`string`): name of the experiment. 33 | - ` ` (`username`): username for postgres. Optional. 34 | - ` ` (`hostname`): hostname for postgres. Optional. 35 | - ` ` (`port`): port for postgres 36 | 37 | It is assumed that a database with this name also exists and the user has permission to connect to it. 38 | For more information on the parameters for postgres, see: 39 | 40 | http://keplerproject.github.io/luasql/manual.html#postgres_extensions. Optional. 41 | 42 | 43 | ## Experiment:setup() 44 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Experiment.lua#L61) 45 | 46 | Creates relevant tables. 47 | 48 | 49 | ## Experiment:delete() 50 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Experiment.lua#L90) 51 | 52 | Drops tables for this experiment. 53 | 54 | 55 | ## Experiment:query(query, iterator) 56 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Experiment.lua#L119) 57 | 58 | Submits a query to the database and returns the result. 59 | 60 | Arguments: 61 | 62 | - `query ` (`string`): query to run. 63 | - `iterator ` (`boolean`): whether to return the result as an iterator or as a table. Optional. 64 | 65 | Returns: 66 | 67 | - (`conditional`) if `true`, then an iterator will be returned. Otherwise, a table will be returned 68 | 69 | If the result is not a number, then results will be returned. If the result is a number, then no result will 70 | be returned. 71 | 72 | If the query fails, then an error will be raised. 73 | 74 | ## Experiment:create\_run(opt) 75 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Experiment.lua#L163) 76 | 77 | Creates a new run for the experiment. 78 | 79 | Arguments: 80 | 81 | - `opt ` (`table[any:any]`): options for the run. 82 | 83 | Returns: 84 | 85 | - (`table`) a Run object 86 | 87 | The Run object returned has the following functions: 88 | 89 | - `run.id`: the id for this run 90 | 91 | - `run:info()`: retrieves the row in the runs table. 92 | 93 | - `run:scores()`: retrieves the scores in the scores table. 94 | 95 | - `run:submit_scores(epoch, scores)`: submits scores. 96 | 97 | - `run:predictions(example_id)`: retrieves the predictions. 98 | 99 | - `run:submit_prediction(example_id, pred, gold, info)`: submits a single prediction. 100 | 101 | These are merely convenience functions mapping to corresponding methods in `Experiment`. 102 | 103 | They are convienient in the sense that one does not have to memorize the `id` of the Run to use them. 104 | 105 | ## Experiment:get\_run\_info(id) 106 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Experiment.lua#L181) 107 | 108 | Retrieves the options for the requested run. 109 | 110 | Arguments: 111 | 112 | - `id ` (`int`): ID of the run to retrieve. 113 | 114 | 115 | ## Experiment:submit\_scores(run\_id, epoch, scores) 116 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Experiment.lua#L191) 117 | 118 | Submits a score for the run. 119 | 120 | Arguments: 121 | 122 | - `run_id ` (`int`): ID of the run to submit scores for. 123 | - `epoch ` (`int`): epoch of the score. 124 | - `scores ` (`table[string:number]`): scores to submit. 125 | 126 | 127 | ## Experiment:get\_run\_scores(id) 128 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Experiment.lua#L203) 129 | 130 | Retrieves the scores for the requested run. 131 | 132 | Arguments: 133 | 134 | - `id ` (`int`): ID of the run to retrieve. 135 | 136 | 137 | ## Experiment:submit\_prediction(run\_id, example\_id, pred, gold, info) 138 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Experiment.lua#L214) 139 | 140 | Submits the prediction for a single example for a run. 141 | 142 | Arguments: 143 | 144 | - `run_id ` (`int`): ID for the run. 145 | - `example_id ` (`int`): ID for the example. 146 | - `pred ` (`any`): prediction. 147 | - `gold ` (`any`): ground truth. Optional. 148 | - `info ` (`table`): information for the run. Optional. 149 | 150 | 151 | ## Experiment:get\_predictions(run\_id, example\_id) 152 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Experiment.lua#L237) 153 | 154 | Retrieves the prediction a run. 155 | 156 | Arguments: 157 | 158 | - `run_id ` (`int`): ID for the run. 159 | - `example_id ` (`int`): ID for the example. If this is given then only the prediction for this example is returned. Optional. 160 | 161 | 162 | -------------------------------------------------------------------------------- /docs/ml/GloveVocab.md: -------------------------------------------------------------------------------- 1 | # GloveVocab 2 | Vocab object prepopulated with Glove embeddings by Pennington, Socher, and Manning. 3 | This is a subclass of `Vocab`. 4 | For details, see: 5 | 6 | 7 | 8 | http://nlp.stanford.edu/projects/glove/. 9 | 10 | This only supports the 50-d wikipedia/Giga-word version. 11 | 12 | The download is from: 13 | 14 | https://dl.dropboxusercontent.com/u/9015381/datasets/torchnlp/glove.6B.50d.t7 15 | 16 | ## GloveVocab:load\_words() 17 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/GloveVocab.lua#L18) 18 | 19 | Retrieves the word list and populates the vocabulary. 20 | 21 | 22 | ## GloveVocab:embeddings() 23 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/GloveVocab.lua#L29) 24 | 25 | 26 | 27 | Returns: 28 | 29 | - (`torch.Tensor`) pretrained embeddings for words in the vocabulary 30 | 31 | ## GloveVocab:\_\_tostring\_\_() 32 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/GloveVocab.lua#L47) 33 | 34 | 35 | 36 | Returns: 37 | 38 | - (`string`) string representation 39 | 40 | -------------------------------------------------------------------------------- /docs/ml/ProbTable.md: -------------------------------------------------------------------------------- 1 | # ProbTable 2 | Implementation of probability table using Torch tensor 3 | 4 | 5 | 6 | 7 | ## ProbTable:\_\_init(P, names) 8 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/ProbTable.lua#L17) 9 | 10 | Constructor. 11 | 12 | Arguments: 13 | 14 | - `P ` (`torch.tensor`): probability Tensor, the `i`th dimension corresponds to the `i`th variable. 15 | - `names ` (`table[string]`): A table of names for the variables. By default theses will be assigned using indices. 16 | 17 | Example:. Optional. 18 | 19 | ```lua 20 | local t = ProbTable(torch.Tensor{{0.2, 0.8}, {0.4, 0.6}, {0.1, 0.9}}, {'a', 'b'}) 21 | t:query{a=1, b=2} 0.8 22 | t:query{a=2} Tensor{0.4, 0.6} 23 | ``` 24 | 25 | ## ProbTable:size() 26 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/ProbTable.lua#L35) 27 | 28 | 29 | 30 | Returns: 31 | 32 | - (`int`) number of variables in the table 33 | 34 | ## ProbTable:query(dict) 35 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/ProbTable.lua#L50) 36 | 37 | 38 | 39 | Arguments: 40 | 41 | - `dict ` (`table[string`): an assignment to consider 42 | 43 | Example:. Optional, Default: `int]`. 44 | 45 | Returns: 46 | 47 | - (`torch.Tensor`) probabilities for the assignments in `dict`. 48 | 49 | ```lua 50 | local t = ProbTable(torch.Tensor{{0.2, 0.8}, {0.4, 0.6}, {0.1, 0.9}}, {'a', 'b'}) 51 | t:query{a=1, b=2} 52 | t:query{a=2} 53 | ``` 54 | 55 | The first query is `0.8`. The second query is `Tensor{0.4, 0.6}` 56 | 57 | ## ProbTable:clone() 58 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/ProbTable.lua#L63) 59 | 60 | 61 | 62 | Returns: 63 | 64 | - (`ProbTable`) a copy 65 | 66 | ## ProbTable:\_\_tostring\_\_() 67 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/ProbTable.lua#L70) 68 | 69 | 70 | 71 | Returns: 72 | 73 | - (`string`) string representation 74 | 75 | ## ProbTable:mul(B) 76 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/ProbTable.lua#L94) 77 | 78 | Returns a new table that is the product of two tables. 79 | 80 | Arguments: 81 | 82 | - `B ` (`ProbTable`): another table. 83 | 84 | Returns: 85 | 86 | - (`ProbTable`) product of this and another table 87 | 88 | ## ProbTable:marginalize(name) 89 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/ProbTable.lua#L143) 90 | 91 | Marginalizes this probability table in place. 92 | 93 | Arguments: 94 | 95 | - `name ` (`string`): the variable to marginalize. 96 | 97 | Returns: 98 | 99 | - (`ProbTable`) this probability table with the variable `name` marginalized out 100 | 101 | ## ProbTable:marginal(name) 102 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/ProbTable.lua#L160) 103 | 104 | Marginalizes this probability table in place to calculate a marginal. 105 | 106 | Arguments: 107 | 108 | - `name ` (`string`): the variable to calculate. 109 | 110 | Returns: 111 | 112 | - (`ProbTable`) this probability table marginalizing all variables except `name` 113 | 114 | ## ProbTable:normalize() 115 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/ProbTable.lua#L175) 116 | 117 | Normalizes this table by dividing by the sum of all probabilities. 118 | 119 | Returns: 120 | 121 | - (`ProbTable`) normalized table 122 | 123 | -------------------------------------------------------------------------------- /docs/ml/Scorer.md: -------------------------------------------------------------------------------- 1 | # Scorer 2 | Implementation of a scorer to calculate precision/recall/f1. 3 | 4 | 5 | 6 | 7 | ## Scorer:\_\_init(gold\_log, pred\_log) 8 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Scorer.lua#L11) 9 | 10 | Constructor. 11 | 12 | Arguments: 13 | 14 | - `gold_log ` (`string`): if given, gold labels will be written to this file. Optional. 15 | - `pred_log ` (`string]`): if given, predicted labels will be written to this file. 16 | 17 | 18 | ## Scorer:add\_pred(gold, pred, id) 19 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Scorer.lua#L27) 20 | 21 | Adds a prediction/ground truth pair to the scorer. 22 | 23 | Arguments: 24 | 25 | - `gold ` (`string`): ground truth label. 26 | - `pred ` (`string`): corresponding predicted label. 27 | - `id ` (`string`): corresponding identifier for this example 28 | 29 | If the scorer was given the gold log and the pred log, then the pair will be written to their respective log files. Optional. 30 | 31 | 32 | ## Scorer:reset() 33 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Scorer.lua#L45) 34 | 35 | Removes all remembered statistics from the scorer. 36 | 37 | 38 | ## Scorer:precision\_recall\_f1(ignore) 39 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Scorer.lua#L77) 40 | 41 | Computes the precision/recall/f1 statistics for the current batch of elements. 42 | 43 | Arguments: 44 | 45 | - `ignore ` (`string`): if given, `ignore` is taken to be the "negative" class and its statistics will be withheld 46 | from the computation. Optional. 47 | 48 | Returns: 49 | 50 | - (`table, table, table`) micro, macro, and class scores 51 | 52 | Example: 53 | 54 | ```lua 55 | local s = Scorer() 56 | s:add_pred('a', 'b', 1) 57 | s:add_pred('b', 'b', 2) 58 | s:add_pred('c', 'a', 3) 59 | local micro, macro, all_stats = s:precision_recall_f1(ignore) 60 | ``` 61 | 62 | Returns the following 63 | 64 | - `micro`: the micro averaged precision/recall/f1 statistics 65 | 66 | - `macro`: the macro averaged precision/recall/f1 statistics 67 | 68 | - `class_stats`: the precision/recall/f1 for each class 69 | 70 | -------------------------------------------------------------------------------- /docs/ml/VariableTensor.md: -------------------------------------------------------------------------------- 1 | # VariableTensor 2 | Implementation of a variable tensor class to efficiently store tensors of varying lengths. 3 | 4 | 5 | 6 | 7 | ## VariableTensor:\_\_init(opt) 8 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/VariableTensor.lua#L12) 9 | 10 | Constructor. 11 | 12 | Arguments: 13 | 14 | - `preinit_size ` (`int`): how many indices to preallocate for. Optional, Default: `1`. 15 | - `preinit_store_size ` (`int`): how many elements to preallocate for. Optional, Default: `1`. 16 | 17 | 18 | ## VariableTensor:cuda() 19 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/VariableTensor.lua#L30) 20 | 21 | Moves the storage to cuda 22 | 23 | Returns: 24 | 25 | - (`VariableTensor`) modified tensor 26 | 27 | ## VariableTensor:size() 28 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/VariableTensor.lua#L36) 29 | 30 | 31 | 32 | Returns: 33 | 34 | - (`int`) sum of the size of each tensor in the storage 35 | 36 | ## VariableTensor:push(tensor) 37 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/VariableTensor.lua#L43) 38 | 39 | Appends a tensor to the storage. 40 | 41 | Arguments: 42 | 43 | - `tensor ` (`torch.Tensor`): tensor to add to storage. 44 | 45 | Returns: 46 | 47 | - (`VariableTensor`) modified tensor 48 | 49 | ## VariableTensor:shuffle(indices) 50 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/VariableTensor.lua#L62) 51 | 52 | Shuffles the indices. 53 | 54 | Arguments: 55 | 56 | - `indices ` (`torch.Tensor`): tensor that denotes how the new indices should be set. If not given, then a random 57 | tensor will be generated. Optional. 58 | 59 | Returns: 60 | 61 | - (`torch.Tensor`) the `indices` tensor used to shuffle 62 | 63 | ## VariableTensor:get(i) 64 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/VariableTensor.lua#L71) 65 | 66 | Retrieves the tensor at index `i`. 67 | 68 | Arguments: 69 | 70 | - `i ` (`int`): index to query. 71 | 72 | Returns: 73 | 74 | - (`torch.Tensor`) tensor at index 75 | 76 | ## VariableTensor:batch(indices, pad) 77 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/VariableTensor.lua#L78) 78 | 79 | Creates a zero-padded batch from tensors at the indices `indices`. 80 | 81 | Arguments: 82 | 83 | - `indices ` (`table`): starting indices of tensors to pad. 84 | - `pad ` (`int`): number to use to pad shorter tensors. Optional, Default: `0`. 85 | 86 | 87 | -------------------------------------------------------------------------------- /docs/ml/Vocab.md: -------------------------------------------------------------------------------- 1 | # Vocab 2 | Implementation of vocabulary 3 | 4 | 5 | 6 | 7 | ## Vocab:\_\_init(unk) 8 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Vocab.lua#L9) 9 | 10 | Constructor. 11 | 12 | Arguments: 13 | 14 | - `unk ` (`string`): the symbol for the unknown token. Optional, Default: `'UNK'`. 15 | 16 | 17 | ## Vocab:\_\_tostring\_\_() 18 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Vocab.lua#L22) 19 | 20 | 21 | 22 | Returns: 23 | 24 | - (`string`) string representation 25 | 26 | ## Vocab:contains(word) 27 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Vocab.lua#L28) 28 | 29 | 30 | 31 | Arguments: 32 | 33 | - `word ` (`string`): word to query. 34 | 35 | Returns: 36 | 37 | - (`boolean`) whether `word` is in the vocabulary 38 | 39 | ## Vocab:count(word) 40 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Vocab.lua#L34) 41 | 42 | 43 | 44 | Arguments: 45 | 46 | - `word ` (`string`): word to query. 47 | 48 | Returns: 49 | 50 | - (`int`) count for `word` seen during training 51 | 52 | ## Vocab:size() 53 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Vocab.lua#L40) 54 | 55 | 56 | 57 | Returns: 58 | 59 | - (`int`) how many distinct tokens are in the vocabulary 60 | 61 | ## Vocab:add(word, count) 62 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Vocab.lua#L48) 63 | 64 | Adds `word` `count` time to the vocabulary. 65 | 66 | Arguments: 67 | 68 | - `word ` (`string`): word to add. 69 | - `count ` (`int`): number of times to add. Optional, Default: `1`. 70 | 71 | Returns: 72 | 73 | - (`int`) index of `word` 74 | 75 | ## Vocab:indexOf(word, add) 76 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Vocab.lua#L79) 77 | 78 | 79 | 80 | Arguments: 81 | 82 | - `word ` (`string`): word to query. 83 | - `add ` (`boolean`): whether to add new word to the vocabulary 84 | 85 | If the word is not found, then one of the following occurs: 86 | 87 | - if `add` is `true`, then `word` is added to the vocabulary with count 1 and the new index returned 88 | 89 | - otherwise, the index of the unknown token is returned 90 | 91 | Example: 92 | 93 | Suppose we have a vocabulary of words 'unk', 'foo' and 'bar'. Optional. 94 | 95 | Returns: 96 | 97 | - (`int`) index of `word`. 98 | 99 | ```lua 100 | vocab:indexOf('foo') returns 2 101 | vocab:indexOf('bar') returns 3 102 | vocab:indexOf('hello') returns 1 corresponding to `unk` because `hello` is not in the vocabuarly 103 | vocab:indexOf('hello', true) returns 4 because `hello` is added to the vocabulary 104 | ``` 105 | 106 | ## Vocab:wordAt(index) 107 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Vocab.lua#L105) 108 | 109 | 110 | 111 | Arguments: 112 | 113 | - `index ` (`int`): the index to query 114 | 115 | If `index` is out of bounds then an error will be raised. 116 | 117 | Example: 118 | 119 | Suppose we have a vocabulary with words 'unk', 'foo', and 'bar'. 120 | 121 | Returns: 122 | 123 | - (`string`) word at index `index` 124 | 125 | ```lua 126 | vocab:wordAt(1) unk 127 | vocab:wordAt(2) foo 128 | vocab:wordAt(4) raises and error because there is no 4th word in the vocabulary 129 | ``` 130 | 131 | ## Vocab:indicesOf(words, add) 132 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Vocab.lua#L122) 133 | 134 | `indexOf` on a table of words. 135 | 136 | Arguments: 137 | 138 | - `words ` (`table[string]`): words to query. 139 | - `add ` (`boolean`): whether to add new words to the vocabulary. Optional. 140 | 141 | Returns: 142 | 143 | - (`table[int]`) corresponding indices. 144 | 145 | Example: 146 | 147 | Suppose we have a vocabulary with words 'unk', 'foo', and 'bar' 148 | 149 | ```lua 150 | vocab:indicesOf{'foo', 'bar'} {2, 3} 151 | ``` 152 | 153 | ## Vocab:tensorIndicesOf(words, add) 154 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Vocab.lua#L144) 155 | 156 | `indexOf` on a table of words. 157 | 158 | Arguments: 159 | 160 | - `add ` (`boolean`): whether to add new words to the vocabulary. Optional. 161 | 162 | Returns: 163 | 164 | - (`torch.Tensor`) tensor of corresponding indices 165 | 166 | Example: 167 | 168 | Suppose we have a vocabulary with words 'unk', 'foo', and 'bar' 169 | 170 | {table[string]} words - words to query 171 | 172 | ```lua 173 | vocab:tensorIndicesOf{'foo', 'bar'} torch.Tensor{2, 3} 174 | vocab:tensorIndicesOf{'foo', 'hi'} torch.Tensor{2, 1}, because `hi` is not in the vocabulary 175 | ``` 176 | 177 | ## Vocab:wordsAt(indices) 178 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Vocab.lua#L165) 179 | 180 | `wordAt` on a table of indices. 181 | 182 | Arguments: 183 | 184 | - `indices ` (`table[int]`): indices to query. 185 | 186 | Returns: 187 | 188 | - (`table[string]`) corresponding words 189 | 190 | Example: 191 | 192 | Suppose we have a vocabulary with words 'unk', 'foo', and 'bar' 193 | 194 | ```lua 195 | vocab:wordsAt{1, 3} {'unk', 'bar'} 196 | vocab:wordsAt{1, 4} raises an error because there is no 4th word 197 | ``` 198 | 199 | ## Vocab:tensorWordsAt(indices) 200 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Vocab.lua#L182) 201 | 202 | `wordAt` on a tensor of indices. Returns a table of corresponding words. 203 | 204 | Example: 205 | 206 | Suppose we have a vocabulary with words 'unk', 'foo', and 'bar' 207 | 208 | ```lua 209 | vocab:tensorWordsAt(torch.Tensor{1, 3}) {'unk', 'bar'} 210 | vocab:tensorWordsAt(torch.Tensor{1, 4}) raises an error because there is no 4th word 211 | ``` 212 | 213 | ## Vocab:copyAndPruneRares(cutoff) 214 | [View source](http://github.com/vzhong/torchlib/blob/master/src//ml/Vocab.lua#L201) 215 | 216 | Returns a new vocabulary with words occurring less than `cutoff` times removed. 217 | 218 | Arguments: 219 | 220 | - `cutoff ` (`int`): words with frequency below this number will be removed from the vocabulary. 221 | 222 | Returns: 223 | 224 | - (`Vocab`) modified vocabulary 225 | 226 | Example: 227 | 228 | Suppose we want to forget all words that occurred less than 5 times: 229 | 230 | ```lua 231 | smaller_vocab = orig_vocab:copyAndPruneRares(5) 232 | ``` 233 | 234 | -------------------------------------------------------------------------------- /docs/set/Set.md: -------------------------------------------------------------------------------- 1 | # Set 2 | Implementation of set. 3 | 4 | 5 | 6 | 7 | ## Set:\_\_init(values) 8 | [View source](http://github.com/vzhong/torchlib/blob/master/src//set/Set.lua#L9) 9 | 10 | Constructor. 11 | 12 | Arguments: 13 | 14 | - `values ` (`table[any]`): used to initialize the set. Optional. 15 | 16 | 17 | ## Set.keyOf(val) 18 | [View source](http://github.com/vzhong/torchlib/blob/master/src//set/Set.lua#L18) 19 | 20 | 21 | 22 | Arguments: 23 | 24 | - `val ` (`any`): value to produce a key for. 25 | 26 | Returns: 27 | 28 | - (`torch.pointer`) unique key for the value 29 | 30 | ## Set:size() 31 | [View source](http://github.com/vzhong/torchlib/blob/master/src//set/Set.lua#L27) 32 | 33 | 34 | 35 | Returns: 36 | 37 | - (`int`) number of values in the set 38 | 39 | ## Set:add(val) 40 | [View source](http://github.com/vzhong/torchlib/blob/master/src//set/Set.lua#L34) 41 | 42 | Adds a value to the set. 43 | 44 | Arguments: 45 | 46 | - `val ` (`any`): value to add to the set. 47 | 48 | Returns: 49 | 50 | - (`Set`) modified set 51 | 52 | ## Set:addMany(...) 53 | [View source](http://github.com/vzhong/torchlib/blob/master/src//set/Set.lua#L46) 54 | 55 | Adds a variable number of values to the set. 56 | 57 | Arguments: 58 | 59 | - `vararg ` (`vararg`): values to add to the set. 60 | 61 | Returns: 62 | 63 | - (`Set`) modified set 64 | 65 | ## Set:copy() 66 | [View source](http://github.com/vzhong/torchlib/blob/master/src//set/Set.lua#L55) 67 | 68 | 69 | 70 | Returns: 71 | 72 | - (`Set`) copy of the set 73 | 74 | ## Set:contains(val) 75 | [View source](http://github.com/vzhong/torchlib/blob/master/src//set/Set.lua#L61) 76 | 77 | 78 | 79 | Arguments: 80 | 81 | - `val ` (`any`): value to check for. 82 | 83 | Returns: 84 | 85 | - (`boolean`) whether the set contains `val` 86 | 87 | ## Set:remove(val) 88 | [View source](http://github.com/vzhong/torchlib/blob/master/src//set/Set.lua#L69) 89 | 90 | 91 | 92 | Arguments: 93 | 94 | - `val ` (`any`): value to remove from the set. 95 | 96 | Returns: 97 | 98 | - (`Set`) modified set 99 | If `val` is not found then an error is raised. 100 | 101 | ## Set:totable() 102 | [View source](http://github.com/vzhong/torchlib/blob/master/src//set/Set.lua#L78) 103 | 104 | 105 | 106 | Returns: 107 | 108 | - (`tabl`) the set in table format 109 | 110 | ## Set:equals(another) 111 | [View source](http://github.com/vzhong/torchlib/blob/master/src//set/Set.lua#L89) 112 | 113 | Compares two sets. 114 | 115 | Arguments: 116 | 117 | - `another ` (`Set`): another set. 118 | 119 | Returns: 120 | 121 | - (`boolean`) whether this set and `another` contain the same values 122 | 123 | ## Set:union(another) 124 | [View source](http://github.com/vzhong/torchlib/blob/master/src//set/Set.lua#L105) 125 | 126 | Computes the union of two sets. 127 | 128 | Arguments: 129 | 130 | - `another ` (`Set`): another set. 131 | 132 | Returns: 133 | 134 | - (`Set`) a set of values that are in this set or in `another` 135 | 136 | ## Set:intersect(another) 137 | [View source](http://github.com/vzhong/torchlib/blob/master/src//set/Set.lua#L116) 138 | 139 | Computes the intersection of two sets. 140 | 141 | Arguments: 142 | 143 | - `another ` (`Set`): another set. 144 | 145 | Returns: 146 | 147 | - (`Set`) a set of values that are in this set and in `another` 148 | 149 | ## Set:subtract(another) 150 | [View source](http://github.com/vzhong/torchlib/blob/master/src//set/Set.lua#L129) 151 | 152 | Subtracts another set from this one. 153 | 154 | Arguments: 155 | 156 | - `another ` (`Set`): another set. 157 | 158 | Returns: 159 | 160 | - (`Set`) a set of values that are in this set but not in `another` 161 | 162 | ## Set:\_\_tostring\_\_() 163 | [View source](http://github.com/vzhong/torchlib/blob/master/src//set/Set.lua#L140) 164 | 165 | 166 | 167 | Returns: 168 | 169 | - (`string`) string representation 170 | 171 | -------------------------------------------------------------------------------- /docs/tree/BinarySearchTree.md: -------------------------------------------------------------------------------- 1 | # BinarySearchTree.Node 2 | A node in the binary search tree. 3 | This is a subclass of `BinaryTree.Node`. 4 | 5 | 6 | 7 | 8 | ## BinarySearchTreeNode:search(key) 9 | [View source](http://github.com/vzhong/torchlib/blob/master/src//tree/BinarySearchTree.lua#L13) 10 | 11 | Searches for a key in the BST. 12 | 13 | Arguments: 14 | 15 | - `key ` (`number`): the key to retrieve. 16 | 17 | Returns: 18 | 19 | - (`BinarySearchTree.Node`) the node with the requested key 20 | 21 | ## BinarySearchTreeNode:min() 22 | [View source](http://github.com/vzhong/torchlib/blob/master/src//tree/BinarySearchTree.lua#L28) 23 | 24 | 25 | 26 | Returns: 27 | 28 | - (`int`) the minimum node of the subtree rooted at this node. 29 | 30 | ## BinarySearchTreeNode:max() 31 | [View source](http://github.com/vzhong/torchlib/blob/master/src//tree/BinarySearchTree.lua#L37) 32 | 33 | 34 | 35 | Returns: 36 | 37 | - (`int`) the maximum node of the subtree rooted at this node. 38 | 39 | ## BinarySearchTreeNode:successor() 40 | [View source](http://github.com/vzhong/torchlib/blob/master/src//tree/BinarySearchTree.lua#L46) 41 | 42 | 43 | 44 | Returns: 45 | 46 | - (`BinarySearchTre.Node`) the node with the smallest key that is larger than this one. 47 | 48 | ## BinarySearchTreeNode:predecessor() 49 | [View source](http://github.com/vzhong/torchlib/blob/master/src//tree/BinarySearchTree.lua#L61) 50 | 51 | 52 | 53 | Returns: 54 | 55 | - (`BinarySearchTre.Node`) the node with the largest key that is smaller than this one. 56 | 57 | # BinarySearchTree 58 | Binary Search Tree. An implementation of `BinaryTree`. 59 | 60 | Example: 61 | 62 | 63 | 64 | ```lua 65 | local t = BinarySearchTree.new() 66 | t:insert(BinarySearchTreeNode.new(12)) 67 | t:insert(BinarySearchTreeNode.new(5)) 68 | t:insert(BinarySearchTreeNode.new(2)) 69 | t:insert(BinarySearchTreeNode.new(9)) 70 | t:insert(BinarySearchTreeNode.new(18)) 71 | t:insert(BinarySearchTreeNode.new(15)) 72 | t:insert(BinarySearchTreeNode.new(13)) 73 | t:insert(BinarySearchTreeNode.new(17)) 74 | t:insert(BinarySearchTreeNode.new(19)) 75 | print(t) 76 | ``` 77 | 78 | ## BinarySearchTree:insert(node) 79 | [View source](http://github.com/vzhong/torchlib/blob/master/src//tree/BinarySearchTree.lua#L97) 80 | 81 | Inserts a node into the tree. 82 | 83 | Arguments: 84 | 85 | - `node ` (`BinarySearchTree.Node`): node to insert. 86 | 87 | Returns: 88 | 89 | - (`BinarySearchTree`) modified tree 90 | 91 | ## BinarySearchTree:search(key) 92 | [View source](http://github.com/vzhong/torchlib/blob/master/src//tree/BinarySearchTree.lua#L118) 93 | 94 | 95 | 96 | Arguments: 97 | 98 | - `key ` (`number`): key to search for. 99 | 100 | Returns: 101 | 102 | - (`BinarySearchTree.Node`) node with the requested key 103 | 104 | ## BinarySearchTree:min() 105 | [View source](http://github.com/vzhong/torchlib/blob/master/src//tree/BinarySearchTree.lua#L123) 106 | 107 | 108 | 109 | Returns: 110 | 111 | - (`BinarySearchTree.Node`) node with the minimum key 112 | 113 | ## BinarySearchTree:max() 114 | [View source](http://github.com/vzhong/torchlib/blob/master/src//tree/BinarySearchTree.lua#L128) 115 | 116 | 117 | 118 | Returns: 119 | 120 | - (`BinarySearchTree.Node`) node with the maximum key 121 | 122 | ## BinarySearchTree:transplant(old, new) 123 | [View source](http://github.com/vzhong/torchlib/blob/master/src//tree/BinarySearchTree.lua#L136) 124 | 125 | Replaces the subtree rooted at `old` with the one rooted at `new`. 126 | 127 | Arguments: 128 | 129 | - `old ` (`BinarySearchTree.Node`): node to replace. 130 | - `new ` (`BinarySearchTree.Node`): new node to use. 131 | 132 | Returns: 133 | 134 | - (`BinarySearchTree`) modified tree 135 | 136 | ## BinarySearchTree:delete(node) 137 | [View source](http://github.com/vzhong/torchlib/blob/master/src//tree/BinarySearchTree.lua#L153) 138 | 139 | Deletes a node from the tree. 140 | 141 | Arguments: 142 | 143 | - `node ` (`BinarySearchTree.Node`): node to delete. 144 | 145 | Returns: 146 | 147 | - (`BinarySearchTree`) modified tree 148 | 149 | -------------------------------------------------------------------------------- /docs/tree/BinaryTree.md: -------------------------------------------------------------------------------- 1 | # BinaryTree.Node 2 | Node in a binary tree. 3 | This is a subclass of `Tree.Node` 4 | 5 | 6 | 7 | 8 | ## BinaryTreeNode:\_\_init(key, val) 9 | [View source](http://github.com/vzhong/torchlib/blob/master/src//tree/BinaryTree.lua#L10) 10 | 11 | Constructor 12 | 13 | 14 | ## BinaryTreeNode:children() 15 | [View source](http://github.com/vzhong/torchlib/blob/master/src//tree/BinaryTree.lua#L17) 16 | 17 | 18 | 19 | Returns: 20 | 21 | - (`table`) children of this node 22 | 23 | ## BinaryTreeNode:walkInOrder(callback) 24 | [View source](http://github.com/vzhong/torchlib/blob/master/src//tree/BinaryTree.lua#L26) 25 | 26 | Traverses the tree in order. 27 | 28 | Arguments: 29 | 30 | - `callback ` (`function`): function to execute at each node. Optional. 31 | 32 | 33 | # BinaryTree 34 | Implementation of binary tree. 35 | This is a subclass of `Tree`. 36 | 37 | 38 | 39 | 40 | ## BinaryTree:\_\_init() 41 | [View source](http://github.com/vzhong/torchlib/blob/master/src//tree/BinaryTree.lua#L43) 42 | 43 | Constructor. 44 | 45 | 46 | ## BinaryTree:walkInOrder(callback) 47 | [View source](http://github.com/vzhong/torchlib/blob/master/src//tree/BinaryTree.lua#L50) 48 | 49 | Traverses the binary tree starting from the root in order 50 | 51 | Arguments: 52 | 53 | - `callback ` (`function`): function to execute at each node. Optional. 54 | 55 | 56 | -------------------------------------------------------------------------------- /docs/tree/Tree.md: -------------------------------------------------------------------------------- 1 | # Tree 2 | Implementation of tree. 3 | 4 | 5 | 6 | 7 | ## TreeNode:\_\_init(key, val) 8 | [View source](http://github.com/vzhong/torchlib/blob/master/src//tree/Tree.lua#L9) 9 | 10 | Constructor. 11 | 12 | 13 | ## TreeNode:children() 14 | [View source](http://github.com/vzhong/torchlib/blob/master/src//tree/Tree.lua#L18) 15 | 16 | 17 | 18 | Returns: 19 | 20 | - (`table`) children of this node 21 | 22 | ## TreeNode:\_\_tostring\_\_() 23 | [View source](http://github.com/vzhong/torchlib/blob/master/src//tree/Tree.lua#L23) 24 | 25 | 26 | 27 | Returns: 28 | 29 | - (`string`) string representation 30 | 31 | ## TreeNode:subtreeToString(prefix, isLeaf) 32 | [View source](http://github.com/vzhong/torchlib/blob/master/src//tree/Tree.lua#L30) 33 | 34 | 35 | 36 | Arguments: 37 | 38 | - `prefix ` (`string`): string to add before each line. 39 | - `isLeaf ` (`boolean`): whether the subtree is a leaf. 40 | 41 | Returns: 42 | 43 | - (`string`) string representation 44 | 45 | ## Tree:\_\_tostring\_\_() 46 | [View source](http://github.com/vzhong/torchlib/blob/master/src//tree/Tree.lua#L49) 47 | 48 | 49 | 50 | Returns: 51 | 52 | - (`string`) string representation 53 | 54 | ## Tree:size() 55 | [View source](http://github.com/vzhong/torchlib/blob/master/src//tree/Tree.lua#L58) 56 | 57 | 58 | 59 | Returns: 60 | 61 | - (`int`) number of nodes in the tree 62 | 63 | -------------------------------------------------------------------------------- /docs/util/Download.md: -------------------------------------------------------------------------------- 1 | # Downloader 2 | A download utility with caching support. 3 | 4 | 5 | 6 | 7 | ## Downloader:\_\_init(cache, opt) 8 | [View source](http://github.com/vzhong/torchlib/blob/master/src//util/Download.lua#L15) 9 | 10 | Constructor. 11 | 12 | Arguments: 13 | 14 | - `cache ` (`string`): cache directory. Optional, Default: `'/tmp/torchlib'`. 15 | 16 | Options: 17 | 18 | - `verbose`: prints out progress 19 | 20 | ## Downloader:get(to, url, opt) 21 | [View source](http://github.com/vzhong/torchlib/blob/master/src//util/Download.lua#L33) 22 | 23 | Retrieves a file from cache, downloading it from `url` if it doesn't exists. 24 | 25 | Arguments: 26 | 27 | - `to ` (`string`): location to download to, relative to the cache directory. 28 | - `url ` (`string`): url to download from. Optional. 29 | - `opt ` (`table[string:any]`): options. Optional. 30 | 31 | Options: 32 | 33 | - `force`: overwrite the file if one exists. 34 | 35 | -------------------------------------------------------------------------------- /docs/util/global.md: -------------------------------------------------------------------------------- 1 | ## tl.range(from, to, inc) 2 | [View source](http://github.com/vzhong/torchlib/blob/master/src//util/global.lua#L5) 3 | 4 | 5 | 6 | Arguments: 7 | 8 | - `from ` (`int`): start index. 9 | - `end ` (`int`): end index. Optional, Default: `end`. 10 | - `inc ` (`int`): value to increment by. Optional, Default: `1`. 11 | 12 | Returns: 13 | 14 | - (`table`) indices from `from` to `to`, incrementing by `inc` 15 | 16 | ## tl.equals(a, b) 17 | [View source](http://github.com/vzhong/torchlib/blob/master/src//util/global.lua#L22) 18 | 19 | 20 | 21 | Arguments: 22 | 23 | - `a ` (`table`): first object. 24 | - `b ` (`table`): second object. 25 | 26 | Returns: 27 | 28 | - (`boolean`) whether the two objects are equal to each other 29 | 30 | ## tl.deepcopy(t) 31 | [View source](http://github.com/vzhong/torchlib/blob/master/src//util/global.lua#L34) 32 | 33 | 34 | 35 | Arguments: 36 | 37 | - `t ` (`any`): object to copy. 38 | 39 | Returns: 40 | 41 | - (`any`) deep copy 42 | 43 | from https://gist.github.com/MihailJP/3931841 44 | 45 | ## tl.copy(t) 46 | [View source](http://github.com/vzhong/torchlib/blob/master/src//util/global.lua#L51) 47 | 48 | 49 | 50 | Arguments: 51 | 52 | - `t ` (`any`): object to copy. 53 | 54 | Returns: 55 | 56 | - (`any`) shallow copy 57 | 58 | -------------------------------------------------------------------------------- /docs/util/string.md: -------------------------------------------------------------------------------- 1 | ## string.startswith(s, substring) 2 | [View source](http://github.com/vzhong/torchlib/blob/master/src//util/string.lua#L4) 3 | 4 | 5 | 6 | Arguments: 7 | 8 | - `s ` (`string`): larger string. 9 | - `substring ` (`string`): smaller string. 10 | 11 | Returns: 12 | 13 | - (`boolean`) whether the larger string starts with the smaller string 14 | 15 | ## string.endswith(s, substring) 16 | [View source](http://github.com/vzhong/torchlib/blob/master/src//util/string.lua#L11) 17 | 18 | 19 | 20 | Arguments: 21 | 22 | - `s ` (`string`): larger string. 23 | - `substring ` (`string`): smaller string. 24 | 25 | Returns: 26 | 27 | - (`boolean`) whether the larger string ends with the smaller string 28 | 29 | -------------------------------------------------------------------------------- /docs/util/table.md: -------------------------------------------------------------------------------- 1 | ## table.tostring(t, indent, s) 2 | [View source](http://github.com/vzhong/torchlib/blob/master/src//util/table.lua#L5) 3 | 4 | 5 | 6 | Arguments: 7 | 8 | - `t ` (`table`): a table. 9 | - `indent ` (`string`): indentation for nested keys. Optional. 10 | - `s ` (`string`): accumulated string. Optional. 11 | 12 | Returns: 13 | 14 | - (`string`) string representation for the table 15 | 16 | ## table.shuffle(t) 17 | [View source](http://github.com/vzhong/torchlib/blob/master/src//util/table.lua#L22) 18 | 19 | 20 | 21 | Arguments: 22 | 23 | - `t ` (`table`): table to shuffle in place. 24 | 25 | Returns: 26 | 27 | - (`table`) shuffled table 28 | 29 | ## table.equals(t1, t2) 30 | [View source](http://github.com/vzhong/torchlib/blob/master/src//util/table.lua#L35) 31 | 32 | 33 | 34 | Arguments: 35 | 36 | - `t1 ` (`table[any]`): first table. 37 | - `t2 ` (`table[any]`): seoncd table. 38 | 39 | Returns: 40 | 41 | - (`boolean`) whether the keys and values of each table are equal 42 | 43 | ## table.valuesEqual(t1, t2) 44 | [View source](http://github.com/vzhong/torchlib/blob/master/src//util/table.lua#L52) 45 | 46 | 47 | 48 | Arguments: 49 | 50 | - `t1 ` (`table[any]`): first table. 51 | - `t2 ` (`table[any]`): seoncd table. 52 | 53 | Returns: 54 | 55 | - (`boolean`) whether the values of each table are equal, disregarding order 56 | 57 | ## table.reverse(t) 58 | [View source](http://github.com/vzhong/torchlib/blob/master/src//util/table.lua#L68) 59 | 60 | 61 | 62 | Arguments: 63 | 64 | - `t ` (`table`): table to reverse. 65 | 66 | Returns: 67 | 68 | - (`table`) A copy of the table, reversed. 69 | 70 | ## table.contains(t, val) 71 | [View source](http://github.com/vzhong/torchlib/blob/master/src//util/table.lua#L79) 72 | 73 | 74 | 75 | Arguments: 76 | 77 | - `t ` (`table`): table to check. 78 | - `val ` (`any`): value to check. 79 | 80 | Returns: 81 | 82 | - (`boolean`) whether the tabale contains the value 83 | 84 | ## table.flatten(t, tab, prefix) 85 | [View source](http://github.com/vzhong/torchlib/blob/master/src//util/table.lua#L93) 86 | 87 | Flattens the table. 88 | 89 | Arguments: 90 | 91 | - `t ` (`table`): the table to modify. 92 | - `tab ` (`table`): where to store the results. If not given, then a new table will be used. Optional. 93 | - `prefix ` (`string`): string to use to join nested keys. Optional, Default: `'__'`. 94 | 95 | Returns: 96 | 97 | - (`table`) flattened table 98 | 99 | ## table.map(t, callback) 100 | [View source](http://github.com/vzhong/torchlib/blob/master/src//util/table.lua#L110) 101 | 102 | Applies `callback` to each element in `t` and returns the results in another table. 103 | 104 | Arguments: 105 | 106 | - `t ` (`table`): the table to modify. 107 | - `callback ` (`function`): function to apply. 108 | 109 | Returns: 110 | 111 | - (`table`) modified table 112 | 113 | ## table.select(t, keys, forget\_keys) 114 | [View source](http://github.com/vzhong/torchlib/blob/master/src//util/table.lua#L125) 115 | 116 | Selects items from table `t`. 117 | 118 | Arguments: 119 | 120 | - `t ` (`table`): table to select from. 121 | - `keys ` (`table`): table of keys. 122 | - `forget_keys ` (`boolean`): whether to retain the keys. Optional. 123 | 124 | Returns: 125 | 126 | - (`table`) a table of key value pairs where the keys are `keys` and the values are corresponding values from `t`. 127 | 128 | If `forget_keys` is `true`, then the returned table will have integer keys. 129 | 130 | ## table.extend(t, another) 131 | [View source](http://github.com/vzhong/torchlib/blob/master/src//util/table.lua#L141) 132 | 133 | Extends the table `t` with another table `another` 134 | 135 | Arguments: 136 | 137 | - `t ` (`table`): first table. 138 | - `another ` (`table`): second table. 139 | 140 | Returns: 141 | 142 | - (`table`) modified first table 143 | 144 | ## table.combinations(input) 145 | [View source](http://github.com/vzhong/torchlib/blob/master/src//util/table.lua#L159) 146 | 147 | Returns all combinations of elements in a table. 148 | 149 | Arguments: 150 | 151 | - `input ` (`table[table[any]]`): a collection of lists to compute the combination for. 152 | 153 | Returns: 154 | 155 | - (`table[table[any]]`) combinations of the input 156 | 157 | Example: 158 | 159 | ```lua 160 | table.combinations{{1, 2}, {'a', 'b', 'c'}} 161 | 162 | This returns `{{1, 'a'}, {1, 'b'}, {1, 'c'}, {2, 'a'}, {2, 'b'}, {2, 'c'}}` 163 | ``` 164 | 165 | -------------------------------------------------------------------------------- /init.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | 3 | tl = {} 4 | 5 | local Object = torch.class('tl.Object') 6 | function Object:__tostring__() 7 | return torch.type(self) 8 | end 9 | 10 | require('torchlib/src/util/global') 11 | require('torchlib/src/util/table') 12 | require('torchlib/src/util/string') 13 | require('torchlib/src/util/Download') 14 | tl.table = table 15 | tl.string = string 16 | 17 | require('torchlib/src/list/List') 18 | require('torchlib/src/list/ArrayList') 19 | require('torchlib/src/list/LinkedList') 20 | require('torchlib/src/list/Queue') 21 | require('torchlib/src/list/Heap') 22 | require('torchlib/src/list/Stack') 23 | 24 | require('torchlib/src/tree/Tree') 25 | require('torchlib/src/tree/BinaryTree') 26 | require('torchlib/src/tree/BinarySearchTree') 27 | 28 | require('torchlib/src/set/Set') 29 | require('torchlib/src/map/Map') 30 | require('torchlib/src/map/HashMap') 31 | require('torchlib/src/map/Counter') 32 | 33 | require('torchlib/src/graph/Graph') 34 | require('torchlib/src/graph/DirectedGraph') 35 | require('torchlib/src/graph/UndirectedGraph') 36 | 37 | require('torchlib/src/ml/Dataset') 38 | require('torchlib/src/ml/Vocab') 39 | require('torchlib/src/ml/GloveVocab') 40 | require('torchlib/src/ml/Scorer') 41 | require('torchlib/src/ml/VariableTensor') 42 | require('torchlib/src/ml/ProbTable') 43 | require('torchlib/src/ml/Model') 44 | require('torchlib/src/ml/Experiment') 45 | 46 | return tl 47 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | copyright: Copyright 2016 Victor Zhong (victor@victorzhong.com) 2 | site_description: Documentation for torchlib 3 | repo_url: https://github.com/vzhong/torchlib 4 | site_author: Victor Zhong (victor@victorzhong.com) 5 | site_name: torchlib 6 | pages: 7 | - 'Home': 'index.md' 8 | - graph: 9 | - 'DirectedGraph': 'graph/DirectedGraph.md' 10 | - 'Graph': 'graph/Graph.md' 11 | - 'UndirectedGraph': 'graph/UndirectedGraph.md' 12 | - list: 13 | - 'ArrayList': 'list/ArrayList.md' 14 | - 'Heap': 'list/Heap.md' 15 | - 'LinkedList': 'list/LinkedList.md' 16 | - 'List': 'list/List.md' 17 | - 'Queue': 'list/Queue.md' 18 | - 'Stack': 'list/Stack.md' 19 | - map: 20 | - 'Counter': 'map/Counter.md' 21 | - 'HashMap': 'map/HashMap.md' 22 | - 'Map': 'map/Map.md' 23 | - ml: 24 | - 'Dataset': 'ml/Dataset.md' 25 | - 'Experiment': 'ml/Experiment.md' 26 | - 'GloveVocab': 'ml/GloveVocab.md' 27 | - 'Model': 'ml/Model.md' 28 | - 'ProbTable': 'ml/ProbTable.md' 29 | - 'Scorer': 'ml/Scorer.md' 30 | - 'VariableTensor': 'ml/VariableTensor.md' 31 | - 'Vocab': 'ml/Vocab.md' 32 | - set: 33 | - 'Set': 'set/Set.md' 34 | - tree: 35 | - 'BinarySearchTree': 'tree/BinarySearchTree.md' 36 | - 'BinaryTree': 'tree/BinaryTree.md' 37 | - 'Tree': 'tree/Tree.md' 38 | - util: 39 | - 'Download': 'util/Download.md' 40 | - 'global': 'util/global.md' 41 | - 'string': 'util/string.md' 42 | - 'table': 'util/table.md' 43 | -------------------------------------------------------------------------------- /src/graph/DirectedGraph.lua: -------------------------------------------------------------------------------- 1 | --- @module DirectedGraph 2 | -- A directed graph implementation. 3 | -- This is a subclass of `Graph`. 4 | local torch = require 'torch' 5 | local DirectedGraph = torch.class('tl.DirectedGraph', 'tl.Graph') 6 | local Set = tl.Set 7 | local Graph = tl.Graph 8 | 9 | --- Connects two nodes. 10 | -- @arg {Graph.Node} nodeA - starting node 11 | -- @arg {Graph.Node} nodeB - ending node 12 | function DirectedGraph:connect(nodeA, nodeB) 13 | self:assertValidNode(nodeA) 14 | self:assertValidNode(nodeB) 15 | self._nodeMap:get(nodeA):add(nodeB) 16 | end 17 | 18 | --- Returns nodes in this graph in topologically sorted order 19 | -- @returns {table} 20 | function DirectedGraph:topologicalSort() 21 | local ordered = {} 22 | local function callback(node) 23 | table.insert(ordered, 1, node) 24 | end 25 | self:depthFirstSearch(self:nodeSet():totable(), {finish=callback}) 26 | return ordered 27 | end 28 | 29 | --- Returns whether the graph has a cycle 30 | -- @returns {boolean} 31 | function DirectedGraph:hasCycle() 32 | self:resetState() 33 | local nodes = self:nodeSet():totable() 34 | 35 | local function DFSVisit(graph, node) 36 | node.state = Graph.state.VISITED 37 | local conns = self:connectionsOf(node) 38 | for i = 1, #conns do 39 | local conn = conns[i] 40 | if conn.state == Graph.state.VISITED then 41 | return true -- looped back to ancestor 42 | elseif conn.state == Graph.state.UNDISCOVERED then 43 | DFSVisit(graph, conn) 44 | end 45 | end 46 | node.state = Graph.state.FINISHED 47 | return false 48 | end 49 | 50 | for i = 1, #nodes do 51 | local node = nodes[i] 52 | if node.state == Graph.state.UNDISCOVERED then 53 | if DFSVisit(self, node) then 54 | return true 55 | end 56 | end 57 | end 58 | return false 59 | end 60 | 61 | --- Returns a transpose of this graph (eg. with the edges reversed) 62 | -- @returns {DirectedGraph} 63 | function DirectedGraph:transpose() 64 | local g = DirectedGraph.new() 65 | g._nodeMap = self._nodeMap:copy() 66 | local nodes = g:nodeSet():totable() 67 | -- clear out the connections first 68 | for i = 1, #nodes do 69 | local node = nodes[i] 70 | g._nodeMap:add(node, Set()) 71 | end 72 | -- add in transpose connections 73 | for i = 1, #nodes do 74 | local node = nodes[i] 75 | local conns = self:connectionsOf(node) 76 | for j = 1, #conns do 77 | local conn = conns[j] 78 | g:connect(conn, node) 79 | end 80 | end 81 | return g 82 | end 83 | 84 | --- Returns strongly connected components. 85 | -- Each strongly connected component is itself a table. 86 | -- @returns {table[table]} a table of strongly connected components. 87 | function DirectedGraph:stronglyConnectedComponents() 88 | local firstToLastFinish = self:topologicalSort() 89 | local roots = {} 90 | local discoverCallback = function(node) 91 | table.insert(roots, node) 92 | end 93 | self:depthFirstSearch(table.reverse(firstToLastFinish), {discover=discoverCallback}) 94 | return roots 95 | end 96 | 97 | return DirectedGraph 98 | -------------------------------------------------------------------------------- /src/graph/Graph.lua: -------------------------------------------------------------------------------- 1 | --- @module Graph 2 | -- Abstract graph implementation. 3 | -- 4 | -- A `Graph` consists of `GraphNode`s. Each `GraphNode` can be in three states: 5 | -- - `UNDISCOVERED` 6 | -- - `VISITED` 7 | -- - `FINISHED` 8 | 9 | local torch = require 'torch' 10 | local HashMap = tl.HashMap 11 | local Set = tl.Set 12 | local Queue = tl.Queue 13 | 14 | local Graph, parent = torch.class('tl.Graph', 'tl.Object') 15 | local GraphNode = torch.class('tl.Graph.Node', 'tl.Object') 16 | 17 | Graph.state = {UNDISCOVERED = 1, VISITED = 2, FINISHED = 3} 18 | 19 | --[[ Constructor for a node in the graph. 20 | 21 | Parameter: 22 | - `val`: the value for this node 23 | ]] 24 | --- Constructor. 25 | -- @arg {any} val - value for the new node. 26 | function GraphNode:__init(val) 27 | self.val = val 28 | end 29 | 30 | --- @returns {string} string representation 31 | function GraphNode:__tostring__() 32 | return parent.__tostring__(self) .. '(' .. self.val .. ')' 33 | end 34 | 35 | --- Constructor. 36 | function Graph:__init() 37 | self._nodeMap = HashMap.new() 38 | end 39 | 40 | --- @returns {int} number of nodes in the graph. 41 | function Graph:size() 42 | return self._nodeMap:size() 43 | end 44 | 45 | --- Verifies that the node is in the graph 46 | -- @arg {Graph.Node} node - the node to verify. 47 | function Graph:assertValidNode(node) 48 | assert(self._nodeMap:contains(node), 'Error: node ' .. tostring(node.val) .. ' is not in graph') 49 | end 50 | 51 | --- Adds a node with given value to the graph. 52 | -- @arg {any} val - value for the new node. 53 | -- @returns {Graph.Node} 54 | function Graph:addNode(val) 55 | local node = GraphNode.new(val) 56 | self._nodeMap:add(node, Set()) 57 | return node 58 | end 59 | 60 | --- Returns neighbours of a given node. 61 | -- @arg {Graph.Node} node - the node to find neighbours for. 62 | -- @returns {table(Graph.Node)} 63 | function Graph:connectionsOf(node) 64 | self:assertValidNode(node) 65 | return self._nodeMap:get(node):totable() 66 | end 67 | 68 | --- Returns a set of nodes in the graph. 69 | -- @returns {Set(Graph.Node)} 70 | function Graph:nodeSet() 71 | return self._nodeMap:keys() 72 | end 73 | 74 | --- Initializes all nodes to `Graph.state.UNDISCOVERED`. 75 | -- @returns {Graph} 76 | -- 77 | -- The graph will be returned 78 | function Graph:resetState() 79 | local nodes = self:nodeSet():totable() 80 | for i = 1, #nodes do 81 | local node = nodes[i] 82 | node.state = Graph.state.UNDISCOVERED 83 | node.timestamp = math.huge 84 | node.parent = nil 85 | end 86 | return self 87 | end 88 | 89 | --- Performs breadth first search. 90 | -- 91 | -- @arg {Graph.Node} source - the source node to start BFS 92 | -- @arg {table[string:function]=} callbacks - a map with optional callbacks 93 | -- 94 | -- Available callbacks: 95 | -- 96 | -- - `discover = function(Graph.Node)`: called when a node is initially encountered 97 | -- 98 | -- - `finish = function(Graph.Node)`: called when a node has been fully explored (eg. its connected nodes have all been visited) 99 | function Graph:breadthFirstSearch(source, callbacks) 100 | callbacks = callbacks or {} 101 | callbacks.discover = callbacks.discover or function(node) end 102 | callbacks.finish = callbacks.finish or function(node) end 103 | self:resetState() 104 | source.state = Graph.state.VISITED 105 | callbacks.discover(source) 106 | source.timestamp = 0 107 | 108 | local q = Queue.new() 109 | q:enqueue(source) 110 | 111 | while not q:isEmpty() do 112 | local node = q:dequeue() 113 | local conns = self:connectionsOf(node) 114 | for i = 1, #conns do 115 | local conn = conns[i] 116 | if conn.state == Graph.state.UNDISCOVERED then 117 | conn.state = Graph.state.VISITED 118 | callbacks.discover(conn) 119 | conn.timestamp = node.timestamp + 1 120 | conn.parent = node 121 | q:enqueue(conn) 122 | end 123 | end 124 | node.state = Graph.state.FINISHED 125 | callbacks.finish(node) 126 | end 127 | end 128 | 129 | --- Returns the shortest path from source to destination 130 | -- 131 | -- @arg {Graph.Node} source - starting node 132 | -- @arg {Graph.Node} destination - end node 133 | -- @arg {boolean=} skipBFS - whether BFS has already been performned 134 | -- 135 | -- Note: This function relies on the results from a BFS call. By default, a BFS is performed before 136 | -- retrieving the shortest path. Alternatively, if the caller has already performed BFS, then 137 | -- this BFS can be skipped by passing in `skipBFS = true`. 138 | function Graph:shortestPath(source, destination, skipBFS) 139 | if skipBFS ~= true then 140 | self:breadthFirstSearch(source) 141 | end 142 | local function recur(graph, source, destination, path) 143 | if source == destination then 144 | table.insert(path, destination) 145 | elseif destination.parent == nil then 146 | error('Error: no path from ' .. tostring(source) .. ' to ' .. tostring(destination)) 147 | else 148 | recur(graph, source, destination.parent, path) 149 | table.insert(path, destination) 150 | end 151 | end 152 | local path = {} 153 | recur(self, source, destination, path) 154 | return path 155 | end 156 | 157 | --- Performs depth first search. 158 | -- 159 | -- @arg {table[Graph.Node]} nodes - the table of nodes on which to perform DFS. If not set, then all nodes in the graph are used 160 | -- @arg {table[string:function]=} callbacks - a map with optional callbacks 161 | -- 162 | -- Available callbacks: 163 | -- 164 | -- - `discover = function(Graph.Node)`: called when a node is initially encountered 165 | -- 166 | -- - `finish = function(Graph.Node)`: called when a node has been fully explored (eg. its connected nodes have all been visited) 167 | function Graph:depthFirstSearch(nodes, callbacks) 168 | callbacks = callbacks or {} 169 | callbacks.discover = callbacks.discover or function(node) end 170 | callbacks.finish = callbacks.finish or function(node) end 171 | self:resetState() 172 | local timestamp = 0 173 | nodes = nodes or self:nodeSet():totable() 174 | 175 | local function DFSVisit(graph, node) 176 | timestamp = timestamp + 1 177 | node.timestamp = timestamp 178 | node.state = Graph.state.VISITED 179 | callbacks.discover(node) 180 | local conns = self:connectionsOf(node) 181 | for i = 1, #conns do 182 | local conn = conns[i] 183 | if conn.state == Graph.state.UNDISCOVERED then 184 | conn.parent = node 185 | DFSVisit(graph, conn) 186 | end 187 | end 188 | node.state = Graph.state.FINISHED 189 | timestamp = timestamp + 1 190 | node.finishTime = timestamp 191 | callbacks.finish(node) 192 | end 193 | 194 | for i = 1, #nodes do 195 | local node = nodes[i] 196 | if node.state == Graph.state.UNDISCOVERED then 197 | DFSVisit(self, node) 198 | end 199 | end 200 | end 201 | 202 | return Graph 203 | -------------------------------------------------------------------------------- /src/graph/UndirectedGraph.lua: -------------------------------------------------------------------------------- 1 | --- @module UndirectedGraph 2 | -- Undirected graph implementation 3 | -- This is a subclass of `Graph`. 4 | 5 | local torch = require 'torch' 6 | local UndirectedGraph = torch.class('tl.UndirectedGraph', 'tl.Graph') 7 | 8 | --- Connects two nodes. 9 | -- @arg {Graph.Node} nodeA - starting node 10 | -- @arg {Graph.Node} nodeB - ending node 11 | function UndirectedGraph:connect(nodeA, nodeB) 12 | self:assertValidNode(nodeA) 13 | self:assertValidNode(nodeB) 14 | self._nodeMap:get(nodeA):add(nodeB) 15 | self._nodeMap:get(nodeB):add(nodeA) 16 | end 17 | 18 | return UndirectedGraph 19 | -------------------------------------------------------------------------------- /src/list/ArrayList.lua: -------------------------------------------------------------------------------- 1 | --- @module ArrayList 2 | -- Array list implementation. 3 | -- This is a subclass of `List`. 4 | 5 | local ArrayList = torch.class('tl.ArrayList', 'tl.List') 6 | 7 | function ArrayList:__init(values) 8 | values = values or {} 9 | self._arr = tl.copy(values) 10 | self._size = #self._arr 11 | end 12 | 13 | function ArrayList:add(val, index) 14 | if index == nil then 15 | table.insert(self._arr, val) 16 | else 17 | self:assertValidIndex(index) 18 | table.insert(self._arr, index, val) 19 | end 20 | self._size = self._size + 1 21 | return self 22 | end 23 | 24 | function ArrayList:get(index) 25 | self:assertValidIndex(index) 26 | return self._arr[index] 27 | end 28 | 29 | function ArrayList:set(index, val) 30 | self:assertValidIndex(index) 31 | self._arr[index] = val 32 | return self 33 | end 34 | 35 | function ArrayList:remove(index) 36 | self:assertValidIndex(index) 37 | self._size = self._size - 1 38 | return table.remove(self._arr, index) 39 | end 40 | 41 | function ArrayList:equals(another) 42 | if self:size() ~= another:size() then return false end 43 | for i = 1, self:size() do 44 | if self:get(i) ~= another:get(i) then return false end 45 | end 46 | return true 47 | end 48 | 49 | function ArrayList:swap(i, j) 50 | self:assertValidIndex(i) 51 | self:assertValidIndex(j) 52 | temp = self._arr[i] 53 | self._arr[i] = self._arr[j] 54 | self._arr[j] = temp 55 | return self 56 | end 57 | 58 | function ArrayList:totable() 59 | tab = {} 60 | for i = 1, self:size() do 61 | table.insert(tab, self._arr[i]) 62 | end 63 | return tab 64 | end 65 | 66 | return ArrayList 67 | -------------------------------------------------------------------------------- /src/list/Heap.lua: -------------------------------------------------------------------------------- 1 | --- @module Heap 2 | -- Max heap implementation. 3 | -- This is a subclass of `List`. 4 | 5 | local Heap, parent = torch.class('tl.Heap', 'tl.ArrayList') 6 | local Object = tl.Object 7 | 8 | --- @returns {int} parent index of `i` 9 | -- @arg {int} i - index to compute parent for 10 | function Heap.parent(i) 11 | return math.floor(i/2) 12 | end 13 | 14 | --- @returns {int} left child index of `i` 15 | -- @arg {int} i - index to compute left child for 16 | function Heap.left(i) 17 | return 2 * i 18 | end 19 | 20 | --- @returns {int} right child index of `i` 21 | -- @arg {int} i - index to compute right child for 22 | function Heap.right(i) 23 | return 2 * i + 1 24 | end 25 | 26 | --- Restores max heap condition at the `i`th index. 27 | -- 28 | -- @arg {int} i - index at which to restore max heap condition 29 | -- @arg {int=size} effectiveSize - effective size of the heap (eg. number of valid elements) 30 | -- @returns {Heap} modified heap 31 | -- 32 | -- Recursively swaps down the node at `i` until the max heap condition is restored at `a[i]`. 33 | -- Note: this function assumes that the binary trees rooted at left and right are max heaps but 34 | -- `a[i]` may violate the max-heap condition. 35 | function Heap:maxHeapify(i, effectiveSize) 36 | local l = Heap.left(i) 37 | local r = Heap.right(i) 38 | local pi, vi = table.unpack(self._arr[i]) 39 | effectiveSize = effectiveSize or self:size() 40 | local largest = i 41 | local plargest = pi 42 | if l <= effectiveSize then 43 | local pl, vl = table.unpack(self._arr[l]) 44 | if pl > plargest then 45 | plargest = pl 46 | largest = l 47 | end 48 | end 49 | if r <= effectiveSize then 50 | local pr, vr = table.unpack(self._arr[r]) 51 | if pr > plargest then 52 | -- plargest = pr 53 | largest = r 54 | end 55 | end 56 | 57 | if largest ~= i then 58 | self:swap(largest, i) 59 | self:maxHeapify(largest, effectiveSize) 60 | end 61 | end 62 | 63 | --- Sorts the heap using heap sort. 64 | -- @returns {Heap} sorted heap 65 | function Heap:sort() 66 | local effectiveSize = self:size() 67 | for i = self:size(), 2, -1 do 68 | self:swap(1, i) --move the largest to the end 69 | effectiveSize = effectiveSize - 1 70 | self:maxHeapify(1, effectiveSize) --swap the new head down 71 | end 72 | return self 73 | end 74 | 75 | --- Adds an element to the heap while keeping max heap property. 76 | -- @arg {number} key - priority to add with 77 | -- @arg {any} val - element to add to heap 78 | -- @returns {Heap} modified heap 79 | function Heap:push(key, val) 80 | if val == nil then val = key end 81 | self:add(table.pack(key, val), 1) 82 | self:maxHeapify(1) 83 | return self 84 | end 85 | 86 | --- Removes and returns the max priority element from the heap. 87 | -- @returns {any} removed element 88 | function Heap:pop() 89 | assert(not self:isEmpty(), 'Error: cannot pop from empty heap') 90 | self:swap(1, self:size()) 91 | local plargest, vlargest = table.unpack(self:remove(self:size())) 92 | if self:size() > 1 then 93 | self:maxHeapify(1) 94 | end 95 | return vlargest 96 | end 97 | 98 | --- @return {any} max priority element from the heap 99 | -- 100 | -- Note: the element is not removed. 101 | function Heap:peek() 102 | assert(not self:isEmpty(), 'Error: cannot peek from empty heap') 103 | local plargest, vlargest = table.unpack(self:get(1)) 104 | return vlargest 105 | end 106 | 107 | function Heap:__tostring__() 108 | local s = Object.__tostring__(self) .. '[' 109 | local max = 5 110 | for i = 1, math.min(self:size(), max) do 111 | local p, v = table.unpack(self:get(i)) 112 | s = s .. tostring(v) .. '(' .. p .. ')' 113 | if i == max then 114 | s = s .. ', ...' 115 | elseif i ~= self:size() then 116 | s = s .. ', ' 117 | end 118 | end 119 | s = s .. ']' 120 | return s 121 | end 122 | 123 | return Heap 124 | -------------------------------------------------------------------------------- /src/list/LinkedList.lua: -------------------------------------------------------------------------------- 1 | --- @module LinkedList 2 | -- Array list implementation. 3 | -- This is a subclass of `List`. 4 | 5 | local torch = require 'torch' 6 | local LinkedList = torch.class('tl.LinkedList', 'tl.List') 7 | LinkedList.Node = torch.class('tl.LinkedListNode') 8 | 9 | function LinkedList.Node:__init(val) 10 | self.val = val 11 | self.next = nil 12 | end 13 | 14 | function LinkedList.Node:__tostring__() 15 | return 'LinkedListNode(' .. self.val .. ')' 16 | end 17 | 18 | function LinkedList:__init(values) 19 | self._sentinel = LinkedList.Node.new() 20 | self._tail = self._sentinel 21 | self._size = 0 22 | values = values or {} 23 | for _, v in ipairs(values) do 24 | self:add(v) 25 | end 26 | end 27 | 28 | function LinkedList:size() 29 | return self._size 30 | end 31 | 32 | --- @returns {LinkedList.Node} head of the linked list 33 | function LinkedList:head() 34 | return self._sentinel.next 35 | end 36 | 37 | function LinkedList:add(val, index) 38 | local node = LinkedList.Node.new(val) 39 | if index == nil then 40 | self._tail.next = node 41 | self._tail = node 42 | else 43 | self:assertValidIndex(index) 44 | local count = 1 45 | local prev = self._sentinel 46 | local curr = self:head() 47 | while count ~= index do 48 | prev = curr 49 | curr = curr.next 50 | count = count + 1 51 | end 52 | prev.next = node 53 | prev.next.next = curr 54 | end 55 | self._size = self._size + 1 56 | return self 57 | end 58 | 59 | function LinkedList:get(index) 60 | self:assertValidIndex(index) 61 | local count = 1 62 | local curr = self:head() 63 | while count ~= index do 64 | curr = curr.next 65 | count = count + 1 66 | end 67 | return curr.val 68 | end 69 | 70 | function LinkedList:set(index, val) 71 | self:assertValidIndex(index) 72 | local count = 1 73 | local curr = self:head() 74 | while count ~= index do 75 | curr = curr.next 76 | count = count + 1 77 | end 78 | curr.val = val 79 | return self 80 | end 81 | 82 | function LinkedList:remove(index) 83 | self:assertValidIndex(index) 84 | local count = 1 85 | local prev = self._sentinel 86 | local curr = self:head() 87 | while count ~= index do 88 | prev = curr 89 | curr = curr.next 90 | count = count + 1 91 | end 92 | prev.next = curr.next 93 | if curr == self._tail then self._tail = prev end 94 | self._size = self._size - 1 95 | return curr.val 96 | end 97 | 98 | function LinkedList:swap(i, j) 99 | self:assertValidIndex(i) 100 | self:assertValidIndex(j) 101 | local count = 1 102 | local prev = self._sentinel 103 | local curr = self:head() 104 | local currI, prevI, currJ, prevJ 105 | while count <= math.max(i, j) do 106 | if count == i then 107 | prevI = prev 108 | currI = curr 109 | end 110 | if count == j then 111 | prevJ = prev 112 | currJ = curr 113 | end 114 | count = count + 1 115 | prev = curr 116 | curr = curr.next 117 | end 118 | assert(prevI) 119 | assert(currI) 120 | assert(prevJ) 121 | assert(currJ) 122 | prevI.next = currJ 123 | prevJ.next = currI 124 | local temp = currI.next 125 | currI.next = currJ.next 126 | currJ.next = temp 127 | return self 128 | end 129 | 130 | function LinkedList:equals(another) 131 | if self:size() ~= another:size() then return false end 132 | local curr = self:head() 133 | local currAnother = another:head() 134 | while curr ~= nil do 135 | if curr.val ~= currAnother.val then return false end 136 | curr = curr.next 137 | currAnother = currAnother.next 138 | end 139 | return true 140 | end 141 | 142 | function LinkedList:totable() 143 | local tab = {} 144 | local curr = self:head() 145 | while curr do 146 | table.insert(tab, curr.val) 147 | curr = curr.next 148 | end 149 | return tab 150 | end 151 | 152 | return LinkedList 153 | -------------------------------------------------------------------------------- /src/list/List.lua: -------------------------------------------------------------------------------- 1 | --- @module List 2 | -- Abstract list implementation. 3 | 4 | local torch = require 'torch' 5 | local List, parent = torch.class('tl.List', 'tl.Object') 6 | 7 | --- Constructor. 8 | -- @arg {table[any]=} values - used to initialize the list 9 | function List:__init(values) 10 | error('not implemented') 11 | end 12 | 13 | --- Adds element to list. 14 | -- @arg {any} val - value to add 15 | -- @arg {int=end} index - index to add value at 16 | -- @returns {List} - modified list 17 | function List:add(val, index) 18 | error('not implemented') 19 | end 20 | 21 | --- @arg {int} index - index to retrieve value for 22 | -- 23 | -- Asserts error if `index` is out of bounds. 24 | -- @returns {any} - value at index 25 | function List:get(index) 26 | error('not implemented') 27 | end 28 | 29 | --- Sets the value at index. 30 | -- @arg {int} index - inde to set value for 31 | -- @arg {any} val - value to set 32 | -- @returns {List} - modified list 33 | -- 34 | -- Asserts error if `index` is out of bounds. 35 | function List:set(index, val) 36 | error('not implemented') 37 | end 38 | 39 | --- @arg {int} index - index to remove at 40 | -- @returns {any} - value at index 41 | -- 42 | -- Elements after `index` will be shifted to the left by 1. 43 | -- Asserts error if `index` is out of bounds. 44 | function List:remove(index) 45 | error('not implemented') 46 | end 47 | 48 | --- Compares two lists. 49 | -- @arg {List} another - another list to compare to 50 | -- @returns {boolean} whether this list is equal to `another` 51 | -- 52 | -- Lists are considered equal if their values match at every position. 53 | function List:equals(another) 54 | error('not implemented') 55 | end 56 | 57 | --- Swaps value at two indices. 58 | -- @arg {int} i - first index 59 | -- @arg {int} j - second index 60 | -- @returns {List} - modified list 61 | function List:swap(i, j) 62 | error('not implemented') 63 | end 64 | 65 | --- Returns the list in table form. 66 | -- @returns {table[any]} a table containing the values in the list. 67 | function List:totable() 68 | error('not implemented') 69 | end 70 | 71 | --- Asserts that index is inside the list. 72 | -- @arg {int} index - index to check 73 | function List:assertValidIndex(index) 74 | assert(index > 0 and index <= self:size()+1, 'index ' .. index .. ' is out of bounds for array of size ' .. self:size()) 75 | end 76 | 77 | --- @returns {int} size of the list 78 | function List:size() 79 | return self._size 80 | end 81 | 82 | --- Adds items to the list. 83 | -- @arg {vararg[any]} vararg - values to add to the list 84 | -- @returns {List} modified list 85 | function List:addMany(...) 86 | local args = table.pack(...) 87 | for k, v in ipairs(args) do 88 | self:add(v) 89 | end 90 | return self 91 | end 92 | 93 | --- Returns whether the list contains a value. 94 | -- @arg {any} val - value to check. 95 | -- @returns {boolean} whether `val` is in the list 96 | function List:contains(val) 97 | for i = 1, self:size() do 98 | if self:get(i) == val then 99 | return true 100 | end 101 | end 102 | return false 103 | end 104 | 105 | --- @returns {List} a copy of this list 106 | function List:copy() 107 | return self.new(self:totable()) 108 | end 109 | 110 | --- @returns {boolean} whether the list is empty 111 | function List:isEmpty() 112 | return self:size() == 0 113 | end 114 | 115 | --- Returns a copy of a segment of this list. 116 | -- 117 | -- @arg {int} start - start of the segment 118 | -- @arg {int=end} finish - start of the segment 119 | function List:sublist(start, finish) 120 | finish = finish or self:size() 121 | local sub = self.new() 122 | self:assertValidIndex(start) 123 | self:assertValidIndex(finish) 124 | for i = start, finish do sub:add(self:get(i)) end 125 | return sub 126 | end 127 | 128 | --- Sorts the list in place. 129 | -- @arg {int=1} start - start index of the sort 130 | -- @arg {int=end} finish - end index of the sort 131 | function List:sort(start, finish) 132 | local partition = function (l, start, finish) 133 | local pivotIndex = math.random(start, finish) 134 | local pivot = self:get(pivotIndex) 135 | self:swap(pivotIndex, finish) 136 | local write = start 137 | for i = start, finish-1 do 138 | if self:get(i) < pivot then 139 | self:swap(i, write) 140 | write = write + 1 141 | end 142 | end 143 | self:swap(write, finish) 144 | return write 145 | end 146 | start = start or 1 147 | finish = finish or self:size() 148 | if start < finish then 149 | local pivot = partition(self, start, finish) 150 | self:sort(start, pivot-1) 151 | self:sort(pivot+1, finish) 152 | end 153 | end 154 | 155 | --- @returns {string} string representation 156 | function List:__tostring__() 157 | local s = parent.__tostring__(self) .. '[' 158 | local max = 5 159 | for i = 1, math.min(self:size(), max) do 160 | s = s .. tostring(self:get(i)) 161 | if i == max then 162 | s = s .. ', ...' 163 | elseif i ~= self:size() then 164 | s = s .. ', ' 165 | end 166 | end 167 | s = s .. ']' 168 | return s 169 | end 170 | 171 | return List 172 | -------------------------------------------------------------------------------- /src/list/Queue.lua: -------------------------------------------------------------------------------- 1 | --- @module Queue 2 | -- Queue implementation. 3 | -- This is a subclass of `List`. 4 | 5 | local Queue, parent = torch.class('tl.Queue', 'tl.LinkedList') 6 | function Queue:__init(values) 7 | parent:__init(values) 8 | end 9 | 10 | --- Adds a value to the stack. 11 | -- @arg {any} val - value to add 12 | -- @returns {Queue} modified stack 13 | function Queue:enqueue(val) 14 | self:add(val) 15 | return self 16 | end 17 | 18 | --- Returns and removes the first value in the queue. 19 | -- @returns {any} removed value 20 | function Queue:dequeue() 21 | assert(self:size() > 0, 'cannot dequeue from empty queue') 22 | return self:remove(1) 23 | end 24 | 25 | return Queue 26 | -------------------------------------------------------------------------------- /src/list/Stack.lua: -------------------------------------------------------------------------------- 1 | --- @module Stack 2 | -- Stack implementation. 3 | -- This is a subclass of `List`. 4 | 5 | local Stack = torch.class('tl.Stack', 'tl.LinkedList') 6 | 7 | --- Adds a value to the stack. 8 | -- @arg {any} val - value to add 9 | -- @returns {Stack} modified stack 10 | function Stack:push(val) 11 | self:add(val) 12 | return self 13 | end 14 | 15 | --- Returns and removes the value at the top of the stack. 16 | -- @returns {any} removed value 17 | function Stack:pop() 18 | assert(self:size() > 0, 'cannot dequeue from empty stack') 19 | return self:remove(self:size()) 20 | end 21 | 22 | return Stack 23 | -------------------------------------------------------------------------------- /src/map/Counter.lua: -------------------------------------------------------------------------------- 1 | --- @module Counter 2 | -- Implementation of a counter. 3 | 4 | local torch = require 'torch' 5 | local Counter = torch.class('tl.Counter') 6 | 7 | --- Constructor. 8 | function Counter:__init() 9 | self.counts = {} 10 | end 11 | 12 | --- Increments the count for a key. 13 | -- @arg {any} key - key to increment count for 14 | -- @arg {int} count - how much to increment count by 15 | -- @returns {int} the new count 16 | function Counter:add(key, count) 17 | count = count or 1 18 | if not self.counts[key] then self.counts[key] = 0 end 19 | self.counts[key] = self.counts[key] + count 20 | return self.counts[key] 21 | end 22 | 23 | --- @arg {any} key - key to return count for. 24 | -- @returns {int} the count for the key 25 | -- 26 | -- If `key` has not been added to the counter, then returns 0. 27 | function Counter:get(key) 28 | if self.counts[key] then return self.counts[key] else return 0 end 29 | end 30 | 31 | --- Clears the counter. 32 | -- @returns {Counter} the modified counter 33 | function Counter:reset() 34 | self.counts = {} 35 | return self 36 | end 37 | 38 | return Counter 39 | -------------------------------------------------------------------------------- /src/map/HashMap.lua: -------------------------------------------------------------------------------- 1 | --- @module HashMap 2 | -- Implementation of hash map. 3 | -- This is a subclass of `Map` 4 | 5 | local torch = require 'torch' 6 | local HashMap, parent = torch.class('tl.HashMap', 'tl.Map') 7 | local Set = tl.Set 8 | 9 | function HashMap:__init(key_values) 10 | self._map = {} 11 | self._size = 0 12 | key_values = key_values or {} 13 | self:addMany(key_values) 14 | end 15 | 16 | function HashMap:add(key, val) 17 | if not self:contains(key) then 18 | self._size = self._size + 1 19 | end 20 | self._map[key] = val 21 | return self 22 | end 23 | 24 | function HashMap:addMany(tab) 25 | for k, v in pairs(tab) do 26 | self:add(k, v) 27 | end 28 | return self 29 | end 30 | 31 | function HashMap:copy() 32 | return HashMap.new():addMany(self:totable()) 33 | end 34 | 35 | function HashMap:contains(key) 36 | return self._map[key] ~= nil 37 | end 38 | 39 | function HashMap:get(key, returnNilIfMissing) 40 | if self:contains(key) then 41 | return self._map[key] 42 | else 43 | if returnNilIfMissing ~= nil then 44 | return nil 45 | else 46 | error('Error: key ' .. tostring(key) .. ' not found in HashMap') 47 | end 48 | end 49 | end 50 | 51 | function HashMap:remove(key) 52 | assert(self:contains(key), 'Error: key ' .. tostring(key) .. ' not found in HashMap') 53 | local val = self:get(key) 54 | self._map[key] = nil 55 | self._size = self._size - 1 56 | return val 57 | end 58 | 59 | function HashMap:keys() 60 | local keys = Set() 61 | for k, v in pairs(self._map) do 62 | keys:add(k) 63 | end 64 | return keys 65 | end 66 | 67 | function HashMap:totable() 68 | local tab = {} 69 | for k, v in pairs(self._map) do 70 | tab[k] = v 71 | end 72 | return tab 73 | end 74 | 75 | function HashMap:__tostring__() 76 | local s = parent.__tostring__(self) .. '{' 77 | local max = 5 78 | local keys = self:keys():totable() 79 | 80 | for i = 1, math.min(self:size(), max) do 81 | local key = keys[i] 82 | s = s .. tostring(key) .. ' -> ' .. tostring(self:get(key)) 83 | if i ~= self:size() then 84 | s = s .. ', ' 85 | end 86 | end 87 | if self:size() > max then s = s .. '...' end 88 | s = s .. '}' 89 | return s 90 | end 91 | 92 | function HashMap:equals(another) 93 | if self:size() ~= another:size() then return false end 94 | for k, v in pairs(self._map) do 95 | if v ~= another._map[k] then return false end 96 | end 97 | return true 98 | end 99 | 100 | return HashMap 101 | -------------------------------------------------------------------------------- /src/map/Map.lua: -------------------------------------------------------------------------------- 1 | --- @module Map 2 | -- Abstract map implementation. 3 | 4 | local torch = require 'torch' 5 | local Map, parent = torch.class('tl.Map', 'tl.Object') 6 | 7 | --- Constructor. 8 | -- @arg {table[any:any]=} key_values - used to initialize the map 9 | function Map:__init(key_values) 10 | error('not implemented') 11 | end 12 | 13 | --- Adds an entry to the map. 14 | -- @arg {any} key - key to add 15 | -- @arg {any} value - value to add 16 | function Map:add(key, val) 17 | error('not implemented') 18 | end 19 | 20 | --- Adds many entries to the map. 21 | -- @arg {table[any:any]=} tab - a map of key value pairs to add 22 | function Map:addMany(tab) 23 | error('not implemented') 24 | end 25 | 26 | --- @returns {Map} copy of this map 27 | function Map:copy() 28 | error('not implemented') 29 | end 30 | 31 | --- @returns {coolean} whether the map contains the key 32 | -- @arg {any} key - key to check 33 | function Map:contains(key) 34 | error('not implemented') 35 | end 36 | 37 | --- Retrieves the value for a key. 38 | -- @arg {any} key - key to retrive 39 | -- @arg {boolean=} returnNilIfMissing - whether to tolerate missing keys 40 | -- @returns {any} value corresponding to the key 41 | -- 42 | -- By default, asserts error if `key` is not found. If `returnNilIfMissing` is true, 43 | -- then a `nil` will be returned if `key` is not found. 44 | function Map:get(key, returnNilIfMissing) 45 | error('not implemented') 46 | end 47 | 48 | --- Removes a key value pair 49 | -- @arg {any} key - key to remove 50 | -- @returns {any} the removed value 51 | -- 52 | -- Asserts error if `key` is not in the map. 53 | function Map:remove(key) 54 | error('not implemented') 55 | end 56 | 57 | --- @returns {table[any]} a table of the keys in the map 58 | function Map:keys() 59 | error('not implemented') 60 | end 61 | 62 | --- @returns {table[any:any]} the map in table form 63 | function Map:totable() 64 | error('not implemented') 65 | end 66 | 67 | --- @returns {boolean} whether this map equals `another`. 68 | -- @arg {Map} another - another map to compare to 69 | -- 70 | -- Maps are considered equal if all keys and corresponding values match. 71 | function Map:equals(another) 72 | error('not implemented') 73 | end 74 | 75 | --- @returns {int} number of key value pairs in the map 76 | function Map:size() 77 | return self._size 78 | end 79 | 80 | return Map 81 | -------------------------------------------------------------------------------- /src/ml/GloveVocab.lua: -------------------------------------------------------------------------------- 1 | local torch = require 'torch' 2 | local Downloader = tl.Downloader 3 | 4 | --- @module GloveVocab 5 | -- Vocab object prepopulated with Glove embeddings by Pennington, Socher, and Manning. 6 | -- This is a subclass of `Vocab`. 7 | -- For details, see: 8 | -- 9 | -- http://nlp.stanford.edu/projects/glove/. 10 | -- 11 | -- This only supports the 50-d wikipedia/Giga-word version. 12 | -- The download is from: 13 | -- 14 | -- https://dl.dropboxusercontent.com/u/9015381/datasets/torchnlp/glove.6B.50d.t7 15 | local GloveVocab, parent = torch.class('tl.GloveVocab', 'tl.Vocab') 16 | 17 | --- Retrieves the word list and populates the vocabulary. 18 | function GloveVocab:load_words() 19 | local fname = Downloader():get( 20 | 'vocab-glove.6B.50d.t7', 21 | 'https://dl.dropboxusercontent.com/u/9015381/datasets/torchnlp/glove.6B.50d.t7' 22 | ) 23 | local bin = torch.load(fname) 24 | -- add the pretrained vocabulary 25 | self:indicesOf(bin.words, true) 26 | end 27 | 28 | --- @returns {torch.Tensor} pretrained embeddings for words in the vocabulary 29 | function GloveVocab:embeddings() 30 | local t = torch.Tensor(self:size(), 50):uniform(-0.08, 0.08) 31 | local fname = Downloader():get( 32 | 'vocab-glove.6B.50d.t7', 33 | 'https://dl.dropboxusercontent.com/u/9015381/datasets/torchnlp/glove.6B.50d.t7' 34 | ) 35 | local bin = torch.load(fname) 36 | local w2i = {} 37 | for i, word in ipairs(bin.words) do 38 | w2i[word] = i 39 | end 40 | for i, w in ipairs(self.word2index) do 41 | if w2i[w] then t[i] = bin.vecs[w2i[w]] end 42 | end 43 | return t 44 | end 45 | 46 | --- @returns {string} string representation 47 | function GloveVocab:__tostring__() 48 | return "GloveVocab("..self:size()..' words, unk='..self.unk..")" 49 | end 50 | 51 | return GloveVocab 52 | -------------------------------------------------------------------------------- /src/ml/ProbTable.lua: -------------------------------------------------------------------------------- 1 | local torch = require 'torch' 2 | 3 | --- @module ProbTable 4 | -- Implementation of probability table using Torch tensor 5 | local ProbTable = torch.class('tl.ProbTable') 6 | 7 | --- Constructor. 8 | -- @arg {torch.tensor} P - probability Tensor, the `i`th dimension corresponds to the `i`th variable. 9 | -- @arg {table[string]=} names - A table of names for the variables. By default theses will be assigned using indices. 10 | -- 11 | -- Example: 12 | -- 13 | -- @code {lua} 14 | -- local t = ProbTable(torch.Tensor{{0.2, 0.8}, {0.4, 0.6}, {0.1, 0.9}}, {'a', 'b'}) 15 | -- t:query{a=1, b=2} -- 0.8 16 | -- t:query{a=2} -- Tensor{0.4, 0.6} 17 | function ProbTable:__init(P, names) 18 | if not names then names = torch.range(1, P:nDimension()):totable() end 19 | self.names = {} 20 | self.name2index = {} 21 | if type(names) == 'string' then 22 | self.names = {names} 23 | self.name2index = {} 24 | self.name2index[names] = 1 25 | else 26 | for _, name in ipairs(names) do 27 | table.insert(self.names, name) 28 | self.name2index[name] = #self.names 29 | end 30 | end 31 | self.P = P 32 | end 33 | 34 | --- @returns {int} number of variables in the table 35 | function ProbTable:size() 36 | return self.P:nDimension() 37 | end 38 | 39 | --- @returns {torch.Tensor} probabilities for the assignments in `dict`. 40 | -- @arg {table[string=int]} dict - an assignment to consider 41 | -- 42 | -- Example: 43 | -- 44 | -- @code {lua} 45 | -- local t = ProbTable(torch.Tensor{{0.2, 0.8}, {0.4, 0.6}, {0.1, 0.9}}, {'a', 'b'}) 46 | -- t:query{a=1, b=2} 47 | -- t:query{a=2} 48 | -- 49 | -- The first query is `0.8`. The second query is `Tensor{0.4, 0.6}` 50 | function ProbTable:query(dict) 51 | for name, value in pairs(dict) do 52 | local i = assert(self.name2index[name], name .. ' is not a valid name') 53 | assert(value > 0 and value <= self.P:size(i), value .. ' is out of range') 54 | end 55 | local ind = {} 56 | for i, name in ipairs(self.names) do 57 | table.insert(ind, dict[name] or {}) 58 | end 59 | return self.P[ind] 60 | end 61 | 62 | --- @returns {ProbTable} a copy 63 | function ProbTable:clone() 64 | local names = tl.copy(self.names) 65 | local P = self.P:clone() 66 | return ProbTable.new(P, names) 67 | end 68 | 69 | --- @returns {string} string representation 70 | function ProbTable:__tostring__() 71 | local s = '' 72 | local divider = '' 73 | for i, name in ipairs(self.names) do 74 | s = s .. name .. '\t' 75 | divider = divider .. '-' .. '\t' 76 | end 77 | s = s .. 'P\n' .. divider .. '-\n' 78 | local dims = self.P:size():totable() 79 | for i, d in ipairs(dims) do 80 | dims[i] = torch.range(1, d):totable() 81 | end 82 | for _, ind in ipairs(table.combinations(dims)) do 83 | for _, i in ipairs(ind) do 84 | s = s .. i .. '\t' 85 | end 86 | s = s .. self.P[ind] .. '\n' 87 | end 88 | return s 89 | end 90 | 91 | --- Returns a new table that is the product of two tables. 92 | -- @arg {ProbTable} B - another table 93 | -- @returns {ProbTable} product of this and another table 94 | function ProbTable:mul(B) 95 | -- allocate new P and name for the new product ProbTable 96 | local P = self.P:clone() 97 | local names = tl.copy(self.names) 98 | local name2index = tl.copy(self.name2index) 99 | 100 | -- the idea is that we will extend the new name order such that 101 | -- the beginning names are in the exact same order as B.names. 102 | -- this way B.P[ind] can be multiplied with P[ind] directly. 103 | -- we also do this because repeatTensor repeats from the beginning dimensions. 104 | local write = 1 -- This keep track of the index of the first non-B name 105 | for i, name in ipairs(B.names) do 106 | if name2index[name] then 107 | -- This name is in both A and B, so we swap it to beginning 108 | -- swap P 109 | local old_i = name2index[name] 110 | P = P:transpose(old_i, write) 111 | -- swap name 112 | local old_write_name = names[write] 113 | names[write] = name 114 | names[old_i] = old_write_name 115 | -- swap name2index 116 | name2index[old_write_name] = old_i 117 | name2index[name] = write 118 | else 119 | -- Otherwise this name is in B only, we simply insert it into the correct spot 120 | table.insert(names, write, name) 121 | for i, name in ipairs(names) do name2index[name] = i end 122 | local sizes = torch.ones(P:nDimension() + 1) 123 | sizes[1] = B.P:size(i) 124 | P = P:repeatTensor(table.unpack(sizes:totable())):transpose(1, write) 125 | end 126 | write = write + 1 127 | end 128 | local dims = B.P:size():totable() 129 | for i, d in ipairs(dims) do dims[i] = torch.range(1, d):totable() end 130 | for _, ind in ipairs(table.combinations(dims)) do 131 | if type(P[ind]) == 'number' then 132 | P[ind] = P[ind] * B.P[ind] 133 | else 134 | P[ind]:mul(B.P[ind]) 135 | end 136 | end 137 | return ProbTable.new(P, names) 138 | end 139 | 140 | --- Marginalizes this probability table in place. 141 | -- @arg {string} name - the variable to marginalize 142 | -- @returns {ProbTable} this probability table with the variable `name` marginalized out 143 | function ProbTable:marginalize(name) 144 | local dim = assert(self.name2index[name], tostring(name) .. ' is not a valid name') 145 | self.P = self.P:sum(dim):squeeze(dim) 146 | if type(self.P) == 'number' then self.P = torch.Tensor{self.P} end 147 | self.name2index[name] = nil 148 | for i = dim, #self.names do 149 | self.names[i] = self.names[i+1] 150 | if self.names[i+1] then 151 | self.name2index[self.names[i+1]] = i 152 | end 153 | end 154 | return self 155 | end 156 | 157 | --- Marginalizes this probability table in place to calculate a marginal. 158 | -- @arg {string} name - the variable to calculate 159 | -- @returns {ProbTable} this probability table marginalizing all variables except `name` 160 | function ProbTable:marginal(name) 161 | assert(self.name2index[name], 'Table does not contain variable with name '..name) 162 | while self:size() > 1 do 163 | for i = 1, self:size() do 164 | if self.names[i] ~= name then 165 | self:marginalize(self.names[i]) 166 | break 167 | end 168 | end 169 | end 170 | return self 171 | end 172 | 173 | --- Normalizes this table by dividing by the sum of all probabilities. 174 | -- @returns {ProbTable} normalized table 175 | function ProbTable:normalize() 176 | self.P:div(self.P:sum()) 177 | return self 178 | end 179 | 180 | return ProbTable 181 | -------------------------------------------------------------------------------- /src/ml/Scorer.lua: -------------------------------------------------------------------------------- 1 | local torch = require 'torch' 2 | 3 | --- @module Scorer 4 | -- Implementation of a scorer to calculate precision/recall/f1. 5 | local Scorer = torch.class('tl.Scorer', 'tl.Object') 6 | 7 | --- Constructor. 8 | -- 9 | -- @arg {string=} gold_log - if given, gold labels will be written to this file 10 | -- @arg {string]} pred_log - if given, predicted labels will be written to this file 11 | function Scorer:__init(gold_log, pred_log) 12 | if gold_log and pred_log then 13 | self.logs = {gold = io.open(gold_log, 'w'), pred = io.open(pred_log, 'w')} 14 | end 15 | self.class2ind = {} 16 | self.ind2class = {} 17 | self.pred = {} 18 | self.gold = {} 19 | end 20 | 21 | --- Adds a prediction/ground truth pair to the scorer. 22 | -- @arg {string} gold - ground truth label 23 | -- @arg {string} pred - corresponding predicted label 24 | -- @arg {string=} id - corresponding identifier for this example 25 | -- 26 | -- If the scorer was given the gold log and the pred log, then the pair will be written to their respective log files. 27 | function Scorer:add_pred(gold, pred, id) 28 | if self.logs then 29 | self.logs.gold:write(id..'\t'..gold..'\n') 30 | self.logs.pred:write(id..'\t'..pred..'\n') 31 | end 32 | if not self.class2ind[gold] then 33 | table.insert(self.ind2class, gold) 34 | self.class2ind[gold] = #self.ind2class 35 | end 36 | if not self.class2ind[pred] then 37 | table.insert(self.ind2class, pred) 38 | self.class2ind[pred] = #self.ind2class 39 | end 40 | table.insert(self.gold, self.class2ind[gold]) 41 | table.insert(self.pred, self.class2ind[pred]) 42 | end 43 | 44 | --- Removes all remembered statistics from the scorer. 45 | function Scorer:reset() 46 | if self.logs then 47 | for fname, f in pairs(self.logs) do 48 | f:close() 49 | self.logs[fname] = io.open(fname .. '.log', 'w') 50 | end 51 | end 52 | self.class2ind, self.ind2class, self.pred, self.gold = {}, {}, {}, {} 53 | end 54 | 55 | --- Computes the precision/recall/f1 statistics for the current batch of elements. 56 | -- @arg {string=} ignore - if given, `ignore` is taken to be the "negative" class and its statistics will be withheld 57 | -- from the computation. 58 | -- @returns {table, table, table} micro, macro, and class scores 59 | -- 60 | -- Example: 61 | -- 62 | -- @code 63 | -- local s = Scorer() 64 | -- s:add_pred('a', 'b', 1) 65 | -- s:add_pred('b', 'b', 2) 66 | -- s:add_pred('c', 'a', 3) 67 | -- local micro, macro, all_stats = s:precision_recall_f1(ignore) 68 | -- 69 | -- @description 70 | -- Returns the following 71 | -- 72 | -- - `micro`: the micro averaged precision/recall/f1 statistics 73 | -- 74 | -- - `macro`: the macro averaged precision/recall/f1 statistics 75 | -- 76 | -- - `class_stats`: the precision/recall/f1 for each class 77 | function Scorer:precision_recall_f1(ignore) 78 | if ignore ~= nil then 79 | assert(self.class2ind[ignore], 'ignore '..ignore..' is not a valid class') 80 | end 81 | local pred = torch.Tensor(self.pred) 82 | local gold = torch.Tensor(self.gold) 83 | local all_stats = {} 84 | for class, ind in pairs(self.class2ind) do 85 | local relevant = gold:eq(ind) 86 | local retrieved = pred:eq(ind) 87 | local stats = { 88 | relevant = relevant:sum(), 89 | retrieved = retrieved:sum(), 90 | relevant_and_retrieved = torch.cmul(relevant, retrieved):sum(), 91 | } 92 | stats.precision = stats.relevant_and_retrieved / stats.retrieved 93 | stats.recall = stats.relevant_and_retrieved / stats.relevant 94 | stats.f1 = 2 * stats.precision * stats.recall / (stats.precision + stats.recall) 95 | if stats.relevant_and_retrieved == 0 then 96 | stats.precision = 0 97 | stats.recall = 0 98 | stats.f1 = 0 99 | end 100 | all_stats[class] = stats 101 | end 102 | local macro = {precision=0, recall=0} 103 | local micro = {relevant=0, retrieved=0, relevant_and_retrieved=0} 104 | local n_classes = #self.ind2class 105 | if ignore then n_classes = n_classes - 1 end 106 | for class, s in pairs(all_stats) do 107 | if class ~= ignore then 108 | for _, k in ipairs{'precision', 'recall'} do 109 | macro[k] = macro[k] + s[k] / n_classes 110 | end 111 | for _, k in ipairs{'relevant', 'retrieved', 'relevant_and_retrieved'} do 112 | micro[k] = micro[k] + s[k] 113 | s[k] = nil 114 | end 115 | end 116 | end 117 | 118 | if ignore then 119 | for _, k in ipairs{'relevant', 'retrieved', 'relevant_and_retrieved'} do 120 | all_stats[ignore][k] = nil 121 | end 122 | end 123 | 124 | macro.f1 = 2 * macro.precision * macro.recall / (macro.precision + macro.recall) 125 | if macro.precision == 0 and macro.recall == 0 then macro.f1 = 0 end 126 | 127 | micro.precision = micro.relevant_and_retrieved / micro.retrieved 128 | micro.recall = micro.relevant_and_retrieved / micro.relevant 129 | micro.f1 = 2 * micro.precision * micro.recall / (micro.precision + micro.recall) 130 | if micro.relevant_and_retrieved == 0 then 131 | micro.precision = 0 132 | micro.recall = 0 133 | micro.f1 = 0 134 | end 135 | 136 | micro.relevant = nil 137 | micro.retrieved = nil 138 | micro.relevant_and_retrieved = nil 139 | 140 | return micro, macro, all_stats 141 | end 142 | 143 | return Scorer 144 | -------------------------------------------------------------------------------- /src/ml/VariableTensor.lua: -------------------------------------------------------------------------------- 1 | local torch = require 'torch' 2 | local Dataset = tl.Dataset 3 | 4 | --- @module VariableTensor 5 | -- Implementation of a variable tensor class to efficiently store tensors of varying lengths. 6 | local VariableTensor = torch.class('tl.VariableTensor', 'tl.Object') 7 | 8 | --- Constructor. 9 | -- @arg {int=1} preinit_size - how many indices to preallocate for 10 | -- 11 | -- @arg {int=1} preinit_store_size - how many elements to preallocate for 12 | function VariableTensor:__init(opt) 13 | local preinit_size = 1 14 | local preinit_store_size = 1 15 | if opt and opt.preinit_size then 16 | preinit_size = opt.preinit_size 17 | end 18 | if opt and opt.preinit_store_size then 19 | preinit_store_size = opt.preinit_store_size 20 | end 21 | 22 | self.indices = torch.LongTensor(preinit_size, 2) -- start, end indices 23 | self.store = torch.Tensor(preinit_store_size) 24 | self.write_head = 1 25 | self.store_write_head = 1 26 | end 27 | 28 | --- Moves the storage to cuda 29 | -- @returns {VariableTensor} modified tensor 30 | function VariableTensor:cuda() 31 | self.store = self.store:cuda() 32 | return self 33 | end 34 | 35 | --- @returns {int} sum of the size of each tensor in the storage 36 | function VariableTensor:size() 37 | return self.write_head - 1 38 | end 39 | 40 | --- Appends a tensor to the storage. 41 | -- @arg {torch.Tensor} tensor - tensor to add to storage 42 | -- @returns {VariableTensor} modified tensor 43 | function VariableTensor:push(tensor) 44 | while self.store_write_head + tensor:nElement() - 1 > self.store:size(1) do 45 | self.store:resize(self.store:size(1) * 2) 46 | end 47 | while self.write_head > self.indices:size(1) do 48 | self.indices:resize(self.indices:size(1) * 2, 2) 49 | end 50 | self.store[{{self.store_write_head, self.store_write_head + tensor:nElement()-1}}] = tensor 51 | self.indices[self.write_head][1] = self.store_write_head 52 | self.indices[self.write_head][2] = self.store_write_head + tensor:nElement()-1 53 | self.store_write_head = self.store_write_head + tensor:nElement() 54 | self.write_head = self.write_head + 1 55 | return self 56 | end 57 | 58 | --- Shuffles the indices. 59 | -- @arg {torch.Tensor=} indices - tensor that denotes how the new indices should be set. If not given, then a random 60 | -- tensor will be generated 61 | -- @returns {torch.Tensor} the `indices` tensor used to shuffle 62 | function VariableTensor:shuffle(indices) 63 | indices = indices or torch.randperm(self:size()):long() 64 | self.indices[{{1, self:size()}}] = self.indices[{{1, self:size()}}]:index(1, indices) 65 | return indices 66 | end 67 | 68 | --- Retrieves the tensor at index `i`. 69 | -- @arg {int} i - index to query 70 | -- @returns {torch.Tensor} tensor at index 71 | function VariableTensor:get(i) 72 | return self.store[{{self.indices[i][1], self.indices[i][2]}}] 73 | end 74 | 75 | --- Creates a zero-padded batch from tensors at the indices `indices`. 76 | -- @arg {table} indices - starting indices of tensors to pad 77 | -- @arg {int=0} pad - number to use to pad shorter tensors 78 | function VariableTensor:batch(indices, pad) 79 | local b = {} 80 | for _, i in ipairs(indices) do table.insert(b, self:get(i)) end 81 | return Dataset.pad(b, pad) 82 | end 83 | 84 | return VariableTensor 85 | -------------------------------------------------------------------------------- /src/ml/Vocab.lua: -------------------------------------------------------------------------------- 1 | local torch = require 'torch' 2 | 3 | --- @module Vocab 4 | -- Implementation of vocabulary 5 | local Vocab, parent = torch.class("tl.Vocab", 'tl.Object') 6 | 7 | --- Constructor. 8 | -- @arg {string='UNK'} unk - the symbol for the unknown token. 9 | function Vocab:__init(unk) 10 | self.unk = unk or 'UNK' 11 | self.index2word = {} 12 | self.word2index = {} 13 | self.counter = {} 14 | if self.unk ~= nil then 15 | table.insert(self.index2word, self.unk) 16 | self.word2index[self.unk] = self:size() 17 | self.counter[self.unk] = 0; 18 | end 19 | end 20 | 21 | --- @returns {string} string representation 22 | function Vocab:__tostring__() 23 | return parent.__tostring__(self).."("..self:size()..' words, unk='..self.unk..")" 24 | end 25 | 26 | --- @returns {boolean} whether `word` is in the vocabulary 27 | -- @arg {string} word - word to query 28 | function Vocab:contains(word) 29 | return self.word2index[word] ~= nil 30 | end 31 | 32 | --- @returns {int} count for `word` seen during training 33 | -- @arg {string} word - word to query 34 | function Vocab:count(word) 35 | assert(self:contains(word), 'Error: attempted to get count of word ' .. word .. ' which is not in the vocabulary') 36 | return self.counter[word] 37 | end 38 | 39 | --- @returns {int} how many distinct tokens are in the vocabulary 40 | function Vocab:size() 41 | return #self.index2word 42 | end 43 | 44 | --- Adds `word` `count` time to the vocabulary. 45 | -- @arg {string} word - word to add 46 | -- @arg {int=1} count - number of times to add 47 | -- @returns {int} index of `word` 48 | function Vocab:add(word, count) 49 | count = count or 1 50 | if self:contains(word) then 51 | self.counter[word] = self:count(word) + count 52 | else 53 | self.counter[word] = count 54 | table.insert(self.index2word, word) 55 | self.word2index[word] = self:size() 56 | end 57 | return self.word2index[word] 58 | end 59 | 60 | --- @returns {int} index of `word`. 61 | -- @arg {string} word - word to query 62 | -- @arg {boolean=} add - whether to add new word to the vocabulary 63 | -- 64 | -- If the word is not found, then one of the following occurs: 65 | -- 66 | -- - if `add` is `true`, then `word` is added to the vocabulary with count 1 and the new index returned 67 | -- 68 | -- - otherwise, the index of the unknown token is returned 69 | -- 70 | -- Example: 71 | -- 72 | -- Suppose we have a vocabulary of words 'unk', 'foo' and 'bar' 73 | -- 74 | -- @code {lua} 75 | -- vocab:indexOf('foo') -- returns 2 76 | -- vocab:indexOf('bar') -- returns 3 77 | -- vocab:indexOf('hello') -- returns 1 corresponding to `unk` because `hello` is not in the vocabuarly 78 | -- vocab:indexOf('hello', true) -- returns 4 because `hello` is added to the vocabulary 79 | function Vocab:indexOf(word, add) 80 | add = add or false 81 | if add then 82 | return self:add(word, 1) 83 | end 84 | if not self:contains(word) then 85 | self.counter[self.unk] = self:count(self.unk) + 1 86 | return self:indexOf(self.unk) 87 | end 88 | self.counter[word] = self:count(word) + 1 89 | return self.word2index[word] 90 | end 91 | 92 | --- @returns {string} word at index `index` 93 | -- @arg {int} index - the index to query 94 | -- 95 | -- If `index` is out of bounds then an error will be raised. 96 | -- 97 | -- Example: 98 | -- 99 | -- Suppose we have a vocabulary with words 'unk', 'foo', and 'bar' 100 | -- 101 | -- @code {lua} 102 | -- vocab:wordAt(1) -- unk 103 | -- vocab:wordAt(2) -- foo 104 | -- vocab:wordAt(4) -- raises and error because there is no 4th word in the vocabulary 105 | function Vocab:wordAt(index) 106 | assert(index <= self:size(), 'Error: attempted to get word at index ' .. index .. ' which exceeds the vocab size') 107 | return self.index2word[index] 108 | end 109 | 110 | --- `indexOf` on a table of words. 111 | -- 112 | -- @arg {table[string]} words - words to query 113 | -- @arg {boolean=} add - whether to add new words to the vocabulary 114 | -- @returns {table[int]} corresponding indices. 115 | -- 116 | -- Example: 117 | -- 118 | -- Suppose we have a vocabulary with words 'unk', 'foo', and 'bar' 119 | -- 120 | -- @code {lua} 121 | -- vocab:indicesOf{'foo', 'bar'} -- {2, 3} 122 | function Vocab:indicesOf(words, add) 123 | add = add or false 124 | indices = {} 125 | for i, word in ipairs(words) do 126 | table.insert(indices, self:indexOf(word, add)) 127 | end 128 | return indices 129 | end 130 | 131 | --- `indexOf` on a table of words. 132 | -- 133 | -- @args {table[string]} words - words to query 134 | -- @arg {boolean=} add - whether to add new words to the vocabulary 135 | -- @returns {torch.Tensor} tensor of corresponding indices 136 | -- 137 | -- Example: 138 | -- 139 | -- Suppose we have a vocabulary with words 'unk', 'foo', and 'bar' 140 | -- 141 | -- @code {lua} 142 | -- vocab:tensorIndicesOf{'foo', 'bar'} -- torch.Tensor{2, 3} 143 | -- vocab:tensorIndicesOf{'foo', 'hi'} -- torch.Tensor{2, 1}, because `hi` is not in the vocabulary 144 | function Vocab:tensorIndicesOf(words, add) 145 | add = add or false 146 | local indices = torch.Tensor(#words) 147 | for i, word in ipairs(words) do 148 | indices[i] = self:indexOf(word, add) 149 | end 150 | return indices 151 | end 152 | 153 | --- `wordAt` on a table of indices. 154 | -- 155 | -- @arg {table[int]} indices - indices to query 156 | -- @returns {table[string]} corresponding words 157 | -- 158 | -- Example: 159 | -- 160 | -- Suppose we have a vocabulary with words 'unk', 'foo', and 'bar' 161 | -- 162 | -- @code {lua} 163 | -- vocab:wordsAt{1, 3} -- {'unk', 'bar'} 164 | -- vocab:wordsAt{1, 4} -- raises an error because there is no 4th word 165 | function Vocab:wordsAt(indices) 166 | local words = {} 167 | for i, index in ipairs(indices) do 168 | table.insert(words, self:wordAt(index)) 169 | end 170 | return words 171 | end 172 | 173 | --- `wordAt` on a tensor of indices. Returns a table of corresponding words. 174 | -- 175 | -- Example: 176 | -- 177 | -- Suppose we have a vocabulary with words 'unk', 'foo', and 'bar' 178 | -- 179 | -- @code {lua} 180 | -- vocab:tensorWordsAt(torch.Tensor{1, 3}) -- {'unk', 'bar'} 181 | -- vocab:tensorWordsAt(torch.Tensor{1, 4}) -- raises an error because there is no 4th word 182 | function Vocab:tensorWordsAt(indices) 183 | local words = {} 184 | for i = 1, indices:size(1) do 185 | table.insert(words, self:wordAt(indices[i])) 186 | end 187 | return words 188 | end 189 | 190 | --- Returns a new vocabulary with words occurring less than `cutoff` times removed. 191 | -- 192 | -- @arg {int} cutoff - words with frequency below this number will be removed from the vocabulary 193 | -- @returns {Vocab} modified vocabulary 194 | -- 195 | -- Example: 196 | -- 197 | -- Suppose we want to forget all words that occurred less than 5 times: 198 | -- 199 | -- @code {lua} 200 | -- smaller_vocab = orig_vocab:copyAndPruneRares(5) 201 | function Vocab:copyAndPruneRares(cutoff) 202 | local v = self.new(self.unk) 203 | for i, word in ipairs(self.index2word) do 204 | local count = self:count(word) 205 | if (count >= cutoff or word == self.unk) then 206 | v:add(word, count) 207 | end 208 | end 209 | return v 210 | end 211 | 212 | return Vocab 213 | -------------------------------------------------------------------------------- /src/set/Set.lua: -------------------------------------------------------------------------------- 1 | --- @module Set 2 | -- Implementation of set. 3 | 4 | local torch = require 'torch' 5 | local Set, parent = torch.class('tl.Set', 'tl.Object') 6 | 7 | --- Constructor. 8 | -- @arg {table[any]=} values - used to initialize the set 9 | function Set:__init(values) 10 | self._map = {} 11 | self._size = 0 12 | values = values or {} 13 | for _, v in ipairs(values) do self:add(v) end 14 | end 15 | 16 | --- @arg {any} val - value to produce a key for 17 | -- @returns {torch.pointer} unique key for the value 18 | function Set.keyOf(val) 19 | if torch.type(val) == 'number' or torch.type(val) == 'nil' or torch.type(val) == 'string' then 20 | return val 21 | else 22 | return torch.pointer(val) 23 | end 24 | end 25 | 26 | --- @returns {int} number of values in the set 27 | function Set:size() 28 | return self._size 29 | end 30 | 31 | --- Adds a value to the set. 32 | -- @arg {any} val - value to add to the set 33 | -- @returns {Set} modified set 34 | function Set:add(val) 35 | if not self:contains(val) then 36 | self._size = self._size + 1 37 | end 38 | local key = Set.keyOf(val) 39 | self._map[key] = val 40 | return self 41 | end 42 | 43 | --- Adds a variable number of values to the set. 44 | -- @arg {vararg} vararg - values to add to the set 45 | -- @returns {Set} modified set 46 | function Set:addMany(...) 47 | local args = table.pack(...) 48 | for i, val in ipairs(args) do 49 | self:add(val) 50 | end 51 | return self 52 | end 53 | 54 | --- @returns {Set} copy of the set 55 | function Set:copy() 56 | return Set.new(self:totable()) 57 | end 58 | 59 | --- @returns {boolean} whether the set contains `val` 60 | -- @arg {any} val - value to check for 61 | function Set:contains(val) 62 | local key = Set.keyOf(val) 63 | return self._map[key] ~= nil 64 | end 65 | 66 | --- @arg {any} val - value to remove from the set. 67 | -- @returns {Set} modified set 68 | -- If `val` is not found then an error is raised. 69 | function Set:remove(val) 70 | assert(self:contains(val) == true, 'Error: value ' .. tostring(val) .. ' not found in Set') 71 | local key = Set.keyOf(val) 72 | self._map[key] = nil 73 | self._size = self._size - 1 74 | return self 75 | end 76 | 77 | --- @returns {tabl} the set in table format 78 | function Set:totable() 79 | local tab = {} 80 | for k, v in pairs(self._map) do 81 | table.insert(tab, v) 82 | end 83 | return tab 84 | end 85 | 86 | --- Compares two sets. 87 | -- @arg {Set} another - another set 88 | -- @returns {boolean} whether this set and `another` contain the same values 89 | function Set:equals(another) 90 | if self:size() ~= another:size() then 91 | return false 92 | end 93 | 94 | for i, v in ipairs(another:totable()) do 95 | if not self:contains(v) then 96 | return false 97 | end 98 | end 99 | return true 100 | end 101 | 102 | --- Computes the union of two sets. 103 | -- @arg {Set} another - another set 104 | -- @returns {Set} a set of values that are in this set or in `another` 105 | function Set:union(another) 106 | local s = self:copy() 107 | for i, v in ipairs(another:totable()) do 108 | s:add(v) 109 | end 110 | return s 111 | end 112 | 113 | --- Computes the intersection of two sets. 114 | -- @arg {Set} another - another set 115 | -- @returns {Set} a set of values that are in this set and in `another` 116 | function Set:intersect(another) 117 | local s = self:copy() 118 | for i, v in ipairs(self:totable()) do 119 | if not another:contains(v) then 120 | s:remove(v) 121 | end 122 | end 123 | return s 124 | end 125 | 126 | --- Subtracts another set from this one. 127 | -- @arg {Set} another - another set 128 | -- @returns {Set} a set of values that are in this set but not in `another` 129 | function Set:subtract(another) 130 | local s = self:copy() 131 | for i, v in ipairs(self:totable()) do 132 | if another:contains(v) then 133 | s:remove(v) 134 | end 135 | end 136 | return s 137 | end 138 | 139 | --- @returns {string} string representation 140 | function Set:__tostring__() 141 | local s = parent.__tostring__(self) .. '(' 142 | local max = 5 143 | local keys = self:totable() 144 | 145 | for i = 1, math.min(self:size(), max) do 146 | key = keys[i] 147 | s = s .. tostring(key) 148 | if i ~= self:size() then 149 | s = s .. ', ' 150 | end 151 | end 152 | if self:size() > max then s = s .. '...' end 153 | s = s .. ')' 154 | return s 155 | end 156 | 157 | return Set 158 | -------------------------------------------------------------------------------- /src/tree/BinarySearchTree.lua: -------------------------------------------------------------------------------- 1 | local torch = require 'torch' 2 | 3 | --- @module BinarySearchTree.Node 4 | -- A node in the binary search tree. 5 | -- This is a subclass of `BinaryTree.Node`. 6 | local BinarySearchTree, parent = torch.class('tl.BinarySearchTree', 'tl.BinaryTree') 7 | local BinarySearchTreeNode, parent = torch.class('tl.BinarySearchTree.Node', 'tl.BinaryTree.Node') 8 | 9 | --- Searches for a key in the BST. 10 | -- 11 | -- @arg {number} key - the key to retrieve 12 | -- @returns {BinarySearchTree.Node} the node with the requested key 13 | function BinarySearchTreeNode:search(key) 14 | local curr = self 15 | while curr ~= nil do 16 | if key == curr.key then 17 | return curr 18 | elseif key < curr.key then 19 | curr = curr.left 20 | else 21 | curr = curr.right 22 | end 23 | end 24 | return nil 25 | end 26 | 27 | --- @returns {int} the minimum node of the subtree rooted at this node. 28 | function BinarySearchTreeNode:min() 29 | local curr = self 30 | while curr.left ~= nil do 31 | curr = curr.left 32 | end 33 | return curr 34 | end 35 | 36 | --- @returns {int} the maximum node of the subtree rooted at this node. 37 | function BinarySearchTreeNode:max() 38 | local curr = self 39 | while curr.right ~= nil do 40 | curr = curr.right 41 | end 42 | return curr 43 | end 44 | 45 | --- @returns {BinarySearchTre.Node} the node with the smallest key that is larger than this one. 46 | function BinarySearchTreeNode:successor() 47 | if self.right ~= nil then 48 | return self.right:min() 49 | end 50 | -- right subtree is nil, hence keep going up until we are on the left subtree 51 | local curr = self 52 | local p = curr.parent 53 | while p ~= nil and p.right == curr do 54 | curr = p 55 | p = p.parent 56 | end 57 | return p 58 | end 59 | 60 | --- @returns {BinarySearchTre.Node} the node with the largest key that is smaller than this one. 61 | function BinarySearchTreeNode:predecessor() 62 | if self.left ~= nil then 63 | return self.left:max() 64 | end 65 | -- left subtree is nil, hence keep going up until we are on the right subtree 66 | local curr = self 67 | local p = curr.parent 68 | while p ~= nil and p.left == curr do 69 | curr = p 70 | p = p.parent 71 | end 72 | return p 73 | end 74 | 75 | --- @module BinarySearchTree 76 | -- Binary Search Tree. An implementation of `BinaryTree`. 77 | -- 78 | -- Example: 79 | -- 80 | -- @code {lua} 81 | -- local t = BinarySearchTree.new() 82 | -- t:insert(BinarySearchTreeNode.new(12)) 83 | -- t:insert(BinarySearchTreeNode.new(5)) 84 | -- t:insert(BinarySearchTreeNode.new(2)) 85 | -- t:insert(BinarySearchTreeNode.new(9)) 86 | -- t:insert(BinarySearchTreeNode.new(18)) 87 | -- t:insert(BinarySearchTreeNode.new(15)) 88 | -- t:insert(BinarySearchTreeNode.new(13)) 89 | -- t:insert(BinarySearchTreeNode.new(17)) 90 | -- t:insert(BinarySearchTreeNode.new(19)) 91 | -- print(t) 92 | 93 | 94 | --- Inserts a node into the tree. 95 | -- @arg {BinarySearchTree.Node} node - node to insert 96 | -- @returns {BinarySearchTree} modified tree 97 | function BinarySearchTree:insert(node) 98 | local p = nil 99 | local curr = self.root 100 | while curr ~= nil do 101 | p = curr 102 | if node.key < curr.key then curr = curr.left else curr = curr.right end 103 | end 104 | node.parent = p 105 | if p == nil then 106 | self.root = node -- tree was empty 107 | elseif node.key < p.key then 108 | p.left = node 109 | else 110 | p.right = node 111 | end 112 | self._size = self._size + 1 113 | return self 114 | end 115 | 116 | --- @arg {number} key - key to search for. 117 | -- @returns {BinarySearchTree.Node} node with the requested key 118 | function BinarySearchTree:search(key) 119 | return self.root:search(key) 120 | end 121 | 122 | --- @returns {BinarySearchTree.Node} node with the minimum key 123 | function BinarySearchTree:min() 124 | return self.root:min() 125 | end 126 | 127 | --- @returns {BinarySearchTree.Node} node with the maximum key 128 | function BinarySearchTree:max() 129 | return self.root:max() 130 | end 131 | 132 | --- Replaces the subtree rooted at `old` with the one rooted at `new`. 133 | -- @arg {BinarySearchTree.Node} old - node to replace 134 | -- @arg {BinarySearchTree.Node} new - new node to use 135 | -- @returns {BinarySearchTree} modified tree 136 | function BinarySearchTree:transplant(old, new) 137 | if old == self.root then 138 | self.root = new 139 | elseif old == old.parent.left then 140 | old.parent.left = new 141 | else 142 | old.parent.right = new 143 | end 144 | if new ~= nil then 145 | new.parent = old.parent 146 | end 147 | return self 148 | end 149 | 150 | --- Deletes a node from the tree. 151 | -- @arg {BinarySearchTree.Node} node - node to delete 152 | -- @returns {BinarySearchTree} modified tree 153 | function BinarySearchTree:delete(node) 154 | if node.left == nil then 155 | BinarySearchTree:transplant(node, node.right) 156 | elseif node.right == nil then 157 | BinarySearchTree:transplant(node, node.left) 158 | else 159 | local rightSubtreeMin = node.right:min() 160 | if rightSubtreeMin.parent ~= node then 161 | -- has two children and successor is not its right child 162 | self:transplant(rightSubtreeMin, rightSubtreeMin.right) 163 | rightSubtreeMin.right = node.right 164 | rightSubtreeMin.right.parent = rightSubtreeMin 165 | end 166 | self:transplant(node, rightSubtreeMin) 167 | rightSubtreeMin.left = node.left 168 | rightSubtreeMin.left.parent = rightSubtreeMin 169 | end 170 | self._size = self._size - 1 171 | return self 172 | end 173 | 174 | return BinarySearchTree 175 | -------------------------------------------------------------------------------- /src/tree/BinaryTree.lua: -------------------------------------------------------------------------------- 1 | local torch = require 'torch' 2 | 3 | --- @module BinaryTree.Node 4 | -- Node in a binary tree. 5 | -- This is a subclass of `Tree.Node` 6 | local BinaryTree = torch.class('tl.BinaryTree', 'tl.Tree') 7 | local BinaryTreeNode, parent = torch.class('tl.BinaryTree.Node', 'tl.Tree.Node') 8 | 9 | --- Constructor 10 | function BinaryTreeNode:__init(key, val) 11 | parent.__init(self, key, val) 12 | self.left = nil 13 | self.right = nil 14 | end 15 | 16 | --- @returns {table} children of this node 17 | function BinaryTreeNode:children() 18 | local tab = {} 19 | if self.left ~= nil then table.insert(tab, self.left) end 20 | if self.right ~= nil then table.insert(tab, self.right) end 21 | return tab 22 | end 23 | 24 | --- Traverses the tree in order. 25 | -- @arg {function=} callback - function to execute at each node 26 | function BinaryTreeNode:walkInOrder(callback) 27 | callback = callback or function(node) end 28 | if self.left ~= nil then 29 | self.left:walkInOrder(callback) 30 | end 31 | callback(self) 32 | if self.right ~= nil then 33 | self.right:walkInOrder(callback) 34 | end 35 | end 36 | 37 | 38 | --- @module BinaryTree 39 | -- Implementation of binary tree. 40 | -- This is a subclass of `Tree`. 41 | 42 | --- Constructor. 43 | function BinaryTree:__init() 44 | self.root = nil 45 | self._size = 0 46 | end 47 | 48 | --- Traverses the binary tree starting from the root in order 49 | -- @arg {function=} callback - function to execute at each node 50 | function BinaryTree:walkInOrder(callback) 51 | if self.root ~= nil then 52 | self.root:walkInOrder(callback) 53 | end 54 | end 55 | 56 | return BinaryTree 57 | -------------------------------------------------------------------------------- /src/tree/Tree.lua: -------------------------------------------------------------------------------- 1 | local torch = require 'torch' 2 | 3 | --- @module Tree 4 | -- Implementation of tree. 5 | local Tree = torch.class('tl.Tree') 6 | local TreeNode = torch.class('tl.Tree.Node') 7 | 8 | --- Constructor. 9 | function TreeNode:__init(key, val) 10 | if val == nil then val = key end 11 | self.parent = nil 12 | self.key = key 13 | self.val = val 14 | self._size = 0 15 | end 16 | 17 | --- @returns {table} children of this node 18 | function TreeNode:children() 19 | error('not implemented') 20 | end 21 | 22 | --- @returns {string} string representation 23 | function TreeNode:__tostring__() 24 | return torch.type(self) .. '<' .. tostring(self.val) .. '(' .. tostring(self.key) .. ')' .. '>' 25 | end 26 | 27 | --- @returns {string} string representation 28 | -- @arg {string} prefix - string to add before each line 29 | -- @arg {boolean} isLeaf - whether the subtree is a leaf 30 | function TreeNode:subtreeToString(prefix, isLeaf) 31 | prefix = prefix or '' 32 | isLeaf = isLeaf or true 33 | local s = prefix 34 | if isLeaf then s = s .. '|__ ' else s = s .. '|-- ' end 35 | s = s .. tostring(self) .. "\n" 36 | local newPrefix = prefix 37 | if isLeaf then newPrefix = newPrefix .. ' ' else newPrefix = newPrefix .. '| ' end 38 | local children = self:children() 39 | for i = 1, #children do 40 | s = s .. children[i]:subtreeToString(newPrefix, false) 41 | end 42 | if #children > 0 then 43 | children[#children]:subtreeToString(newPrefix, true) 44 | end 45 | return string.sub(s, 0, -1) 46 | end 47 | 48 | --- @returns {string} string representation 49 | function Tree:__tostring__() 50 | local s = torch.type(self) 51 | if self.root ~= nil then 52 | s = self.root:subtreeToString() 53 | end 54 | return s 55 | end 56 | 57 | --- @returns {int} number of nodes in the tree 58 | function Tree:size() 59 | return self._size 60 | end 61 | -------------------------------------------------------------------------------- /src/util/Download.lua: -------------------------------------------------------------------------------- 1 | local torch = require 'torch' 2 | local path = require 'pl.path' 3 | 4 | --- @module Downloader 5 | -- A download utility with caching support. 6 | local Downloader = torch.class('tl.Downloader') 7 | 8 | --- Constructor. 9 | -- 10 | -- @arg {string='/tmp/torchlib'} cache - cache directory 11 | -- 12 | -- Options: 13 | -- 14 | -- - `verbose`: prints out progress 15 | function Downloader:__init(cache, opt) 16 | opt = opt or {} 17 | self.cache = cache or '/tmp/torchlib' 18 | self.verbose = opt.verbose 19 | if not path.exists(self.cache) then 20 | if self.verbose then print('WARNING: Making cache directory at '..self.cache) end 21 | path.mkdir(self.cache) 22 | end 23 | end 24 | 25 | --- Retrieves a file from cache, downloading it from `url` if it doesn't exists. 26 | -- @arg {string} to - location to download to, relative to the cache directory 27 | -- @arg {string=} url - url to download from 28 | -- @arg {table[string:any]=} opt - options 29 | -- 30 | -- Options: 31 | -- 32 | -- - `force`: overwrite the file if one exists. 33 | function Downloader:get(to, url, opt) 34 | opt = opt or {} 35 | local to_path = path.join(self.cache, to) 36 | if (not path.exists(to_path)) or opt.force then 37 | os.execute('wget -O '..to_path..' '..url) 38 | local f = assert(io.open(to_path, 'rb')) 39 | f:close() 40 | end 41 | return to_path 42 | end 43 | 44 | return Downloader 45 | -------------------------------------------------------------------------------- /src/util/global.lua: -------------------------------------------------------------------------------- 1 | --- @arg {int} from - start index 2 | -- @arg {int=end} end - end index 3 | -- @arg {int=1} inc - value to increment by 4 | -- @returns {table} indices from `from` to `to`, incrementing by `inc` 5 | function tl.range(from, to, inc) 6 | inc = inc or 1 7 | if to == nil then 8 | to = from 9 | from = 1 10 | end 11 | 12 | local t = {} 13 | for i = from, to, inc do 14 | table.insert(t, i) 15 | end 16 | return t 17 | end 18 | 19 | --- @arg {table} a - first object 20 | -- @arg {table} b - second object 21 | -- @returns {boolean} whether the two objects are equal to each other 22 | function tl.equals(a, b) 23 | if torch.type(a) ~= torch.type(b) then return false end 24 | if type(a) == 'table' and a.equals ~= nil and type(b) == 'table' and b.equals ~= nil then 25 | return a:equals(b) 26 | end 27 | return a == b 28 | end 29 | 30 | --- @arg {any} t - object to copy 31 | -- @returns {any} deep copy 32 | -- 33 | -- from https://gist.github.com/MihailJP/3931841 34 | function tl.deepcopy(t) 35 | if type(t) ~= "table" then return t end 36 | local meta = getmetatable(t) 37 | local target = {} 38 | for k, v in pairs(t) do 39 | if type(v) == "table" then 40 | target[k] = tl.deepcopy(v) 41 | else 42 | target[k] = v 43 | end 44 | end 45 | setmetatable(target, meta) 46 | return target 47 | end 48 | 49 | --- @arg {any} t - object to copy 50 | -- @returns {any} shallow copy 51 | function tl.copy(t) 52 | local tab = {} 53 | for k, v in pairs(t) do 54 | tab[k] = v 55 | end 56 | return tab 57 | end 58 | -------------------------------------------------------------------------------- /src/util/string.lua: -------------------------------------------------------------------------------- 1 | --- @arg {string} s - larger string 2 | -- @arg {string} substring - smaller string 3 | -- @returns {boolean} whether the larger string starts with the smaller string 4 | function string.startswith(s, substring) 5 | return string.sub(s, 1, string.len(substring)) == substring 6 | end 7 | 8 | --- @arg {string} s - larger string 9 | -- @arg {string} substring - smaller string 10 | -- @returns {boolean} whether the larger string ends with the smaller string 11 | function string.endswith(s, substring) 12 | return string.sub(s, -string.len(substring)) == substring 13 | end 14 | -------------------------------------------------------------------------------- /src/util/table.lua: -------------------------------------------------------------------------------- 1 | --- @arg {table} t - a table 2 | -- @arg {string=} indent - indentation for nested keys 3 | -- @arg {string=} s - accumulated string 4 | --@returns {string} string representation for the table 5 | function table.tostring(t, indent, s) 6 | indent = indent or 0 7 | s = s or '' 8 | for k, v in pairs(t) do 9 | formatting = string.rep(" ", indent) .. k .. ": " 10 | if type(v) == "table" then 11 | s = s .. formatting .. '\n' 12 | s = table.tostring(v, indent+1, s) 13 | else 14 | s = s .. formatting .. tostring(v) .. '\n' 15 | end 16 | end 17 | return s 18 | end 19 | 20 | --- @arg {table} t - table to shuffle in place 21 | -- @returns {table} shuffled table 22 | function table.shuffle(t) 23 | local iter = #t 24 | local j 25 | for i = iter, 2, -1 do 26 | j = math.random(i) 27 | t[i], t[j] = t[j], t[i] -- swap 28 | end 29 | return t 30 | end 31 | 32 | --- @arg {table[any]} t1 - first table 33 | -- @arg {table[any]} t2 - seoncd table 34 | -- @returns {boolean} whether the keys and values of each table are equal 35 | function table.equals(t1, t2) 36 | for k1, v1 in pairs(t1) do 37 | if not tl.equals(t2[k1], v1) then 38 | return false 39 | end 40 | end 41 | for k2, v2 in pairs(t2) do 42 | if not tl.equals(t1[k2], v2) then 43 | return false 44 | end 45 | end 46 | return true 47 | end 48 | 49 | --- @arg {table[any]} t1 - first table 50 | -- @arg {table[any]} t2 - seoncd table 51 | -- @returns {boolean} whether the values of each table are equal, disregarding order 52 | function table.valuesEqual(t1, t2) 53 | for _, v1 in pairs(t1) do 54 | if not table.contains(t2, v1) then 55 | return false 56 | end 57 | end 58 | for _, v2 in pairs(t2) do 59 | if not table.contains(t1, v2) then 60 | return false 61 | end 62 | end 63 | return true 64 | end 65 | 66 | --- @arg {table} t - table to reverse 67 | -- @returns {table} A copy of the table, reversed. 68 | function table.reverse(t) 69 | local tab = {} 70 | for i, e in ipairs(t) do 71 | table.insert(tab, 1, e) 72 | end 73 | return tab 74 | end 75 | 76 | --- @arg {table} t - table to check 77 | -- @arg {any} val - value to check 78 | -- @returns {boolean} whether the tabale contains the value 79 | function table.contains(t, val) 80 | for k, v in pairs(t) do 81 | if tl.equals(v, val) then 82 | return true 83 | end 84 | end 85 | return false 86 | end 87 | 88 | --- Flattens the table. 89 | -- @arg {table} t - the table to modify 90 | -- @arg {table=} tab - where to store the results. If not given, then a new table will be used. 91 | -- @arg {string='__'} prefix - string to use to join nested keys. 92 | -- @returns {table} flattened table 93 | function table.flatten(t, tab, prefix) 94 | tab = tab or {} 95 | prefix = prefix or '' 96 | for k, v in pairs(t) do 97 | if type(v) == 'table' then 98 | table.flatten(v, tab, prefix..k..'__') 99 | else 100 | tab[prefix..k] = v 101 | end 102 | end 103 | return tab 104 | end 105 | 106 | --- Applies `callback` to each element in `t` and returns the results in another table. 107 | -- @arg {table} t - the table to modify 108 | -- @arg {function} callback - function to apply 109 | -- @returns {table} modified table 110 | function table.map(t, callback) 111 | local results = {} 112 | for k, v in pairs(t) do 113 | results[k] = callback(v) 114 | end 115 | return results 116 | end 117 | 118 | --- Selects items from table `t`. 119 | -- @arg {table} t - table to select from 120 | -- @arg {table} keys - table of keys 121 | -- @arg {boolean=} forget_keys - whether to retain the keys 122 | -- @returns {table} a table of key value pairs where the keys are `keys` and the values are corresponding values from `t`. 123 | -- 124 | -- If `forget_keys` is `true`, then the returned table will have integer keys. 125 | function table.select(t, keys, forget_keys) 126 | local results = {} 127 | for _, k in ipairs(keys) do 128 | if forget_keys then 129 | table.insert(results, t[k]) 130 | else 131 | results[k] = t[k] 132 | end 133 | end 134 | return results 135 | end 136 | 137 | --- Extends the table `t` with another table `another` 138 | -- @arg {table} t - first table 139 | -- @arg {table} another - second table 140 | -- @returns {table} modified first table 141 | function table.extend(t, another) 142 | for _, v in ipairs(another) do 143 | table.insert(t, v) 144 | end 145 | return t 146 | end 147 | 148 | --- Returns all combinations of elements in a table. 149 | -- 150 | -- @arg {table[table[any]]} input - a collection of lists to compute the combination for 151 | -- @returns {table[table[any]]} combinations of the input 152 | -- 153 | -- Example: 154 | -- 155 | -- @code 156 | -- table.combinations{{1, 2}, {'a', 'b', 'c'}} 157 | -- 158 | -- This returns `{{1, 'a'}, {1, 'b'}, {1, 'c'}, {2, 'a'}, {2, 'b'}, {2, 'c'}}` 159 | function table.combinations(input) 160 | local result = {} 161 | local recurse 162 | recurse = function(tab, idx, ...) 163 | if idx < 1 then 164 | table.insert(result, table.pack(...)) 165 | else 166 | local t = tab[idx] 167 | for i = 1, #t do recurse(tab, idx-1, t[i], ...) end 168 | end 169 | end 170 | 171 | recurse(input, #input) 172 | for i, t in ipairs(result) do t.n = nil end 173 | return result 174 | end 175 | -------------------------------------------------------------------------------- /test.lua: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env th 2 | require('test/test_counter') 3 | require('test/test_dataset') 4 | require('test/test_download') 5 | require('test/test_graph') 6 | require('test/test_heap') 7 | require('test/test_list') 8 | require('test/test_map') 9 | require('test/test_model') 10 | require('test/test_prob_table') 11 | require('test/test_queue') 12 | require('test/test_set') 13 | require('test/test_stack') 14 | require('test/test_tree') 15 | require('test/test_variable_tensor') 16 | require('test/test_vocab') 17 | require('test/test_scorer') 18 | require('test/test_util') 19 | -------------------------------------------------------------------------------- /test/mock/conll.mock: -------------------------------------------------------------------------------- 1 | # name grade 2 | class1 3 | Adam A 4 | Bob B 5 | Carol C 6 | 7 | class2 8 | Zack Z 9 | Yasmin Y 10 | Xavier X 11 | 12 | class3 13 | Wesley W 14 | Victor V 15 | -------------------------------------------------------------------------------- /test/test_counter.lua: -------------------------------------------------------------------------------- 1 | local Counter = require('torchlib').Counter 2 | 3 | local TestCounter = torch.TestSuite() 4 | local tester = torch.Tester() 5 | 6 | function TestCounter.test_get() 7 | local c = Counter() 8 | tester:asserteq(0, c:get('foo')) 9 | c.counts['foo'] = 10 10 | tester:asserteq(10, c:get('foo')) 11 | end 12 | 13 | function TestCounter.test_add() 14 | local c = Counter() 15 | tester:asserteq(1, c:add('foo')) 16 | tester:asserteq(1, c.counts.foo) 17 | tester:asserteq(3, c:add('foo', 2)) 18 | tester:asserteq(3, c.counts.foo) 19 | end 20 | 21 | function TestCounter.test_reset() 22 | local c = Counter() 23 | c.counts.foo = 10 24 | c:reset() 25 | tester:asserteq(nil, c.counts.foo) 26 | end 27 | 28 | tester:add(TestCounter) 29 | tester:run() 30 | -------------------------------------------------------------------------------- /test/test_dataset.lua: -------------------------------------------------------------------------------- 1 | local Dataset = require('torchlib').Dataset 2 | 3 | local TestDataset = torch.TestSuite() 4 | local tester = torch.Tester() 5 | 6 | local eps = 1e-5 7 | 8 | local torch = require 'torch' 9 | local T = torch.Tensor 10 | 11 | -- seed the shuffling 12 | torch.manualSeed(12) 13 | 14 | function TestDataset.test_conll() 15 | local d = Dataset.from_conll('test/mock/conll.mock') 16 | tester:asserteq(3, d:size()) 17 | tester:asserteq(3, #d.name) 18 | tester:assertTableEq({'Adam', 'Bob', 'Carol'}, d.name[1]) 19 | tester:assertTableEq({'A', 'B', 'C'}, d.grade[1]) 20 | tester:assertTableEq({'Zack', 'Yasmin', 'Xavier'}, d.name[2]) 21 | tester:assertTableEq({'Z', 'Y', 'X'}, d.grade[2]) 22 | tester:assertTableEq({'Wesley', 'Victor'}, d.name[3]) 23 | tester:assertTableEq({'W', 'V'}, d.grade[3]) 24 | tester:assertTableEq({'class1', 'class2', 'class3'}, d.label) 25 | end 26 | 27 | local X = {T{1, 2, 3}, T{2, 3}, T{1, 3, 5, 2}} 28 | local Y = {2, 5, 1} 29 | 30 | local toyDataset = function() 31 | return Dataset{X=tl.copy(X), Y=tl.copy(Y)} 32 | end 33 | 34 | function TestDataset.test_shuffle() 35 | local x = {1, 2, 3, 4, 5, 6, 7} 36 | local d = Dataset{x=x} 37 | d:shuffle() 38 | tester:assertTableNe(x, d.x) 39 | for _, t in ipairs(x) do 40 | tester:assert(table.contains(d.x, t)) 41 | end 42 | end 43 | 44 | function TestDataset.test_pad() 45 | local padded = Dataset.pad(X) 46 | tester:assertTableEq(T{0, 1, 2, 3}:totable(), padded[1]:totable()) 47 | tester:assertTableEq(T{0, 0, 2, 3}:totable(), padded[2]:totable()) 48 | tester:assertTableEq(T{1, 3, 5, 2}:totable(), padded[3]:totable()) 49 | end 50 | 51 | function TestDataset.test_kfolds() 52 | local d = toyDataset() 53 | local folds = d:kfolds(3) 54 | tester:asserteq(3, #folds) 55 | local n = 0 56 | for _, fold in ipairs(folds) do 57 | n = n + #fold 58 | end 59 | tester:asserteq(d:size(), n) 60 | end 61 | 62 | function TestDataset.test_view() 63 | local d = toyDataset() 64 | local dd = d:view({1, 3}) 65 | tester:asserteq(2, dd:size()) 66 | tester:assertTableEq(d.X[1]:totable(), dd.X[1]:totable()) 67 | tester:assertTableEq(d.X[3]:totable(), dd.X[2]:totable()) 68 | local a, b = d:view({1, 3}, {3, 2, 3}) 69 | tester:assertTableEq(d.X[1]:totable(), a.X[1]:totable()) 70 | tester:assertTableEq(d.X[3]:totable(), a.X[2]:totable()) 71 | tester:assertTableEq(d.X[3]:totable(), b.X[1]:totable()) 72 | tester:assertTableEq(d.X[2]:totable(), b.X[2]:totable()) 73 | tester:assertTableEq(d.X[3]:totable(), b.X[3]:totable()) 74 | end 75 | 76 | function TestDataset.test_train_dev_split() 77 | local d = toyDataset() 78 | local train, test = d:train_dev_split{1, 3} 79 | tester:asserteq(2, train:size()) 80 | tester:assertTableEq(d.X[1]:totable(), train.X[1]:totable()) 81 | tester:assertTableEq(d.X[3]:totable(), train.X[2]:totable()) 82 | tester:asserteq(1, test:size()) 83 | tester:assertTableEq(d.X[2]:totable(), test.X[1]:totable()) 84 | end 85 | 86 | function TestDataset.test_batches() 87 | local d = toyDataset() 88 | local i = 1 89 | for batch, batch_end in d:batches(2) do 90 | local p = require 'pl.pretty' 91 | if i == 1 then 92 | tester:assertTableEq(T{{1, 2, 3}, {0, 2, 3}}:totable(), Dataset.pad(batch.X):totable()) 93 | tester:assertTableEq(T{2, 5}:totable(), batch.Y) 94 | tester:asserteq(2, batch_end) 95 | end 96 | if i == 2 then 97 | tester:assertTableEq(T{{1, 3, 5, 2}}:totable(), Dataset.pad(batch.X):totable()) 98 | tester:assertTableEq({1}, batch.Y) 99 | tester:asserteq(3, batch_end) 100 | end 101 | i = i + 1 102 | end 103 | end 104 | 105 | function TestDataset.test_single_batch() 106 | local d = toyDataset() 107 | local i = 1 108 | for batch, batch_end in d:batches(1) do 109 | tester:assertTableEq({X[i]:totable()}, Dataset.pad(batch.X):totable()) 110 | tester:assertTableEq({Y[i]}, batch.Y) 111 | tester:asserteq(batch_end, i) 112 | i = i + 1 113 | end 114 | end 115 | 116 | function TestDataset.test_transform() 117 | local d = Dataset{names={'Alice', 'Bob'}, ids={3, 2}} 118 | local d2 = d:transform{names=string.lower, ids=function(x) return x-1 end} 119 | -- test not in place 120 | tester:assertTableEq({'alice', 'bob'}, d2.names) 121 | tester:assertTableEq({2, 1}, d2.ids) 122 | tester:assertTableEq({'Alice', 'Bob'}, d.names) 123 | tester:assertTableEq({3, 2}, d.ids) 124 | -- test in place 125 | local d2 = d:transform({names=string.lower}, true) 126 | tester:assertTableEq({'alice', 'bob'}, d.names) 127 | tester:assertTableEq({'alice', 'bob'}, d2.names) 128 | end 129 | 130 | function TestDataset.test_tostring() 131 | local d = Dataset{names={'Alice', 'Bob'}, ids={3, 2}} 132 | tester:assert('tl.Dataset(names, ids) of size 2' == tostring(d) or 'tl.Dataset(ids, names) of size 2' == tostring(d)) 133 | end 134 | 135 | function TestDataset.test_sort_by_length() 136 | local a, b, c = torch.rand(3), torch.rand(2), torch.rand(4) 137 | local d = Dataset{a={a, b, c}, b={1, 2, 3}} 138 | d:sort_by_length('a') 139 | tester:asserteq(d.a[1], b) 140 | tester:asserteq(d.b[1], 2) 141 | tester:asserteq(d.a[2], a) 142 | tester:asserteq(d.b[2], 1) 143 | tester:asserteq(d.a[3], c) 144 | tester:asserteq(d.b[3], 3) 145 | end 146 | 147 | tester:add(TestDataset) 148 | tester:run() 149 | -------------------------------------------------------------------------------- /test/test_download.lua: -------------------------------------------------------------------------------- 1 | local Downloader = require('torchlib').Downloader 2 | 3 | local TestDownload = torch.TestSuite() 4 | local tester = torch.Tester() 5 | 6 | local T = torch.Tensor 7 | 8 | function TestDownload.test_cache() 9 | local d = Downloader() 10 | tester:asserteq('/tmp/torchlib', d.cache) 11 | tester:asserteq(false, path.exists('./foo')) 12 | d = Downloader('/tmp/torchlib-foo') 13 | tester:assert(path.exists('/tmp/torchlib-foo') ~= nil) 14 | path.rmdir('/tmp/torchlib-foo') 15 | end 16 | 17 | function TestDownload.test_get() 18 | local d = Downloader() 19 | local ret = d:get('google.txt', 'http://www.google.com/robots.txt') 20 | tester:asserteq(ret, '/tmp/torchlib/google.txt') 21 | tester:assert(path.exists('/tmp/torchlib/google.txt') ~= nil) 22 | end 23 | 24 | tester:add(TestDownload) 25 | tester:run() 26 | -------------------------------------------------------------------------------- /test/test_graph.lua: -------------------------------------------------------------------------------- 1 | local tl = require('torchlib') 2 | local Graph = tl.Graph 3 | local DirectedGraph = tl.DirectedGraph 4 | local UndirectedGraph = tl.UndirectedGraph 5 | local Set = tl.Set 6 | 7 | 8 | local TestDirectedGraph = torch.TestSuite() 9 | local TestUndirectedGraph = torch.TestSuite() 10 | local tester = torch.Tester() 11 | 12 | 13 | function TestDirectedGraph.testAddNodeDirected() 14 | local g = DirectedGraph() 15 | local na = g:addNode('a') 16 | local nb = g:addNode('b') 17 | local nc = g:addNode('c') 18 | tester:asserteq(3, g:size()) 19 | 20 | g:connect(na, nb) 21 | g:connect(nc, na) 22 | tester:assertTableEq({nb}, g:connectionsOf(na)) 23 | tester:assertTableEq({na}, g:connectionsOf(nc)) 24 | end 25 | 26 | 27 | function TestUndirectedGraph.testAddNodeUndirected() 28 | local g = UndirectedGraph.new() 29 | local na = g:addNode('a') 30 | local nb = g:addNode('b') 31 | local nc = g:addNode('c') 32 | tester:asserteq(3, g:size()) 33 | 34 | g:connect(na, nb) 35 | g:connect(nc, na) 36 | tester:assert(table.valuesEqual({nb, nc}, g:connectionsOf(na))) 37 | tester:assertTableEq({na}, g:connectionsOf(nc)) 38 | end 39 | 40 | function getUndirectedGraph() 41 | -- figure 22.3 from CLRS 42 | local g = UndirectedGraph() 43 | local r = g:addNode('r') 44 | local s = g:addNode('s') 45 | local t = g:addNode('t') 46 | local u = g:addNode('u') 47 | local v = g:addNode('v') 48 | local w = g:addNode('w') 49 | local x = g:addNode('x') 50 | local y = g:addNode('y') 51 | 52 | g:connect(v, r) 53 | g:connect(r, s) 54 | g:connect(s, w) 55 | g:connect(w, t) 56 | g:connect(w, x) 57 | g:connect(t, x) 58 | g:connect(t, u) 59 | g:connect(x, u) 60 | g:connect(x, y) 61 | g:connect(y, u) 62 | 63 | return g, r, s, t, u, v, w, x, y 64 | end 65 | 66 | 67 | function TestUndirectedGraph.testBFS() 68 | local g, r, s, t, u, v, w, x, y = getUndirectedGraph() 69 | local discovered, finished = {}, {} 70 | g:breadthFirstSearch(s, { 71 | discover=function(n) table.insert(discovered, n) end, 72 | finish=function(n) table.insert(finished, n) end 73 | }) 74 | 75 | tester:asserteq(0, s.timestamp) 76 | tester:asserteq(nil, s.parent) 77 | tester:asserteq(Graph.state.FINISHED, s.state) 78 | 79 | tester:asserteq(1, r.timestamp) 80 | tester:asserteq(s, r.parent) 81 | tester:asserteq(Graph.state.FINISHED, r.state) 82 | 83 | tester:asserteq(1, w.timestamp) 84 | tester:asserteq(s, w.parent) 85 | tester:asserteq(Graph.state.FINISHED, w.state) 86 | 87 | tester:asserteq(2, v.timestamp) 88 | tester:asserteq(r, v.parent) 89 | tester:asserteq(Graph.state.FINISHED, v.state) 90 | 91 | tester:asserteq(2, t.timestamp) 92 | tester:asserteq(w, t.parent) 93 | tester:asserteq(Graph.state.FINISHED, t.state) 94 | 95 | tester:asserteq(2, x.timestamp) 96 | tester:asserteq(w, x.parent) 97 | tester:asserteq(Graph.state.FINISHED, x.state) 98 | 99 | tester:asserteq(3, u.timestamp) 100 | tester:assert(u.parent == t or u.parent == x) 101 | tester:asserteq(Graph.state.FINISHED, u.state) 102 | 103 | tester:asserteq(3, y.timestamp) 104 | tester:asserteq(x, y.parent) 105 | tester:asserteq(Graph.state.FINISHED, y.state) 106 | 107 | tester:asserteq(g:size(), Set(discovered):size()) 108 | tester:asserteq(g:size(), Set(finished):size()) 109 | end 110 | 111 | function TestUndirectedGraph.testShortestPath() 112 | local g, r, s, t, u, v, w, x, y = getUndirectedGraph() 113 | local got = g:shortestPath(s, y) 114 | tester:assertTableEq({s, w, x, y}, got) 115 | 116 | got = g:shortestPath(s, t, true) 117 | tester:assertTableEq({s, w, t}, got) 118 | 119 | got = g:shortestPath(v, y) 120 | tester:assertTableEq({v, r, s, w, x, y}, got) 121 | 122 | g = UndirectedGraph() 123 | r = g:addNode('r') 124 | s = g:addNode('s') 125 | tester:assertErrorPattern(function() g:shortestPath(r, s) end, 'Error: no path from tl.Graph.Node.r. to tl.Graph.Node.s.') 126 | end 127 | 128 | function getDirectedGraph() 129 | -- from CLRS fig 22.4 130 | local g = DirectedGraph() 131 | local u = g:addNode('u') 132 | local v = g:addNode('v') 133 | local w = g:addNode('w') 134 | local x = g:addNode('x') 135 | local y = g:addNode('y') 136 | local z = g:addNode('z') 137 | g:connect(u, v) 138 | g:connect(u, x) 139 | g:connect(x, v) 140 | g:connect(v, y) 141 | g:connect(y, x) 142 | g:connect(w, y) 143 | g:connect(w, z) 144 | g:connect(z, z) 145 | return g 146 | end 147 | 148 | function getDirectedAcyclicGraph() 149 | -- from CLRS fig 22.7 150 | local g = DirectedGraph() 151 | local undershorts = g:addNode('undershorts') 152 | local pants = g:addNode('pants') 153 | local belt = g:addNode('belt') 154 | local shirt = g:addNode('shirt') 155 | local tie = g:addNode('tie') 156 | local jacket = g:addNode('jacket') 157 | local socks = g:addNode('socks') 158 | local shoes = g:addNode('shoes') 159 | local watch = g:addNode('watch') 160 | g:connect(undershorts, pants) 161 | g:connect(undershorts, shoes) 162 | g:connect(socks, shoes) 163 | g:connect(pants, shoes) 164 | g:connect(pants, belt) 165 | g:connect(shirt, belt) 166 | g:connect(shirt, tie) 167 | g:connect(tie, jacket) 168 | g:connect(belt, jacket) 169 | return g, undershorts, pands, belt, shirt, tie, jacket, socks, shoes, watch 170 | end 171 | 172 | function TestDirectedGraph.testDFS() 173 | local g, undershorts, pands, belt, shirt, tie, jacket, socks, shoes, watch = getDirectedGraph() 174 | local discovered, finished = {}, {} 175 | g:depthFirstSearch(g:nodeSet():totable(), { 176 | discover=function(n) table.insert(discovered, n) end, 177 | finish=function(n) table.insert(finished, n) end 178 | }) 179 | tester:asserteq(g:size(), Set(discovered):size()) 180 | tester:asserteq(g:size(), Set(finished):size()) 181 | end 182 | 183 | function TestDirectedGraph.testTopologicalSort() 184 | local g, undershorts, pands, belt, shirt, tie, jacket, socks, shoes, watch = getDirectedAcyclicGraph() 185 | local sorted = g:topologicalSort() 186 | -- test correctness automatically 187 | -- table.print(sorted) 188 | end 189 | 190 | function TestDirectedGraph.testHasCycle() 191 | local g = getDirectedGraph() 192 | tester:assert(g:hasCycle()) 193 | g = getDirectedAcyclicGraph() 194 | tester:assert(not g:hasCycle()) 195 | end 196 | 197 | function TestDirectedGraph.testTranspose() 198 | local g = DirectedGraph.new() 199 | local a = g:addNode('a') 200 | local b = g:addNode('b') 201 | local c = g:addNode('c') 202 | g:connect(a, b) 203 | g:connect(c, a) 204 | 205 | local t = g:transpose() 206 | tester:assertTableEq(t:connectionsOf(a), {c}) 207 | tester:assertTableEq(t:connectionsOf(b), {a}) 208 | tester:assertTableEq(t:connectionsOf(c), {}) 209 | 210 | tester:assertTableEq(g:connectionsOf(a), {b}) 211 | tester:assertTableEq(g:connectionsOf(b), {}) 212 | tester:assertTableEq(g:connectionsOf(c), {a}) 213 | end 214 | 215 | function TestDirectedGraph.testStronglyConnectedComponents() 216 | local g = getDirectedGraph() 217 | local roots = g:stronglyConnectedComponents() 218 | -- test correctness automatically 219 | -- table.print(roots) 220 | end 221 | 222 | tester:add(TestDirectedGraph) 223 | tester:add(TestUndirectedGraph) 224 | tester:run() 225 | -------------------------------------------------------------------------------- /test/test_heap.lua: -------------------------------------------------------------------------------- 1 | local Heap = require('torchlib').Heap 2 | 3 | local TestHeap = torch.TestSuite() 4 | local tester = torch.Tester() 5 | 6 | function TestHeap.testIndex() 7 | tester:asserteq(1, Heap.parent(3)) 8 | tester:asserteq(4, Heap.left(2)) 9 | tester:asserteq(5, Heap.right(2)) 10 | end 11 | 12 | function TestHeap.testPush() 13 | local h = Heap.new() 14 | h:push(10, 'bob') 15 | local p, v = table.unpack(h:get(1)) 16 | tester:asserteq(10, p) 17 | tester:asserteq('bob', v) 18 | tester:asserteq(1, h:size()) 19 | 20 | h:push(20, 'bill') 21 | p, v = table.unpack(h:get(1)) 22 | tester:asserteq(20, p) 23 | tester:asserteq('bill', v) 24 | p, v = table.unpack(h:get(2)) 25 | tester:asserteq(10, p) 26 | tester:asserteq('bob', v) 27 | tester:asserteq(2, h:size()) 28 | 29 | h:push(15, 'ben') 30 | p, v = table.unpack(h:get(1)) 31 | tester:asserteq(20, p) 32 | tester:asserteq('bill', v) 33 | p, v = table.unpack(h:get(2)) 34 | tester:asserteq(15, p) 35 | tester:asserteq('ben', v) 36 | p, v = table.unpack(h:get(3)) 37 | tester:asserteq(10, p) 38 | tester:asserteq('bob', v) 39 | tester:asserteq(3, h:size()) 40 | end 41 | 42 | function TestHeap.testPop() 43 | local h = Heap.new() 44 | h:push(3, 'c') 45 | h:push(5, 'a') 46 | h:push(1, 'e') 47 | h:push(2, 'd') 48 | h:push(4, 'b') 49 | 50 | tester:asserteq('a', h:pop()) 51 | tester:asserteq(4, h:size()) 52 | tester:asserteq('b', h:pop()) 53 | tester:asserteq(3, h:size()) 54 | tester:asserteq('c', h:pop()) 55 | tester:asserteq(2, h:size()) 56 | tester:asserteq('d', h:pop()) 57 | tester:asserteq(1, h:size()) 58 | tester:asserteq('e', h:pop()) 59 | tester:asserteq(0, h:size()) 60 | end 61 | 62 | function TestHeap.testToString() 63 | local h = Heap.new() 64 | h:push(3, 'c') 65 | h:push(5, 'a') 66 | h:push(2, 'd') 67 | h:push(4, 'b') 68 | 69 | tester:asserteq('tl.Heap[a(5), b(4), d(2), c(3)]', tostring(h)) 70 | 71 | h:push(4, 'v') 72 | tester:asserteq('tl.Heap[a(5), v(4), b(4), d(2), c(3), ...]', tostring(h)) 73 | end 74 | 75 | function TestHeap.testSort() 76 | local h = Heap.new() 77 | h:push(3, 'c') 78 | h:push(5, 'a') 79 | h:push(2, 'd') 80 | h:push(4, 'b') 81 | h:sort() 82 | tester:asserteq('tl.Heap[d(2), c(3), b(4), a(5)]', tostring(h)) 83 | end 84 | 85 | function TestHeap.testPeek() 86 | local h = Heap.new() 87 | h:push(3, 'c') 88 | tester:asserteq('c', h:peek()) 89 | h:push(5, 'a') 90 | tester:asserteq('a', h:peek()) 91 | end 92 | 93 | tester:add(TestHeap) 94 | tester:run() 95 | -------------------------------------------------------------------------------- /test/test_list.lua: -------------------------------------------------------------------------------- 1 | local tl = require('torchlib') 2 | local List = tl.List 3 | local LinkedList = tl.LinkedList 4 | local ArrayList = tl.ArrayList 5 | 6 | local TestList = torch.TestSuite() 7 | local TestArrayList = torch.TestSuite() 8 | local TestLinkedList = torch.TestSuite() 9 | local tester = torch.Tester() 10 | 11 | local TestGeneric = {} 12 | 13 | function TestGeneric.testAdd(ListClass) 14 | local l = ListClass() 15 | tester:asserteq(0, l:size()) 16 | 17 | l:add(10) 18 | tester:asserteq(1, l:size()) 19 | tester:asserteq(10, l:get(1)) 20 | 21 | l:add(7, 1) 22 | tester:asserteq(2, l:size()) 23 | tester:asserteq(10, l:get(2)) 24 | tester:asserteq(7, l:get(1)) 25 | 26 | v = l:remove(1) 27 | tester:asserteq(7, v) 28 | tester:asserteq(1, l:size()) 29 | 30 | l:addMany(1, 3, 2, 4) 31 | tester:asserteq(5, l:size()) 32 | tester:asserteq(10, l:get(1)) 33 | tester:asserteq(1, l:get(2)) 34 | tester:asserteq(3, l:get(3)) 35 | tester:asserteq(2, l:get(4)) 36 | tester:asserteq(4, l:get(5)) 37 | 38 | l:add(0, 3) 39 | tester:asserteq(6, l:size()) 40 | tester:asserteq(10, l:get(1)) 41 | tester:asserteq(1, l:get(2)) 42 | tester:asserteq(0, l:get(3)) 43 | tester:asserteq(3, l:get(4)) 44 | tester:asserteq(2, l:get(5)) 45 | tester:asserteq(4, l:get(6)) 46 | 47 | tester:asserteq(true, l:contains(3)) 48 | tester:asserteq(false, l:contains(11)) 49 | end 50 | 51 | function TestGeneric.testSet(ListClass) 52 | local l = ListClass{10, 1, 3, 2, 4} 53 | l:set(2, 100) 54 | tester:asserteq(100, l:get(2)) 55 | end 56 | 57 | function TestGeneric.testCopy(ListClass) 58 | local l = ListClass{10, 1, 3, 2, 4} 59 | tester:assert(ListClass{10, 1, 3, 2, 4}:equals(l)) 60 | tester:assert(not ListClass{10, 1, 3, 4}:equals(l)) 61 | tester:assert(not ListClass{10, 1, 3, 2, 4, 5}:equals(l)) 62 | end 63 | 64 | function TestGeneric.testSublist(ListClass) 65 | local l = ListClass{10, 1, 3, 2, 4} 66 | local s = l:sublist(2, 4) 67 | tester:asserteq(3, s:size()) 68 | tester:asserteq(1, s:get(1)) 69 | tester:asserteq(3, s:get(2)) 70 | tester:asserteq(2, s:get(3)) 71 | 72 | s = l:sublist(4) 73 | tester:asserteq(2, s:size()) 74 | tester:asserteq(2, s:get(1)) 75 | tester:asserteq(4, s:get(2)) 76 | end 77 | 78 | function TestGeneric.testEquals(ListClass) 79 | local a = ListClass{1, 2, 3} 80 | local b = ListClass{1, 2} 81 | tester:asserteq(false, a:equals(ListClass{1, 2})) 82 | tester:asserteq(false, a:equals(ListClass{1, 2, 4})) 83 | tester:asserteq(true, a:equals(ListClass{1, 2, 3})) 84 | end 85 | 86 | function TestGeneric.testRemove(ListClass) 87 | local l = ListClass{10, 1, 3, 2, 4} 88 | l:remove(3) 89 | tester:asserteq(true, l:equals(ListClass{10, 1, 2, 4})) 90 | l:remove(4) 91 | tester:asserteq(true, l:equals(ListClass{10, 1, 2})) 92 | end 93 | 94 | function TestGeneric.testCopy(ListClass) 95 | local l = ListClass{1, 2, 3} 96 | local l2 = l:copy() 97 | tester:assertTableEq({1, 2, 3}, l:totable()) 98 | tester:assertTableEq({1, 2, 3}, l2:totable()) 99 | tester:assert(l ~= l2) 100 | end 101 | 102 | function TestGeneric.testSwap(ListClass) 103 | local l = ListClass{'a', 'b', 'c', 'd', 'e'} 104 | local expect = ListClass{'a', 'e', 'c', 'd', 'b'} 105 | tester:assert(l:swap(2, 5):equals(expect)) 106 | end 107 | 108 | function TestGeneric.testSort(ListClass) 109 | local l = ListClass{5, 4, 2, 3, 1} 110 | local expect = ListClass{1, 2, 3, 4, 5} 111 | l:sort() 112 | tester:assert(l:equals(expect)) 113 | end 114 | 115 | function TestGeneric.testToTable(ListClass) 116 | local l = ListClass{5, 4, 2, 3, 1} 117 | tester:assertTableEq({5, 4, 2, 3, 1}, l:totable()) 118 | end 119 | 120 | function TestList.testAbstractMethods() 121 | local funcs = {'__init', 'add', 'get', 'set', 'remove', 'equals', 'swap', 'totable'} 122 | for _, fname in ipairs(funcs) do 123 | tester:assertErrorPattern(List[fname], 'not implemented', fname..' should be a virtual method') 124 | end 125 | end 126 | 127 | function TestArrayList.testToStringArrayList() 128 | local l = ArrayList{1, 2, 3} 129 | tester:asserteq('tl.ArrayList[1, 2, 3]', tostring(l)) 130 | l = ArrayList{1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5} 131 | tester:asserteq('tl.ArrayList[1, 2, 3, 1, 2, ...]', tostring(l)) 132 | end 133 | 134 | function TestLinkedList.testToStringLinkedList() 135 | local l = LinkedList{1, 2, 3} 136 | tester:asserteq('tl.LinkedList[1, 2, 3]', tostring(l)) 137 | end 138 | 139 | function TestList.testLinkedListNode() 140 | tester:asserteq('LinkedListNode(1)', tostring(tl.LinkedList.Node.new(1))) 141 | end 142 | 143 | for k, f in pairs(TestGeneric) do 144 | TestArrayList[k..'ArrayList'] = function() f(ArrayList) end 145 | TestLinkedList[k..'LinkedList'] = function() f(LinkedList) end 146 | end 147 | 148 | tester:add(TestList) 149 | tester:add(TestArrayList) 150 | tester:add(TestLinkedList) 151 | tester:run() 152 | -------------------------------------------------------------------------------- /test/test_map.lua: -------------------------------------------------------------------------------- 1 | local Map = require('torchlib').Map 2 | local HashMap = require('torchlib').HashMap 3 | 4 | local TestMap = torch.TestSuite() 5 | local tester = torch.Tester() 6 | 7 | function TestMap.testAdd() 8 | local m = HashMap() 9 | m:add(10, 'hi') 10 | local t = {} 11 | t[10] = 'hi' 12 | tester:assertTableEq(t, m._map) 13 | 14 | m:add(10, 'bye') 15 | t[10] = 'bye' 16 | tester:assertTableEq(t, m._map) 17 | 18 | m:add('hi', 1) 19 | t['hi'] = 1 20 | tester:assertTableEq(t, m._map) 21 | 22 | local tab = {'foo'} 23 | m:add(tab, 'bar') 24 | t[tab] = 'bar' 25 | tester:assertTableEq(t, m._map) 26 | 27 | t['a'] = 1 28 | t['b'] = 2 29 | t['c'] = 3 30 | m:addMany({a=1, b=2, c=3}) 31 | tester:assertTableEq(t, m._map) 32 | end 33 | 34 | function TestMap.testCopy() 35 | tester:assert(HashMap{a=1, b=2, c=3}:equals(HashMap{a=1, b=2, c=3}) ~= nil) 36 | tester:assert(not HashMap{a=1, b=2, c=3}:equals(HashMap{a=1, c=3})) 37 | tester:assert(not HashMap{a=1, b=2, c=3}:equals(HashMap{a=1, b=2, c=3, d=4})) 38 | end 39 | 40 | function TestMap.testEquals() 41 | local m = HashMap{foo=1, bar=2, baz=3} 42 | local n = HashMap{foo=1, baz=3} 43 | tester:asserteq(false, m:equals(n)) 44 | n:add('bar', 10) 45 | tester:asserteq(false, m:equals(n)) 46 | n:add('bar', 2) 47 | tester:asserteq(true, m:equals(n)) 48 | end 49 | 50 | function TestMap.testContains() 51 | local m = HashMap() 52 | tester:assert(not m:contains('bar')) 53 | m:add('bar', 'foo') 54 | tester:assert(m:contains('bar')) 55 | end 56 | 57 | function TestMap.testGet() 58 | local m = HashMap() 59 | m:add(10, 'hi') 60 | m:add(20, 'bye') 61 | tester:asserteq('hi', m:get(10)) 62 | tester:asserteq('bye', m:get(20)) 63 | 64 | tester:assertErrorPattern(function() m:get('bad') end, 'Error: key bad not found in HashMap', 'get invalid key should error') 65 | tester:assert(m:get('bad', true) == nil) 66 | end 67 | 68 | function TestMap.testRemove() 69 | local m = HashMap() 70 | m:add(10, 20) 71 | m:add(20, 30) 72 | m:remove(10) 73 | local t = {} 74 | t[20] = 30 75 | tester:assertTableEq(t, m._map) 76 | 77 | local s, e = pcall(m.remove, m, 'bad') 78 | tester:assert(string.match(e, 'Error: key bad not found in HashMap') ~= nil) 79 | end 80 | 81 | function TestMap.testSize() 82 | local m = HashMap() 83 | tester:asserteq(0, m:size()) 84 | m:add(10, 20) 85 | tester:asserteq(1, m:size()) 86 | m:add(20, 10) 87 | tester:asserteq(2, m:size()) 88 | m:remove(20) 89 | tester:asserteq(1, m:size()) 90 | end 91 | 92 | function TestMap.testToString() 93 | local m = HashMap() 94 | tester:asserteq('tl.HashMap{}', tostring(m)) 95 | m:add('foo', 'bar') 96 | tester:asserteq('tl.HashMap{foo -> bar}', tostring(m)) 97 | m:add(1, 2) 98 | tostring(m) 99 | end 100 | 101 | function TestMap.testToTable() 102 | local m = HashMap{foo=1, bar=2, baz=3} 103 | tester:assertTableEq({foo=1, bar=2, baz=3}, m:totable()) 104 | end 105 | 106 | function TestMap.testAbstractMethods() 107 | local funcs = {'__init', 'add', 'addMany', 'copy', 'contains', 'get', 'remove', 'keys', 'equals', 'totable'} 108 | for _, fname in ipairs(funcs) do 109 | tester:assertErrorPattern(Map[fname], 'not implemented', fname..' should be a virtual method') 110 | end 111 | end 112 | 113 | tester:add(TestMap) 114 | tester:run() 115 | -------------------------------------------------------------------------------- /test/test_model.lua: -------------------------------------------------------------------------------- 1 | local Model = require('torchlib').Model 2 | local Dataset = require('torchlib').Dataset 3 | local optim = require 'optim' 4 | local nn = require 'nn' 5 | local torch = require 'torch' 6 | local dir = require 'pl.dir' 7 | 8 | local TestModel = torch.TestSuite() 9 | local tester = torch.Tester() 10 | 11 | -- seed the shuffling 12 | torch.manualSeed(12) 13 | 14 | local MyModel = torch.class('MyModel', 'tl.Model') 15 | 16 | function MyModel:required_params() 17 | return {'d_in', 'd_hid'} 18 | end 19 | 20 | function MyModel:get_net() 21 | return nn.Sequential() 22 | :add(nn.Linear(self.opt.d_in, self.opt.d_hid)) 23 | :add(nn.Tanh()) 24 | :add(nn.Linear(self.opt.d_hid, 1)) 25 | end 26 | 27 | function MyModel:get_criterion() 28 | return nn.MSECriterion() 29 | end 30 | 31 | local get_dataset = function(n) 32 | n = n or 100 33 | local fields = {X={}, Y={}} 34 | for i = 1, n do 35 | fields.X[i] = torch.rand(2) 36 | fields.Y[i] = fields.X[i][1] * 2 + fields.X[i][2] 37 | end 38 | return Dataset(fields) 39 | end 40 | 41 | function TestModel.test_init() 42 | -- init while missing required parameter 43 | tester:assertErrorPattern(function() MyModel.new() end, 'missing required parameter d_in') 44 | tester:assertErrorPattern(function() MyModel.new{d_in=1} end, 'missing required parameter d_hid') 45 | 46 | local model = MyModel.new{d_in=2, d_hid=10} 47 | tester:asserteq('nn.Linear', torch.type(model.net.modules[1])) 48 | tester:asserteq('nn.Tanh', torch.type(model.net.modules[2])) 49 | tester:asserteq('nn.Linear', torch.type(model.net.modules[3])) 50 | tester:asserteq('nn.MSECriterion', torch.type(model.criterion)) 51 | 52 | tester:assertge(0.08, model.params:min()) 53 | tester:assertle(-0.08, model.params:max()) 54 | tester:assert(model.dparams:eq(0):any()) 55 | end 56 | 57 | function TestModel.test_train() 58 | torch.manualSeed(12) 59 | local model = MyModel.new{d_in=2, d_hid=10} 60 | local dataset = get_dataset(100) 61 | local opt, optimize, optim_opt 62 | opt = {batch_size=128, silent=true, pad=0} 63 | optimize = optim.adam 64 | optim_opt = {learningRate = 1e-3} 65 | -- hardcoded check 66 | tester:assertGeneralEq(2.2879896, model:train(dataset, opt, optimize, optim_opt), 1e-5) 67 | end 68 | 69 | function TestModel.test_evaluate() 70 | torch.manualSeed(12) 71 | local model = MyModel.new{d_in=2, d_hid=10} 72 | local dataset = get_dataset(100) 73 | opt = {batch_size=128, silent=true, pad=0} 74 | local ret = model:evaluate(dataset, opt) 75 | tester:assertGeneralEq(2.2879896, ret.loss, 1e-5) 76 | tester:asserteq(100, ret.pred:size(1)) 77 | tester:asserteq(100, ret.targ:size(1)) 78 | tester:asserteq(100, ret.max_scores:size(1)) 79 | tester:asserteq(100, ret.raw_scores:size(1)) 80 | tester:asserteq(1, ret.raw_scores:size(2)) 81 | end 82 | 83 | function TestModel.test_fit() 84 | torch.manualSeed(12) 85 | local counter = {} 86 | local callbacks = { 87 | counter = function(split, res) 88 | counter[split] = (counter[split] or 0) + 1 89 | return counter[split] 90 | end, 91 | res_check = function(split, res) 92 | tester:assert(res.loss ~= nil) 93 | end 94 | } 95 | local model = MyModel.new{d_in=2, d_hid=3} 96 | local dataset = {train=get_dataset(100), dev=get_dataset(20), test=get_dataset(10)} 97 | local opt, optimize, optim_opt 98 | opt = {batch_size=2, silent=true, pad=0, n_epoch=10, save='./tmp'} 99 | optimize = optim.adam 100 | optim_opt = {learningRate = 1e-2} 101 | local best = model:fit(dataset, opt, callbacks) 102 | tester:assertGeneralEq(0.225155, best.dev.loss, 1e-5) 103 | tester:assertGeneralEq(0.342600, best.train.loss, 1e-5) 104 | tester:assertGeneralEq(0.432770, best.test.loss, 1e-5) 105 | tester:asserteq(opt.n_epoch, counter.train) 106 | tester:asserteq(opt.n_epoch, counter.dev) 107 | tester:asserteq(1, counter.test) 108 | dir.rmtree('./tmp') 109 | end 110 | 111 | function TestModel.test_default() 112 | local M = torch.class('TmpModel', 'tl.Model') 113 | function M:get_net() return nn.Sequential() end 114 | local m = M.new() 115 | tester:assertTableEq({}, m:required_params()) 116 | tester:assertErrorPattern(function() Model.get_net(m) end, 'not implemented') 117 | tester:asserteq('nn.CrossEntropyCriterion', torch.type(m:get_criterion())) 118 | end 119 | 120 | tester:add(TestModel) 121 | tester:run() 122 | -------------------------------------------------------------------------------- /test/test_prob_table.lua: -------------------------------------------------------------------------------- 1 | local Table = require('torchlib').ProbTable 2 | 3 | local TestTable = torch.TestSuite() 4 | local tester = torch.Tester() 5 | 6 | local t1 = function() 7 | return Table.new(torch.Tensor{{0.2, 0.8}, {0.4, 0.6}, {0.1, 0.9}}, {'a', 'b'}) 8 | end 9 | 10 | local t2 = function() 11 | return Table.new(torch.Tensor{0.3, 0.7}, 'b') 12 | end 13 | 14 | local t3 = function() 15 | return Table.new(torch.Tensor{0.1, 0.2}, 'c') 16 | end 17 | 18 | function TestTable.test_init() 19 | local t = t1() 20 | tester:assertTableEq({'a', 'b'}, t.names) 21 | tester:assertTableEq({a=1, b=2}, t.name2index) 22 | tester:assertTensorEq(torch.Tensor{{0.2, 0.8}, {0.4, 0.6}, {0.1, 0.9}}, t.P) 23 | end 24 | 25 | function TestTable.test_clone() 26 | local t = t1() 27 | local tt = t:clone() 28 | tester:assertTableEq(t.names, tt.names) 29 | tester:assertTableEq(t.name2index, tt.name2index) 30 | tester:assertTensorEq(t.P, tt.P, 1e-5) 31 | end 32 | 33 | function TestTable.test_query() 34 | local t = t1() 35 | tester:assertTensorEq(torch.Tensor{0.2, 0.8}, t:query{a=1}) 36 | tester:assertTensorEq(torch.Tensor{0.1, 0.9}, t:query{a=3}) 37 | tester:assertTensorEq(torch.Tensor{0.8, 0.6, 0.9}, t:query{b=2}) 38 | tester:asserteq(0.8, t:query{a=1, b=2}) 39 | tester:asserteq(0.4, t:query{a=2, b=1}) 40 | tester:asserteq(0.1, t:query{a=3, b=1}) 41 | end 42 | 43 | function TestTable.test_mul_subset() 44 | local t = t1():mul(t2()) 45 | tester:assertTableEq({'b', 'a'}, t.names) 46 | tester:assertTableEq({b=1, a=2}, t.name2index) 47 | tester:assertTensorEq(torch.Tensor{{0.06, 0.12, 0.03}, {0.56, 0.42, 0.63}}, t.P, 1e-5) 48 | end 49 | 50 | function TestTable.test_mul_extra() 51 | local t = t1():mul(t3()) 52 | tester:assertTableEq({'c', 'a', 'b'}, t.names) 53 | tester:assertTableEq({c=1, a=2, b=3}, t.name2index) 54 | tester:assertTensorEq(torch.Tensor{ 55 | { 56 | {0.02, 0.08}, {0.04, 0.06}, {0.01, 0.09}, 57 | }, 58 | { 59 | {0.04, 0.16}, {0.08, 0.12}, {0.02, 0.18}, 60 | } 61 | }, t.P, 1e-5) 62 | end 63 | 64 | function TestTable.test_mul_overlap() 65 | local t = t2():mul(t1()) 66 | tester:assertTableEq({'a', 'b'}, t.names) 67 | tester:assertTableEq({a=1, b=2}, t.name2index) 68 | tester:assertTensorEq(torch.Tensor{{0.06, 0.56}, {0.12, 0.42}, {0.03, 0.63}}, t.P, 1e-5) 69 | end 70 | 71 | function TestTable.test_mul1() 72 | local t = Table(torch.Tensor{2, 2, 2}, 'a') 73 | local tt = Table(torch.Tensor{{0.2, 0.8}, {0.4, 0.6}, {0.1, 0.9}}, {'a', 'b'}) 74 | local r = t:mul(tt) 75 | tester:assertTableEq({'a', 'b'}, r.names) 76 | tester:assertTableEq({a=1, b=2}, r.name2index) 77 | tester:assertTensorEq(torch.Tensor{{0.4, 1.6}, {0.8, 1.2}, {0.2, 1.8}}, r.P, 1e-5) 78 | end 79 | 80 | function TestTable.test_marginalize() 81 | local t = t1():marginalize('a') 82 | tester:assertTableEq({'b'}, t.names) 83 | tester:assertTableEq({b=1}, t.name2index) 84 | tester:assertTensorEq(torch.Tensor{0.7, 2.3}, t.P, 1e-5) 85 | t = t1():marginalize('b') 86 | tester:assertTableEq({'a'}, t.names) 87 | tester:assertTableEq({a=1}, t.name2index) 88 | tester:assertTensorEq(torch.Tensor{1, 1, 1}, t.P, 1e-5) 89 | end 90 | 91 | function TestTable:test_marginal() 92 | local t = t1():marginal('b') 93 | tester:assertTableEq({'b'}, t.names) 94 | tester:assertTableEq({b=1}, t.name2index) 95 | tester:assertTensorEq(torch.Tensor{0.7, 2.3}, t.P, 1e-5) 96 | t = t1():marginal('a') 97 | tester:assertTableEq({'a'}, t.names) 98 | tester:assertTableEq({a=1}, t.name2index) 99 | tester:assertTensorEq(torch.Tensor{1, 1, 1}, t.P, 1e-5) 100 | end 101 | 102 | function TestTable.test_normalize() 103 | local t = t1():normalize() 104 | tester:assertTableEq({'a', 'b'}, t.names) 105 | tester:assertTableEq({a=1, b=2}, t.name2index) 106 | tester:assertTensorEq(torch.Tensor{{0.2/3, 0.8/3}, {0.4/3, 0.6/3}, {0.1/3, 0.9/3}}, t.P, 1e-5) 107 | end 108 | 109 | function TestTable.test_tostring() 110 | local t = t1() 111 | local expect = [[a b P 112 | - - - 113 | 1 1 0.2 114 | 2 1 0.4 115 | 3 1 0.1 116 | 1 2 0.8 117 | 2 2 0.6 118 | 3 2 0.9 119 | ]] 120 | tester:asserteq(expect, tostring(t)) 121 | end 122 | 123 | tester:add(TestTable) 124 | tester:run() 125 | -------------------------------------------------------------------------------- /test/test_queue.lua: -------------------------------------------------------------------------------- 1 | local Queue = require('torchlib').Queue 2 | 3 | local TestQueue = torch.TestSuite() 4 | local tester = torch.Tester() 5 | 6 | function testQueue(q) 7 | tester:asserteq(0, q:size()) 8 | 9 | q:enqueue(10) 10 | tester:asserteq(1, q:size()) 11 | q:enqueue(20) 12 | tester:asserteq(2, q:size()) 13 | 14 | tester:asserteq('tl.Queue[10, 20]', tostring(q)) 15 | 16 | v = q:dequeue() 17 | tester:asserteq(10, v) 18 | tester:asserteq(1, q:size()) 19 | v = q:dequeue() 20 | tester:asserteq(20, v) 21 | tester:asserteq(0, q:size()) 22 | end 23 | 24 | 25 | function TestQueue.testQueue() 26 | local q = Queue() 27 | testQueue(q) 28 | end 29 | 30 | 31 | tester:add(TestQueue) 32 | tester:run() 33 | -------------------------------------------------------------------------------- /test/test_scorer.lua: -------------------------------------------------------------------------------- 1 | local Scorer = require('torchlib').Scorer 2 | 3 | local TestScorer = torch.TestSuite() 4 | local tester = torch.Tester() 5 | 6 | local toyScorer = function() 7 | local s = Scorer() 8 | s:add_pred('a', 'b', 1) 9 | s:add_pred('b', 'b', 2) 10 | s:add_pred('c', 'a', 3) 11 | return s 12 | end 13 | 14 | function TestScorer.test_add_pred() 15 | local s = toyScorer() 16 | tester:assertTableEq({'a', 'b', 'c'}, s.ind2class) 17 | tester:assertTableEq({a=1, b=2, c=3}, s.class2ind) 18 | tester:assertTableEq({1, 2, 3}, s.gold) 19 | tester:assertTableEq({2, 2, 1}, s.pred) 20 | end 21 | 22 | function TestScorer.test_reset() 23 | local s = toyScorer() 24 | s:reset() 25 | tester:assertTableEq({}, s.ind2class) 26 | tester:assertTableEq({}, s.class2ind) 27 | tester:assertTableEq({}, s.gold) 28 | tester:assertTableEq({}, s.pred) 29 | end 30 | 31 | function TestScorer.test_precision_recall_f1() 32 | -- these numbers are tested against Python's sklearn.metrics 33 | local s = toyScorer() 34 | local ignore = 'c' 35 | local micro, macro, all_stats = s:precision_recall_f1(ignore) 36 | tester:assertTableEq({precision=1/3, recall=1/2, f1=0.4}, micro) 37 | tester:assertTableEq({precision=0.25, recall=0.5, f1=1/3}, macro) 38 | tester:assertTableEq({ 39 | a = {precision=0, recall=0, f1=0}, 40 | b = {precision=0.5, recall=1, f1=2/3}, 41 | c = {precision=0, recall=0, f1=0}, 42 | }, all_stats) 43 | end 44 | 45 | tester:add(TestScorer) 46 | tester:run() 47 | -------------------------------------------------------------------------------- /test/test_set.lua: -------------------------------------------------------------------------------- 1 | local Set = require('torchlib').Set 2 | 3 | local TestSet = torch.TestSuite() 4 | local tester = torch.Tester() 5 | 6 | -- seed the shuffling 7 | torch.manualSeed(12) 8 | 9 | 10 | function TestSet.testAdd() 11 | local s = Set() 12 | tester:assertTableEq({}, s._map) 13 | 14 | -- add number 15 | s:add(15) 16 | local t = {} 17 | t[15] = 15 18 | tester:assertTableEq(t, s._map) 19 | 20 | -- add duplicate number 21 | s:add(15) 22 | tester:assertTableEq(t, s._map) 23 | 24 | -- add table 25 | local tab = {'foo'} 26 | s:add(tab) 27 | t[torch.pointer(tab)] = tab 28 | tester:assertTableEq(t, s._map) 29 | 30 | -- add table with duplicate content 31 | local tab2 = {'foo'} 32 | s:add(tab2) 33 | t[torch.pointer(tab2)] = tab2 34 | tester:assertTableEq(t, s._map) 35 | 36 | -- add many 37 | s:addMany(1, 2, 3) 38 | t[1] = 1 39 | t[2] = 2 40 | t[3] = 3 41 | tester:assertTableEq(t, s._map) 42 | end 43 | 44 | function TestSet.testToTable() 45 | local s = Set() 46 | s:add(5) 47 | s:add('bar') 48 | local tab = {'foo', 'bar'} 49 | s:add(tab) 50 | 51 | local expect = {5, 'bar', tab} 52 | local got = {} 53 | 54 | -- make got a dictionary for easy lookup 55 | for i, v in ipairs(s:totable()) do 56 | got[v] = true 57 | end 58 | 59 | -- check length equal 60 | tester:asserteq(#expect, #s:totable()) 61 | 62 | -- check each expected item is here 63 | for i, e in ipairs(expect) do 64 | tester:asserteq(true, got[e] ~= nil) 65 | end 66 | end 67 | 68 | function TestSet.testSize() 69 | local s = Set() 70 | tester:asserteq(0, s:size()) 71 | 72 | s:add(15) 73 | tester:asserteq(1, s:size()) 74 | 75 | s:add(16) 76 | tester:asserteq(2, s:size()) 77 | end 78 | 79 | function TestSet.testContains() 80 | local s = Set() 81 | tester:asserteq(false, s:contains(nil)) 82 | 83 | s:add(15) 84 | tester:asserteq(true, s:contains(15)) 85 | 86 | s:add('foo') 87 | tester:asserteq(true, s:contains('foo')) 88 | 89 | local tab = {'foo'} 90 | s:add(tab) 91 | tester:asserteq(true, s:contains(tab)) 92 | end 93 | 94 | function TestSet.testRemove() 95 | local s = Set() 96 | ret, err = pcall(Set.remove, s, 'bar') 97 | tester:assert(string.match(err, 'Error: value bar not found in Set') ~= nil) 98 | tester:asserteq(0, s:size()) 99 | 100 | s:add('bar') 101 | tester:asserteq(1, s:size()) 102 | s:remove('bar') 103 | tester:asserteq(0, s:size()) 104 | end 105 | 106 | function TestSet.testEquals() 107 | local s = Set() 108 | local t = Set() 109 | tester:asserteq(true, s:equals(t)) 110 | 111 | s:add(5) 112 | tester:asserteq(false, s:equals(t)) 113 | t:add(5) 114 | tester:asserteq(true, s:equals(t)) 115 | 116 | tester:asserteq(false, Set():addMany(1, 2):equals(Set():addMany(1))) 117 | tester:asserteq(false, Set():addMany(1):equals(Set():addMany(1, 2))) 118 | tester:asserteq(false, Set():addMany(1, 2):equals(Set():addMany(1, 3))) 119 | 120 | s:add(5) 121 | tester:asserteq(true, s:equals(t)) 122 | 123 | s:add(10) 124 | tester:asserteq(false, s:equals(t)) 125 | end 126 | 127 | function getToySets() 128 | local s = Set{5, 6} 129 | local t = Set{5, 7, 8} 130 | return s, t 131 | end 132 | 133 | function TestSet.testUnion() 134 | local s, t = getToySets() 135 | local expect = Set{5, 6, 7, 8} 136 | tester:assert(expect:equals(s:union(t))) 137 | end 138 | 139 | function TestSet.testIntersect() 140 | local s, t = getToySets() 141 | local expect = Set{5} 142 | tester:assert(expect:equals(s:intersect(t))) 143 | end 144 | 145 | function TestSet.testSubtract() 146 | local s, t = getToySets() 147 | local expect = Set{6} 148 | tester:assert(expect:equals(s:subtract(t))) 149 | end 150 | 151 | function TestSet.testToString() 152 | local s = Set():addMany(5) 153 | tester:asserteq('tl.Set(5)', tostring(s)) 154 | s:add(6) 155 | tester:assert(tostring(s) == 'tl.Set(5, 6)' or tostring(s) == 'tl.Set(6, 5)') 156 | end 157 | 158 | function TestSet.testCopy() 159 | local s = Set{5, 6, 7} 160 | tester:assert(s:equals(Set{5, 6, 7})) 161 | tester:assert(not s:equals(Set{5, 7})) 162 | tester:assert(not s:equals(Set{5, 8, 6, 7})) 163 | end 164 | 165 | tester:add(TestSet) 166 | tester:run() 167 | -------------------------------------------------------------------------------- /test/test_stack.lua: -------------------------------------------------------------------------------- 1 | local Stack = require('torchlib').Stack 2 | 3 | local TestStack = torch.TestSuite() 4 | local tester = torch.Tester() 5 | 6 | function TestStack.testStack() 7 | local s = Stack.new() 8 | tester:asserteq(0, s:size()) 9 | 10 | s:push(10) 11 | tester:asserteq(1, s:size()) 12 | s:push(20) 13 | tester:asserteq(2, s:size()) 14 | 15 | tester:asserteq('tl.Stack[10, 20]', tostring(s)) 16 | 17 | v = s:pop() 18 | tester:asserteq(20, v) 19 | tester:asserteq(1, s:size()) 20 | v = s:pop() 21 | tester:asserteq(10, v) 22 | tester:asserteq(0, s:size()) 23 | end 24 | 25 | 26 | tester:add(TestStack) 27 | tester:run() 28 | -------------------------------------------------------------------------------- /test/test_tree.lua: -------------------------------------------------------------------------------- 1 | local BinarySearchTree = require('torchlib').BinarySearchTree 2 | local BinaryTreeNode = require('torchlib').BinaryTree.Node 3 | local BinarySearchTreeNode = require('torchlib').BinarySearchTree.Node 4 | local TreeNode = require('torchlib').Tree.Node 5 | 6 | local TestTree = torch.TestSuite() 7 | local TestBinaryTree = torch.TestSuite() 8 | local tester = torch.Tester() 9 | 10 | local dummyTree = function() 11 | local t = BinarySearchTree.new() 12 | t:insert(BinarySearchTreeNode.new(12, 'n1')) 13 | t:insert(BinarySearchTreeNode.new(5, 'n2')) 14 | t:insert(BinarySearchTreeNode.new(2, 'n3')) 15 | t:insert(BinarySearchTreeNode.new(9, 'n4')) 16 | t:insert(BinarySearchTreeNode.new(18, 'n5')) 17 | t:insert(BinarySearchTreeNode.new(15, 'n6')) 18 | t:insert(BinarySearchTreeNode.new(13, 'n7')) 19 | t:insert(BinarySearchTreeNode.new(17, 'n8')) 20 | t:insert(BinarySearchTreeNode.new(19, 'n9')) 21 | return t 22 | end 23 | 24 | function TestTree.testToString() 25 | local node = BinaryTreeNode(5, 'hi') 26 | tester:asserteq('tl.BinaryTree.Node', tostring(node)) 27 | end 28 | 29 | function TestTree.testWalkInOrder() 30 | local t = dummyTree() 31 | local ordered = {} 32 | t:walkInOrder(function(n) table.insert(ordered, n) end) 33 | local tab = {2, 5, 9, 12, 13, 15, 17, 18, 19} 34 | for i, v in ipairs(tab) do 35 | tester:asserteq(v, ordered[i].key) 36 | end 37 | end 38 | 39 | function TestTree.testTreeNode() 40 | local node = TreeNode.new() 41 | tester:assertErrorPattern(function() node:children() end, 'not implemented') 42 | end 43 | 44 | function TestBinaryTree.testBinaryTreeNode() 45 | local t = BinaryTreeNode(1, 2) 46 | tester:asserteq(1, t.key) 47 | tester:asserteq(2, t.val) 48 | 49 | t.left = BinaryTreeNode(3, 4) 50 | t.right = BinaryTreeNode(5, 6) 51 | tester:assertTableEq({t.left, t.right}, t:children()) 52 | end 53 | 54 | function TestBinaryTree.testInsert() 55 | local a = BinarySearchTreeNode(2, 'a') 56 | local tree = BinarySearchTree() 57 | tester:asserteq(0, tree:size()) 58 | 59 | tree:insert(a) 60 | tester:asserteq(1, tree:size()) 61 | tester:asserteq(a, tree.root) 62 | 63 | local b = BinarySearchTreeNode(1, 'b') 64 | tree:insert(b) 65 | tester:asserteq(2, tree:size()) 66 | tester:asserteq(b, tree.root.left) 67 | tester:asserteq(b, a.left) 68 | 69 | local c = BinarySearchTreeNode(3, 'c') 70 | tree:insert(c) 71 | tester:asserteq(3, tree:size()) 72 | tester:asserteq(c, tree.root.right) 73 | tester:asserteq(c, a.right) 74 | 75 | local d = BinarySearchTreeNode(4, 'd') 76 | tree:insert(d) 77 | tester:asserteq(4, tree:size()) 78 | tester:asserteq(d, tree.root.right.right) 79 | tester:asserteq(d, c.right) 80 | end 81 | 82 | function TestBinaryTree.testSearch() 83 | local tree = dummyTree() 84 | tester:asserteq(nil, tree:search(-1)) 85 | tester:asserteq(12, tree:search(12).key) 86 | end 87 | 88 | function TestBinaryTree.testMin() 89 | local tree = dummyTree() 90 | tester:asserteq(2, tree:min().key) 91 | end 92 | 93 | function TestBinaryTree.testMax() 94 | local tree = dummyTree() 95 | tester:asserteq(19, tree:max().key) 96 | end 97 | 98 | function TestBinaryTree.testSuccessor() 99 | local tree = dummyTree() 100 | tester:asserteq(5, tree:search(2):successor().key) 101 | tester:asserteq(9, tree:search(5):successor().key) 102 | tester:asserteq(12, tree:search(9):successor().key) 103 | tester:asserteq(13, tree:search(12):successor().key) 104 | tester:asserteq(15, tree:search(13):successor().key) 105 | tester:asserteq(17, tree:search(15):successor().key) 106 | tester:asserteq(18, tree:search(17):successor().key) 107 | tester:asserteq(19, tree:search(18):successor().key) 108 | tester:asserteq(nil, tree:search(19):successor()) 109 | end 110 | 111 | function TestBinaryTree.testPredecssor() 112 | local tree = dummyTree() 113 | tester:asserteq(nil, tree:search(2):predecessor()) 114 | tester:asserteq(2, tree:search(5):predecessor().key) 115 | tester:asserteq(5, tree:search(9):predecessor().key) 116 | tester:asserteq(9, tree:search(12):predecessor().key) 117 | tester:asserteq(12, tree:search(13):predecessor().key) 118 | tester:asserteq(13, tree:search(15):predecessor().key) 119 | tester:asserteq(15, tree:search(17):predecessor().key) 120 | tester:asserteq(17, tree:search(18):predecessor().key) 121 | tester:asserteq(18, tree:search(19):predecessor().key) 122 | end 123 | 124 | function TestBinaryTree.testDelete() 125 | local tree = dummyTree() 126 | tree:delete(tree:search(13)) 127 | tester:asserteq(12, tree.root.key) 128 | tester:asserteq(5, tree.root.left.key) 129 | tester:asserteq(2, tree.root.left.left.key) 130 | tester:asserteq(9, tree.root.left.right.key) 131 | tester:asserteq(18, tree.root.right.key) 132 | tester:asserteq(15, tree.root.right.left.key) 133 | tester:asserteq(17, tree.root.right.left.right.key) 134 | tester:asserteq(19, tree.root.right.right.key) 135 | 136 | tree:delete(tree:search(12)) 137 | tester:asserteq(15, tree.root.key) 138 | tester:asserteq(5, tree.root.left.key) 139 | tester:asserteq(2, tree.root.left.left.key) 140 | tester:asserteq(9, tree.root.left.right.key) 141 | tester:asserteq(18, tree.root.right.key) 142 | tester:asserteq(17, tree.root.right.left.key) 143 | tester:asserteq(19, tree.root.right.right.key) 144 | 145 | tree:delete(tree:search(9)) 146 | tester:asserteq(15, tree.root.key) 147 | tester:asserteq(5, tree.root.left.key) 148 | tester:asserteq(2, tree.root.left.left.key) 149 | tester:asserteq(18, tree.root.right.key) 150 | tester:asserteq(17, tree.root.right.left.key) 151 | tester:asserteq(19, tree.root.right.right.key) 152 | 153 | tree:delete(tree:search(5)) 154 | tester:asserteq(15, tree.root.key) 155 | tester:asserteq(2, tree.root.left.key) 156 | tester:asserteq(18, tree.root.right.key) 157 | tester:asserteq(17, tree.root.right.left.key) 158 | tester:asserteq(19, tree.root.right.right.key) 159 | end 160 | 161 | function TestBinaryTree.testToStringBinarySearchTree() 162 | local tree = dummyTree() 163 | local expect = [[|__ tl.BinarySearchTree.Node 164 | |__ tl.BinarySearchTree.Node 165 | |__ tl.BinarySearchTree.Node 166 | |__ tl.BinarySearchTree.Node 167 | |__ tl.BinarySearchTree.Node 168 | |__ tl.BinarySearchTree.Node 169 | |__ tl.BinarySearchTree.Node 170 | |__ tl.BinarySearchTree.Node 171 | |__ tl.BinarySearchTree.Node 172 | ]] 173 | tester:asserteq(expect, tostring(tree)) 174 | end 175 | 176 | tester:add(TestTree) 177 | tester:add(TestBinaryTree) 178 | tester:run() 179 | -------------------------------------------------------------------------------- /test/test_util.lua: -------------------------------------------------------------------------------- 1 | local tl = require('torchlib') 2 | 3 | local TestGlobal = torch.TestSuite() 4 | local TestTable = torch.TestSuite() 5 | local TestString = torch.TestSuite() 6 | local tester = torch.Tester() 7 | 8 | math.randomseed(123) 9 | 10 | function TestGlobal.test_range() 11 | tester:assertTableEq({2, 3, 4}, tl.range(2, 4, 1)) 12 | tester:assertTableEq({3, 2}, tl.range(3, 2, -1)) 13 | tester:assertTableEq({1, 2, 3}, tl.range(3)) 14 | end 15 | 16 | function TestGlobal.test_equals() 17 | tester:asserteq(true, tl.equals('a', 'a')) 18 | tester:asserteq(false, tl.equals({}, 'a')) 19 | tester:asserteq(false, tl.equals({}, {})) 20 | tester:asserteq(true, tl.equals(tl.Set():add(5), tl.Set():add(5))) 21 | end 22 | 23 | function TestGlobal.test_copy() 24 | local s = {} 25 | local a = {a=1, b=s} 26 | local b = tl.copy(a) 27 | tester:asserteq(false, a == b) 28 | tester:asserteq(true, a.a == b.a) 29 | tester:asserteq(true, a.b == b.b) 30 | tester:asserteq(true, a.b == s) 31 | end 32 | 33 | function TestGlobal.test_deep_copy() 34 | local s = {} 35 | local a = {a=1, b=s} 36 | local b = tl.deepcopy(a) 37 | tester:asserteq(false, a == b) 38 | tester:asserteq(true, a.a == b.a) 39 | tester:asserteq(false, a.b == b.b) 40 | tester:asserteq(true, a.b == s) 41 | end 42 | 43 | function TestString.test_startswith() 44 | tester:asserteq(true, string.startswith('bob', 'bo')) 45 | tester:asserteq(false, string.startswith('bob', 'o')) 46 | end 47 | 48 | function TestString.test_endswith() 49 | tester:asserteq(true, string.endswith('bob', 'ob')) 50 | tester:asserteq(false, string.endswith('bob', 'o')) 51 | end 52 | 53 | function TestTable.test_print() 54 | local a = {a=1, b={c=3}} 55 | local expect = [[a: 1 56 | b: 57 | c: 3 58 | ]] 59 | local got = table.tostring(a) 60 | tester:asserteq(expect, got) 61 | end 62 | 63 | function TestTable.test_shuffle() 64 | local t = {1, 2, 3, 4} 65 | table.shuffle(t) 66 | tester:assertTableEq({4, 2, 3, 1}, t) 67 | end 68 | 69 | function TestTable.test_table_equals() 70 | tester:asserteq(true, table.equals({1, 2}, {1, 2})) 71 | tester:asserteq(false, table.equals({1}, {1, 2})) 72 | tester:asserteq(false, table.equals({1, 2}, {1})) 73 | end 74 | 75 | function TestTable.test_values_equal() 76 | tester:asserteq(true, table.valuesEqual({1, 2}, {1, 2})) 77 | tester:asserteq(false, table.valuesEqual({1}, {1, 2})) 78 | tester:asserteq(false, table.valuesEqual({1, 2}, {1})) 79 | tester:asserteq(true, table.valuesEqual({a=1, b=2}, {a=1, b=2})) 80 | tester:asserteq(false, table.valuesEqual({a=1}, {a=1, b=2})) 81 | tester:asserteq(false, table.valuesEqual({a=1, b=2}, {a=1})) 82 | end 83 | 84 | function TestTable.test_contains() 85 | tester:asserteq(true, table.contains({'a', 'b'}, 'b')) 86 | tester:asserteq(false, table.contains({'a', 'b'}, 'c')) 87 | end 88 | 89 | function TestTable.test_flatten() 90 | tester:assertTableEq({a=1, ['b__c']=3}, table.flatten{a=1, b={c=3}}) 91 | tester:assertTableEq({a=1, b=2}, table.flatten{a=1, b=2}) 92 | end 93 | 94 | function TestTable.test_map() 95 | tester:assertTableEq({2, 3, 4}, table.map({1, 2, 3}, function(x) return x+1 end)) 96 | end 97 | 98 | function TestTable.test_select() 99 | tester:assertTableEq({'b', 'c'}, table.select({'a', 'b', 'c', 'd'}, {2, 3}, true)) 100 | tester:assertTableEq({a=1}, table.select({a=1, b=2}, {'a'})) 101 | tester:assertTableEq({2}, table.select({a=1, b=2}, {'b'}, true)) 102 | end 103 | 104 | function TestTable.test_extend() 105 | tester:assertTableEq({2, 3, 4, 5, 6}, table.extend({2, 3, 4}, {5, 6})) 106 | end 107 | 108 | function TestTable.test_combinations() 109 | local got = table.combinations{{1, 2}, {'a', 'b', 'c'}} 110 | local expect = {{1, 'a'}, {2, 'a'}, {1, 'b'}, {2, 'b'}, {1, 'c'}, {2, 'c'}} 111 | tester:assertTableEq(expect, got) 112 | end 113 | 114 | tester:add(TestGlobal) 115 | tester:add(TestTable) 116 | tester:add(TestString) 117 | tester:run() 118 | -------------------------------------------------------------------------------- /test/test_variable_tensor.lua: -------------------------------------------------------------------------------- 1 | local VariableTensor = require('torchlib').VariableTensor 2 | 3 | local TestVariableTensor = torch.TestSuite() 4 | local tester = torch.Tester() 5 | 6 | math.randomseed(123) 7 | 8 | function TestVariableTensor.testInit() 9 | local s = VariableTensor({preinit_size=3, preinit_store_size=5}) 10 | tester:asserteq(3, s.indices:size(1)) 11 | tester:asserteq(5, s.store:size(1)) 12 | end 13 | 14 | function TestVariableTensor.testSize() 15 | local s = VariableTensor() 16 | tester:asserteq(s:size(), 0) 17 | s = VariableTensor({preinit_size=3, preinit_store_size=5}) 18 | tester:asserteq(s:size(), 0) 19 | end 20 | 21 | function TestVariableTensor:testBatch() 22 | local s = VariableTensor() 23 | local a = torch.rand(5) 24 | local b = torch.rand(10) 25 | local c = torch.rand(7) 26 | local d = torch.rand(1) 27 | s:push(a) 28 | s:push(b) 29 | s:push(c) 30 | s:push(d) 31 | 32 | tester:assertTensorEq(a, s:get(1), 1e-5) 33 | tester:assertTensorEq(b, s:get(2), 1e-5) 34 | tester:assertTensorEq(c, s:get(3), 1e-5) 35 | tester:assertTensorEq(d, s:get(4), 1e-5) 36 | 37 | local got = s:batch({1, 2}, -1) 38 | tester:asserteq(got:size(1), 2) 39 | tester:assertTensorEq(got[1][{{6, 10}}], a, 1e-5) 40 | tester:assertTensorEq(got[2], b, 1e-5) 41 | 42 | got = s:batch({3, 4}, -1) 43 | tester:asserteq(got:size(1), 2) 44 | tester:assertTensorEq(got[1], c, 1e-5) 45 | tester:asserteq(got[{2, 7}], d[1]) 46 | 47 | got = s:batch({4, 1, 3}, -1) 48 | tester:asserteq(got:size(1), 3) 49 | tester:asserteq(got[{1, 7}], d[1]) 50 | tester:assertTensorEq(got[2][{{3, 7}}], a, 1e-5) 51 | tester:assertTensorEq(got[3], c, 1e-5) 52 | end 53 | 54 | function TestVariableTensor.testPush() 55 | local s = VariableTensor() 56 | for i = 1, 100 do 57 | local n = torch.random(100) 58 | local t = torch.rand(n) 59 | s:push(t) 60 | local start, finish = s.indices[s:size()][1], s.indices[s:size()][2] 61 | local got = s.store[{{start, finish}}] 62 | tester:assertTensorEq(t, got, 1e-5) 63 | tester:assertTensorEq(t, s:get(i), 1e-5) 64 | end 65 | tester:asserteq(s:size(), 100) 66 | end 67 | 68 | function TestVariableTensor.testShuffle() 69 | local s = VariableTensor() 70 | local a, b, c, d = torch.rand(2), torch.rand(3), torch.rand(4), torch.rand(5) 71 | s:push(a) 72 | s:push(b) 73 | s:push(c) 74 | s:push(d) 75 | 76 | s:shuffle() 77 | end 78 | 79 | 80 | tester:add(TestVariableTensor) 81 | tester:run() 82 | -------------------------------------------------------------------------------- /test/test_vocab.lua: -------------------------------------------------------------------------------- 1 | local Vocab = require('torchlib').Vocab 2 | 3 | local TestVocab = torch.TestSuite() 4 | local tester = torch.Tester() 5 | 6 | function TestVocab.testInit() 7 | local v = Vocab.new() 8 | tester:assertTableEq(v.index2word, {'UNK'}, "index2word of empty") 9 | tester:assertTableEq(v.word2index, {UNK=1}, "word2index of empty") 10 | tester:assertTableEq(v.counter, {UNK=0}, "count of empty vocab") 11 | 12 | -- with unk 13 | v = Vocab("UNKNOWN") 14 | tester:assertTableEq(v.index2word, {"UNKNOWN"}, "index2word of empty vocab with unk") 15 | tester:assertTableEq(v.word2index, {UNKNOWN=1}, "word2index of empty vocab with unk") 16 | tester:assertTableEq(v.counter, {UNKNOWN=0}, "counter of empty vocab with unk") 17 | end 18 | 19 | function TestVocab.testAdd() 20 | local v = Vocab() 21 | local index = v:add("foo") 22 | tester:asserteq(index, 2) 23 | tester:assertTableEq(v.index2word, {'UNK', "foo"}, "index2word after add") 24 | tester:assertTableEq(v.word2index, {UNK=1, foo=2}, "word2index after add") 25 | tester:assertTableEq(v.counter, {UNK=0, foo=1}, "counter after add") 26 | end 27 | 28 | function TestVocab.testContains() 29 | local v = Vocab('unk') 30 | tester:assert(v:contains('unk') == true, 'contains of unk') 31 | tester:assert(v:contains('foo') == false, 'contains before add') 32 | v:add("foo") 33 | tester:assert(v:contains('foo') == true, 'contains after add') 34 | end 35 | 36 | function TestVocab.testCount() 37 | local v = Vocab('unk') 38 | tester:asserteq(v:count('unk'), 0) 39 | local status, err = pcall(v.count, v, 'foo') 40 | tester:assert(string.match(err, 'Error: attempted to get count of word foo which is not in the vocabulary') ~= nil) 41 | v:add('foo', 2) 42 | tester:asserteq(v:count('foo'), 2) 43 | v:add('foo', 0) 44 | tester:asserteq(v:count('foo'), 2) 45 | v:add('foo', 1) 46 | tester:asserteq(v:count('foo'), 3) 47 | end 48 | 49 | function TestVocab.testSize() 50 | v = Vocab() 51 | tester:asserteq(v:size(), 1) 52 | v:add('foo', 2) 53 | tester:asserteq(v:size(), 2) 54 | v:add('foo', 2) 55 | tester:asserteq(v:size(), 2) 56 | v:add('bar', 1) 57 | tester:asserteq(v:size(), 3) 58 | end 59 | 60 | function TestVocab.testWordAt() 61 | local v = Vocab('unk') 62 | local status, err = pcall(v.wordAt, v, 2) 63 | tester:assert(string.match(err, 'Error: attempted to get word at index 2 which exceeds the vocab size') ~= nil) 64 | tester:asserteq(v:wordAt(1), 'unk') 65 | v:add('bar') 66 | tester:asserteq(v:wordAt(2), 'bar') 67 | end 68 | 69 | function TestVocab.testIndexOf() 70 | local v = Vocab() 71 | v:add('bar') 72 | tester:asserteq(v:indexOf('bar'), 2) 73 | 74 | -- with add 75 | local v = Vocab() 76 | v:indexOf('bar', true) 77 | tester:asserteq(v:indexOf('bar'), 2) 78 | 79 | -- with unk 80 | v = Vocab('unk') 81 | tester:asserteq(v:indexOf('bar'), v:indexOf('unk')) 82 | v:add('bar') 83 | tester:asserteq(v:indexOf('unk'), 1) 84 | tester:asserteq(v:indexOf('bar'), 2) 85 | end 86 | 87 | function TestVocab.testIndicesOf() 88 | local v = Vocab() 89 | v:add('foo') 90 | 91 | -- with add 92 | tester:assertTableEq(v:indicesOf({'foo', 'bar', 'this', 'this', 'this'}, true), {2, 3, 4, 4, 4}) 93 | tester:asserteq(v:count('foo'), 2) 94 | tester:asserteq(v:count('bar'), 1) 95 | tester:asserteq(v:count('this'), 3) 96 | end 97 | 98 | function TestVocab.testTensorIndicesOf() 99 | local v = Vocab() 100 | v:add('foo') 101 | 102 | -- with add 103 | tester:assertTensorEq(v:tensorIndicesOf({'foo', 'bar', 'this', 'this', 'this'}, true), torch.Tensor{2, 3, 4, 4, 4}, 1e-5) 104 | tester:asserteq(v:count('foo'), 2) 105 | tester:asserteq(v:count('bar'), 1) 106 | tester:asserteq(v:count('this'), 3) 107 | end 108 | 109 | function TestVocab.testWordsAt() 110 | local v = Vocab() 111 | v:indicesOf({'foo', 'bar', 'this'}, true) 112 | local status, err = pcall(v.wordsAt, v, {1, 5, 3}) 113 | tester:assert(string.match(err, 'Error: attempted to get word at index 5 which exceeds the vocab size') ~= nil) 114 | tester:assertTableEq(v:wordsAt({4, 2, 3}), {'this', 'foo', 'bar'}) 115 | end 116 | 117 | function TestVocab.testTensorWordsAt() 118 | local v = Vocab() 119 | v:indicesOf({'foo', 'bar', 'this'}, true) 120 | local status, err = pcall(v.tensorWordsAt, v, torch.Tensor{1, 5, 3}) 121 | tester:assert(string.match(err, 'Error: attempted to get word at index 5 which exceeds the vocab size') ~= nil) 122 | tester:assertTableEq(v:tensorWordsAt(torch.Tensor{4, 2, 3}), {'this', 'foo', 'bar'}) 123 | end 124 | 125 | function TestVocab.testCopyAndPruneRares() 126 | local v = Vocab('unk') 127 | v:indicesOf({'foo', 'bar', 'this', 'this', 'this', 'bar'}, true) 128 | local pruned = v:copyAndPruneRares(2) 129 | tester:assertTableEq(pruned.counter, {unk=0, bar=2, this=3}) 130 | 131 | -- with unk 132 | local v = Vocab('unk') 133 | v:indicesOf({'foo', 'bar', 'this', 'this', 'this', 'bar'}, true) 134 | local pruned = v:copyAndPruneRares(2) 135 | tester:assertTableEq(pruned.counter, {unk=0, bar=2, this=3}) 136 | tester:asserteq(pruned.unk, 'unk') 137 | end 138 | 139 | function TestVocab.testToString() 140 | local v = Vocab('hi') 141 | v:add('boo') 142 | tester:asserteq('tl.Vocab(2 words, unk=hi)', tostring(v)) 143 | end 144 | 145 | tester:add(TestVocab) 146 | tester:run() 147 | -------------------------------------------------------------------------------- /torchlib-scm-1.rockspec: -------------------------------------------------------------------------------- 1 | package = "torchlib" 2 | version = "scm-1" 3 | 4 | source = { 5 | url = "git://github.com/vzhong/torchlib", 6 | tag = "master" 7 | } 8 | 9 | description = { 10 | summary = "libraries for torch", 11 | detailed = [[ 12 | libraries for torch 13 | ]], 14 | homepage = "https://github.com/vzhong/torchlib" 15 | } 16 | 17 | dependencies = { 18 | } 19 | 20 | build = { 21 | type = "command", 22 | build_command = [[ 23 | cmake -E make_directory build; 24 | cd build; 25 | cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH="$(LUA_BINDIR)/.." -DCMAKE_INSTALL_PREFIX="$(PREFIX)"; 26 | $(MAKE) 27 | ]], 28 | install_command = "cd build && $(MAKE) install" 29 | } 30 | --------------------------------------------------------------------------------