├── .gitignore
├── LICENSE
├── README.md
├── build.sbt
├── project
├── build.properties
└── plugins.sbt
└── src
└── main
└── scala
└── kr
└── ac
└── kaist
└── ir
└── deep
├── field
└── package.scala
├── fn
├── Activation.scala
├── Objective.scala
├── ScalarMatrix.scala
├── WeightUpdater.scala
└── package.scala
├── layer
├── BasicLayer.scala
├── Dropout.scala
├── DropoutOperation.scala
├── FullTensorLayer.scala
├── GaussianRBFLayer.scala
├── LowerTriangularLayer.scala
├── Normalize.scala
├── NormalizeOperation.scala
├── Rank3TensorLayer.scala
├── ReconBasicLayer.scala
├── Reconstructable.scala
├── SplitTensorLayer.scala
└── package.scala
├── network
├── AutoEncoder.scala
├── BasicNetwork.scala
├── StackedAutoEncoder.scala
└── package.scala
├── package.scala
├── rec
├── BinaryTree.scala
├── Leaf.scala
├── Node.scala
├── WildcardLeaf.scala
└── package.scala
├── train
├── AEType.scala
├── DistBeliefTrainStyle.scala
├── ManipulationType.scala
├── MultiThreadTrainStyle.scala
├── RAEType.scala
├── RandomEqualPartitioner.scala
├── SingleThreadTrainStyle.scala
├── StandardRAEType.scala
├── TrainStyle.scala
├── Trainer.scala
├── TrainingCriteria.scala
├── TreeType.scala
├── URAEType.scala
├── VectorType.scala
└── package.scala
└── wordvec
├── PrepareCorpus.scala
├── StringToVectorType.scala
├── StringType.scala
└── package.scala
/.gitignore:
--------------------------------------------------------------------------------
1 | **/.idea/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ScalaNetwork 1.3.0
2 | ====================
3 |
4 | ## Currently, we don't maintain this repository any more. The code had not designed well, hence we re-designed the entire logic. For the latest library, check [DeepSpark](https://github.com/nearbydelta/deepspark).
5 |
6 | A *Neural Network implementation* with Scala, [Breeze](https://github.com/scalanlp/breeze) & [Spark](http://spark.apache.org)
7 |
8 | Spark Network follows [GPL v2 license](http://choosealicense.com/licenses/gpl-2.0/).
9 |
10 | # Features
11 |
12 | ## Network
13 |
14 | ScalaNetwork supports following layered neural network implementation:
15 |
16 | * *Fully-connected* Neural Network : f(Wx + b)
17 | * *Fully-connected* Rank-3 Tensor Network : f(v1TQ[1:k]v2 + L[1:k]v + b)
18 | * *Fully-connected* Auto Encoder
19 | * *Fully-connected* Stacked Auto Encoder
20 |
21 | Also you can implement following Recursive Network via training tools.
22 |
23 | * Traditional *Recursive* Auto Encoder (RAE)
24 | * Standard *Recursive* Auto Encoder (RAE)
25 | * Unfolding *Recursive* Auto Encoder (RAE) [EXPERIMENTAL]
26 |
27 | ## Training Methodology
28 |
29 | ScalaNetwork supports following training methodologies:
30 |
31 | * Stochastic Gradient Descent w/ L1-, L2-regularization, Momentum.
32 | * [AdaGrad](http://www.magicbroom.info/Papers/DuchiHaSi10.pdf)
33 | * [AdaDelta](http://www.matthewzeiler.com/pubs/googleTR2012/googleTR2012.pdf)
34 |
35 | ScalaNetwork supports following environments:
36 |
37 | * Single-Threaded Training Environment.
38 | * Spark-based Distributed Environment, with modified version of Downpour SGD in [DistBelief](http://research.google.com/archive/large_deep_networks_nips2012.html)
39 |
40 | Also you can add negative examples with `Trainer.setNegativeSampler()`.
41 |
42 | ## Activation Function
43 |
44 | ScalaNetwork supports following activation functions:
45 |
46 | * Linear
47 | * Sigmoid
48 | * HyperbolicTangent
49 | * Rectifier
50 | * Softplus
51 | * HardSigmoid
52 | * HardTanh
53 | * Softmax
54 |
55 | And also you can make new activation function using several operations.
56 |
57 | # Usage
58 |
59 | Here is some examples for basic usage. If you want to extend this package or use it more precisely, please refer [ScalaDoc](http://nearbydelta.github.io/ScalaNetwork/api/#kr.ac.kaist.ir.deep.package)
60 |
61 | ## Download
62 |
63 | Currently ScalaNetwork supports Scala version 2.10 ~ 2.11.
64 |
65 | * Stable Release is 1.3.0
66 |
67 | If you are using SBT, add a dependency as described below:
68 |
69 | ```scala
70 | libraryDependencies += "kr.ac.kaist.ir" %% "scalanetwork" % "1.3.0"
71 | ```
72 |
73 | If you are using Maven, add a dependency as described below:
74 | ```xml
75 |
Backward computation.
61 | * 62 | * @note
63 | * Let this layer have function F composed with function X(x) = W.x + b
64 | * and higher layer have function G.
65 | *
68 | * Weight is updated with: dG/dW
69 | * and propagate dG/dx
70 | *
73 | * For the computation, we only used denominator layout. (cf. Wikipedia Page of Matrix Computation) 74 | * For the computation rules, see "Matrix Cookbook" from MIT. 75 | *
76 | * 77 | * @param delta Sequence of delta amount of weight. The order must be the reverse of [[W]] 78 | * In this function, (bias :: weight) ::: lowerStack 79 | * Thus dWeight is app 80 | * @param error to be propagated (dG / dF
is propagated from higher layer )
81 | * @return propagated error (in this case, dG/dx
)
82 | */
83 | def updateBy(delta: Iterator[ScalarMatrix], error: ScalarMatrix): ScalarMatrix = {
84 | /*
85 | * Chain Rule : dG/dX_ij = tr[ ( dG/dF ).t * dF/dX_ij ].
86 | *
87 | * Note 1. X, dG/dF, dF/dX_ij are row vectors. Therefore tr(.) can be omitted.
88 | *
89 | * Thus, dG/dX = [ (dG/dF).t * dF/dX ].t, because [...] is 1 × fanOut matrix.
90 | * Therefore dG/dX = dF/dX * dG/dF, because dF/dX is symmetric in our case.
91 | */
92 | val dGdX: ScalarMatrix = dFdX * error
93 |
94 | // For bias, input is always 1. We only need dG/dX
95 | delta.next += dGdX
96 |
97 | /*
98 | * Chain Rule : dG/dW_ij = tr[ ( dG/dX ).t * dX/dW_ij ].
99 | *
100 | * dX/dW_ij is a fan-Out dimension column vector with all zero but (i, 1) = X_j.
101 | * Thus, tr(.) can be omitted, and dG/dW_ij = (dX/dW_ij).t * dG/dX
102 | * Then {j-th column of dG/dW} = X_j * dG/dX = dG/dX * X_j.
103 | *
104 | * Therefore dG/dW = dG/dX * X.t
105 | */
106 | val dGdW: ScalarMatrix = dGdX * X.t
107 | delta.next += dGdW
108 |
109 | /*
110 | * Chain Rule : dG/dx_ij = tr[ ( dG/dX ).t * dX/dx_ij ].
111 | *
112 | * X is column vector. Thus j is always 1, so dX/dx_i is a W_?i.
113 | * Hence dG/dx_i = tr[ (dG/dX).t * dX/dx_ij ] = (W_?i).t * dG/dX.
114 | *
115 | * Thus dG/dx = W.t * dG/dX
116 | */
117 | val dGdx: ScalarMatrix = weight.t * dGdX
118 | dGdx
119 | }
120 | }
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/layer/Dropout.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.layer
2 |
3 | import kr.ac.kaist.ir.deep.fn._
4 | import play.api.libs.json.{JsObject, Json}
5 |
6 | /**
7 | * __Layer__ that drop-outs its input.
8 | *
9 | * This layer has a function of "pipeline" with drop-out possibility.
10 | * Because dropping out neurons occurr in the hidden layer, we need some intermediate pipe that handle this feature.
11 | * This layer only conveys its input to its output synapse if that output is alive.
12 | */
13 | trait Dropout extends Layer {
14 | /* On-off matrix */
15 | protected var onoff: ScalarMatrix = null
16 | /** The probability of the neuron is alive. `(Default: 1.0, 100%)` */
17 | private var presence: Probability = 1.0f
18 |
19 | /**
20 | * Set presence probability
21 | * @param p Probability to be set
22 | * @return Layer extended with dropout operta
23 | */
24 | def withProbability(p: Probability) = {
25 | presence = p
26 | this
27 | }
28 |
29 | /**
30 | * Forward computation
31 | *
32 | * @param x input matrix
33 | * @return output matrix
34 | */
35 | abstract override def apply(x: ScalarMatrix): ScalarMatrix =
36 | if (presence >= 1.0) super.apply(x)
37 | else super.apply(x) :* presence.safe
38 |
39 | /**
40 | * Translate this layer into JSON object (in Play! framework)
41 | *
42 | * @return JSON object describes this layer
43 | */
44 | abstract override def toJSON: JsObject = super.toJSON ++ Json.obj("Dropout" → presence)
45 |
46 | /**
47 | * Sugar: Forward computation. Calls apply(x)
48 | *
49 | * @param x input matrix
50 | * @return output matrix
51 | */
52 | abstract override def passedBy(x: ScalarMatrix): ScalarMatrix =
53 | if (presence >= 1.0) super.passedBy(x)
54 | else {
55 | onoff = ScalarMatrix $01(x.rows, x.cols, presence.safe)
56 | super.passedBy(x) :* onoff
57 | }
58 |
59 | /**
60 | * Backward computation.
61 | * 62 | * @note Because this layer only mediates two layers, this layer just remove propagated error for unused elements. 63 | * 64 | * @param delta Sequence of delta amount of weight. The order must be the reverse of [[W]] 65 | * @param error to be propagated (dG / dF
is propagated from higher layer )
66 | * @return propagated error (in this case, dG/dx
)
67 | */
68 | abstract override def updateBy(delta: Iterator[ScalarMatrix], error: ScalarMatrix): ScalarMatrix =
69 | if (presence >= 1) super.updateBy(delta, error)
70 | else super.updateBy(delta, error :* onoff)
71 | }
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/layer/DropoutOperation.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.layer
2 |
3 | import kr.ac.kaist.ir.deep.fn._
4 | import play.api.libs.json.{JsObject, Json}
5 |
6 | /**
7 | * __Layer__ that drop-outs its input.
8 | *
9 | * This layer has a function of "pipeline" with drop-out possibility.
10 | * Because dropping out neurons occurr in the hidden layer, we need some intermediate pipe that handle this feature.
11 | * This layer only conveys its input to its output synapse if that output is alive.
12 | *
13 | * @note Please extend [[Dropout]] trait to target layer.
14 | *
15 | * @param presence The probability of the neuron is alive. `(Default: 1.0, 100%)`
16 | */
17 | @deprecated
18 | class DropoutOperation(protected val presence: Probability = 1.0f) extends Layer {
19 | /**
20 | * weights for update
21 | *
22 | * @return weights
23 | */
24 | override val W: IndexedSeq[ScalarMatrix] = IndexedSeq.empty
25 | /** Null activation */
26 | protected override val act = null
27 | /* On-off matrix */
28 | protected var onoff: ScalarMatrix = null
29 |
30 | /**
31 | * Forward computation
32 | *
33 | * @param x input matrix
34 | * @return output matrix
35 | */
36 | override def apply(x: ScalarMatrix): ScalarMatrix =
37 | if (presence >= 1.0) x
38 | else x :* presence.safe
39 |
40 | /**
41 | * Translate this layer into JSON object (in Play! framework)
42 | *
43 | * @return JSON object describes this layer
44 | */
45 | override def toJSON: JsObject = Json.obj(
46 | "type" → "DropoutOp",
47 | "presence" → presence.safe
48 | )
49 |
50 | /**
51 | * Sugar: Forward computation. Calls apply(x)
52 | *
53 | * @param x input matrix
54 | * @return output matrix
55 | */
56 | override def into_:(x: ScalarMatrix): ScalarMatrix =
57 | if (presence >= 1.0) x
58 | else {
59 | onoff = ScalarMatrix $01(x.rows, x.cols, presence.safe)
60 | x :* onoff
61 | }
62 |
63 | /**
64 | * Backward computation.
65 | * 66 | * @note Because this layer only mediates two layers, this layer just remove propagated error for unused elements. 67 | * 68 | * @param delta Sequence of delta amount of weight. The order must be the reverse of [[W]] 69 | * @param error to be propagated (dG / dF
is propagated from higher layer )
70 | * @return propagated error (in this case, dG/dx
)
71 | */
72 | def updateBy(delta: Iterator[ScalarMatrix], error: ScalarMatrix): ScalarMatrix =
73 | if (presence >= 1) error
74 | else error :* onoff
75 | }
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/layer/FullTensorLayer.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.layer
2 |
3 | import kr.ac.kaist.ir.deep.fn.{Activation, ScalarMatrix, ScalarMatrixOp}
4 | import play.api.libs.json.{JsArray, JsObject, Json}
5 |
6 | /**
7 | * __Layer__: Basic, Fully-connected Rank 3 Tensor Layer.
8 | *
9 | * @note 10 | * v0 = a column vector 11 | * Q = Rank 3 Tensor with size out, in × in is its entry. 12 | * L = Rank 3 Tensor with size out, 1 × in is its entry. 13 | * b = out × 1 matrix. 14 | * 15 | * output = f( v0'.Q.v0 + L.v0 + b ) 16 | *17 | * 18 | * @param IO is a tuple of the number of input and output, i.e. (2 → 4) 19 | * @param act is an activation function to be applied 20 | * @param quad is initial quadratic-level weight matrix Q for the case that it is restored from JSON (default: Seq()) 21 | * @param lin is initial linear-level weight matrix L for the case that it is restored from JSON (default: null) 22 | * @param const is initial bias weight matrix b for the case that it is restored from JSON (default: null) 23 | */ 24 | class FullTensorLayer(IO: (Int, Int), 25 | protected override val act: Activation, 26 | quad: Seq[ScalarMatrix] = Seq(), 27 | lin: ScalarMatrix = null, 28 | const: ScalarMatrix = null) 29 | extends Rank3TensorLayer((IO._1, IO._1, IO._1), IO._2, act, quad, lin, const) { 30 | 31 | /** 32 | * Translate this layer into JSON object (in Play! framework) 33 | * 34 | * @return JSON object describes this layer 35 | */ 36 | override def toJSON: JsObject = Json.obj( 37 | "type" → "FullTensorLayer", 38 | "in" → fanIn, 39 | "out" → fanOut, 40 | "act" → act.toJSON, 41 | "quadratic" → JsArray.apply(quadratic.map(_.to2DSeq)), 42 | "linear" → linear.to2DSeq, 43 | "bias" → bias.to2DSeq 44 | ) 45 | 46 | /** 47 | * Retrieve first input 48 | * 49 | * @param x input to be separated 50 | * @return first input 51 | */ 52 | protected override def in1(x: ScalarMatrix): ScalarMatrix = x 53 | 54 | /** 55 | * Retrive second input 56 | * 57 | * @param x input to be separated 58 | * @return second input 59 | */ 60 | protected override def in2(x: ScalarMatrix): ScalarMatrix = x 61 | 62 | /** 63 | * Reconstruct error from fragments 64 | * @param in1 error of input1 65 | * @param in2 error of input2 66 | * @return restored error 67 | */ 68 | override protected def restoreError(in1: ScalarMatrix, in2: ScalarMatrix): ScalarMatrix = in1 + in2 69 | } 70 | -------------------------------------------------------------------------------- /src/main/scala/kr/ac/kaist/ir/deep/layer/GaussianRBFLayer.scala: -------------------------------------------------------------------------------- 1 | package kr.ac.kaist.ir.deep.layer 2 | 3 | import breeze.linalg.sum 4 | import breeze.numerics.{exp, pow} 5 | import kr.ac.kaist.ir.deep.fn._ 6 | import play.api.libs.json.{JsObject, Json} 7 | 8 | import scala.annotation.tailrec 9 | 10 | /** 11 | * __Layer__ : An Radial Basis Function Layer, with Gaussian function as its radial basis. 12 | * 13 | * @param in Dimension of input 14 | * @param centers A Matrix of Centroids. Each column is a column vector for centroids. 15 | * @param canModifyCenter True if update center during training. 16 | * @param w Initial weight (default: null) 17 | */ 18 | class GaussianRBFLayer(val in: Int, 19 | val centers: ScalarMatrix, 20 | val canModifyCenter: Boolean = true, 21 | w: ScalarMatrix = null) extends Layer { 22 | protected final val weight = if (w != null) w else ScalarMatrix of(centers.cols, 1) 23 | protected final val sumCentroidEff = ScalarMatrix $1(centers.cols, 1) 24 | protected final val sumByRow = ScalarMatrix $1(1, in) 25 | override protected val act: Activation = null 26 | override val W: IndexedSeq[ScalarMatrix] = IndexedSeq(centers, weight) 27 | 28 | /** 29 | * Translate this layer into JSON object (in Play! framework) 30 | * @note Please make an LayerReviver object if you're using custom layer. 31 | * In that case, please specify LayerReviver object's full class name as "__reviver__," 32 | * and fill up LayerReviver.revive method. 33 | * 34 | * @return JSON object describes this layer 35 | */ 36 | override def toJSON: JsObject = Json.obj( 37 | "type" → "GaussianRBF", 38 | "in" → in, 39 | "center" → centers.to2DSeq, 40 | "canModifyCenter" → canModifyCenter, 41 | "weight" → weight.to2DSeq 42 | ) 43 | 44 | /** 45 | * Forward computation 46 | * 47 | * @param x input matrix 48 | * @return output matrix 49 | */ 50 | override def apply(x: ScalarMatrix): ScalarMatrix = { 51 | val sqWeight: ScalarMatrix = pow(weight, 2f) :* 2f 52 | exp(applyCoord(x, sqWeight, ScalarMatrix $0(centers.cols, 1), centers.cols - 1)) 53 | } 54 | 55 | /** 56 | *
Backward computation.
57 | * 58 | * @note
59 | * Let X ~ N(c_i, s_i) be the Gaussian distribution, and let N_i be the pdf of it.
60 | * Then the output of this layer will be : y_i = N_i(x) = exp(-[x-c_i]*[x-c_i]/[2*s_i*s_i])
.
61 | * Call function on the higher layers as G.
62 | *
65 | * Centers are updated with: dG/dC_ij = dG/dN_i * dN_i/dc_ij.
66 | * Weights are updated with: dG/dW_i = dG/dN_i * dN_i/dw_i.
67 | * and propagate dG/dx_j = \sum_i dG/dN_i * dN_i/dx_ij.
68 | *
dG / dF
is propagated from higher layer )
73 | * @return propagated error (in this case, dG/dx
)
74 | */
75 | def updateBy(delta: Iterator[ScalarMatrix], error: ScalarMatrix): ScalarMatrix = {
76 | val multiplier: ScalarMatrix = (error :* dFdX) :/ pow(weight, 2f)
77 |
78 | val dGdC = ScalarMatrix.$0(centers.rows, centers.cols)
79 | val dWeight = updateCoord(multiplier, dGdC, centers.cols - 1)
80 |
81 | // Update Weight
82 | delta.next += ScalarMatrix(dWeight: _*)
83 |
84 | if (canModifyCenter)
85 | delta.next += dGdC
86 | else
87 | delta.next()
88 |
89 | -dGdC * sumCentroidEff
90 | }
91 |
92 | @tailrec
93 | private def applyCoord(x: ScalarMatrix, sqWeight: ScalarMatrix, out: ScalarMatrix, i: Int): ScalarMatrix =
94 | if (i >= 0) {
95 | val d: Scalar = sum(pow(x - centers(::, i to i), 2f))
96 | val in = -d / sqWeight(i, 0)
97 |
98 | out(i, 0) = in
99 | applyCoord(x, sqWeight, out, i - 1)
100 | } else
101 | out
102 |
103 | @tailrec
104 | private def updateCoord(multiplier: ScalarMatrix, dGdC: ScalarMatrix,
105 | i: Int, dWeight: Seq[Scalar] = Seq.empty): Seq[Scalar] =
106 | if (i >= 0) {
107 | val d: ScalarMatrix = X - centers(::, i to i)
108 |
109 | val w = weight(i, 0)
110 | val m = multiplier(i, 0)
111 |
112 | /* Compute dNi/dCij.
113 | * Since Ni = exp(-|x-ci|^2/(2si^2)), dNi/dCij = (xj-cij)/si^2 * Ni.
114 | * Therefore dNi/dCi = (x-ci)/si^2 * Ni.
115 | * dG/dCi = dG/dNi * dNi/dCi.
116 | * Note that dNi/dX = -dNi/dCi, and dG/dX = - \sum (dG/dNi * dNi/dCi)
117 | */
118 | dGdC(::, i to i) := d * m
119 |
120 | /* Compute dG/dSi.
121 | * dNi/dSi = |x-ci|^2/si^3 * Ni.
122 | * dG/dSi = dG/dNi * dNi/dSi.
123 | */
124 | val wUpdate = sum(pow(d, 2f)) * (m / w)
125 | // This update entry is the topmost row entry.
126 | updateCoord(multiplier, dGdC, i - 1, wUpdate +: dWeight)
127 | } else
128 | dWeight
129 |
130 | }
131 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/layer/LowerTriangularLayer.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.layer
2 |
3 | import kr.ac.kaist.ir.deep.fn._
4 | import play.api.libs.json.{JsObject, Json}
5 |
6 | /**
7 | * __Layer__: Basic, Fully-connected Layer
8 | *
9 | * @param IO a pair of __input & output__, such as 2 -> 3
10 | * @param act an __activation function__ to be applied
11 | * @param w initial weight matrix for the case that it is restored from JSON `(default: null)`
12 | * @param b inital bias matrix for the case that it is restored from JSON `(default: null)`
13 | */
14 | class LowerTriangularLayer(IO: (Int, Int),
15 | protected override val act: Activation,
16 | w: ScalarMatrix = null,
17 | b: ScalarMatrix = null)
18 | extends Layer {
19 | /** Number of Fan-ins */
20 | protected final val fanIn = IO._1
21 | /** Number of output */
22 | protected final val fanOut = IO._2
23 | /* Initialize weight */
24 | protected final val weight =
25 | if (w != null) w
26 | else
27 | act.initialize(fanIn, fanOut).mapActivePairs {
28 | case ((r, c), x) ⇒ if (c > r) 0f else x
29 | }
30 | protected final val bias = if (b != null) b else act.initialize(fanIn, fanOut, fanOut, 1)
31 | /** weights for update */
32 | override val W: IndexedSeq[ScalarMatrix] = IndexedSeq(weight, bias)
33 |
34 | /**
35 | * Forward computation
36 | *
37 | * @param x input matrix
38 | * @return output matrix
39 | */
40 | override def apply(x: ScalarMatrix): ScalarMatrix = {
41 | val wx: ScalarMatrix = weight * x
42 | val wxb: ScalarMatrix = wx + bias
43 | act(wxb)
44 | }
45 |
46 | /**
47 | * Translate this layer into JSON object (in Play! framework)
48 | *
49 | * @return JSON object describes this layer
50 | */
51 | override def toJSON: JsObject = Json.obj(
52 | "type" → "LowerTriangularLayer",
53 | "in" → fanIn,
54 | "out" → fanOut,
55 | "act" → act.toJSON,
56 | "weight" → weight.to2DSeq,
57 | "bias" → bias.to2DSeq
58 | )
59 |
60 | /**
61 | * Backward computation.
62 | * 63 | * @note
64 | * Let this layer have function F composed with function X(x) = W.x + b
65 | * and higher layer have function G.
66 | *
69 | * Weight is updated with: dG/dW
70 | * and propagate dG/dx
71 | *
74 | * For the computation, we only used denominator layout. (cf. Wikipedia Page of Matrix Computation) 75 | * For the computation rules, see "Matrix Cookbook" from MIT. 76 | *
77 | * 78 | * @param delta Sequence of delta amount of weight. The order must be the reverse of [[W]] 79 | * In this case, bias :: weight ::: lowerStack 80 | * @param error to be propagated (dG / dF
is propagated from higher layer )
81 | * @return propagated error (in this case, dG/dx
)
82 | */
83 | def updateBy(delta: Iterator[ScalarMatrix], error: ScalarMatrix): ScalarMatrix = {
84 | /*
85 | * Chain Rule : dG/dX_ij = tr[ ( dG/dF ).t * dF/dX_ij ].
86 | *
87 | * Note 1. X, dG/dF, dF/dX_ij are row vectors. Therefore tr(.) can be omitted.
88 | *
89 | * Thus, dG/dX = [ (dG/dF).t * dF/dX ].t, because [...] is 1 × fanOut matrix.
90 | * Therefore dG/dX = dF/dX * dG/dF, because dF/dX is symmetric in our case.
91 | */
92 | val dGdX: ScalarMatrix = dFdX * error
93 |
94 | // For bias, input is always 1. We only need dG/dX
95 | delta.next += dGdX
96 |
97 | /*
98 | * Chain Rule : dG/dW_ij = tr[ ( dG/dX ).t * dX/dW_ij ].
99 | *
100 | * dX/dW_ij is a fan-Out dimension column vector with all zero but (i, 1) = X_j.
101 | * Thus, tr(.) can be omitted, and dG/dW_ij = (dX/dW_ij).t * dG/dX
102 | * Then {j-th column of dG/dW} = X_j * dG/dX = dG/dX * X_j.
103 | *
104 | * Therefore dG/dW = dG/dX * X.t
105 | * Except the upper triangular region.
106 | */
107 | val dGdWp: ScalarMatrix = (dGdX * X.t)
108 | val dGdW = dGdWp.mapActivePairs {
109 | case ((r, c), x) ⇒ if (c > r) 0f else x
110 | }
111 | delta.next += dGdW
112 |
113 | /*
114 | * Chain Rule : dG/dx_ij = tr[ ( dG/dX ).t * dX/dx_ij ].
115 | *
116 | * X is column vector. Thus j is always 1, so dX/dx_i is a W_?i.
117 | * Hence dG/dx_i = tr[ (dG/dX).t * dX/dx_ij ] = (W_?i).t * dG/dX.
118 | *
119 | * Thus dG/dx = W.t * dG/dX
120 | */
121 | val dGdx: ScalarMatrix = weight.t * dGdX
122 | dGdx
123 | }
124 | }
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/layer/Normalize.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.layer
2 |
3 | import breeze.linalg.sum
4 | import breeze.numerics.pow
5 | import kr.ac.kaist.ir.deep.fn._
6 | import play.api.libs.json.{JsObject, Json}
7 |
8 | /**
9 | * __Layer__ that normalizes its input.
10 | */
11 | trait Normalize extends Layer {
12 | /**
13 | * weights for update
14 | *
15 | * @return weights
16 | */
17 | override val W: IndexedSeq[ScalarMatrix] = IndexedSeq.empty
18 | /** Null activation */
19 | protected override val act = null
20 |
21 | /**
22 | * Forward computation
23 | *
24 | * @param x input matrix
25 | * @return output matrix
26 | */
27 | abstract override def apply(x: ScalarMatrix): ScalarMatrix = {
28 | val raw = super.apply(x)
29 | val len = Math.sqrt(sum(pow(raw, 2.0f))).toFloat
30 | raw :/ len
31 | }
32 |
33 | /**
34 | * Translate this layer into JSON object (in Play! framework)
35 | *
36 | * @return JSON object describes this layer
37 | */
38 | abstract override def toJSON: JsObject = super.toJSON ++ Json.obj("Normalize" → "")
39 |
40 | /**
41 | * Backward computation.
42 | * 43 | * @note Because this layer only mediates two layers, this layer just remove propagated error for unused elements. 44 | * 45 | * @param delta Sequence of delta amount of weight. The order must be the reverse of [[W]] 46 | * @param error to be propagated (dG / dF
is propagated from higher layer )
47 | * @return propagated error (in this case, dG/dx
)
48 | */
49 | abstract override def updateBy(delta: Iterator[ScalarMatrix], error: ScalarMatrix): ScalarMatrix = {
50 | val Xsq = pow(X, 2.0f)
51 | val lenSq = sum(Xsq)
52 | val len: Scalar = Math.sqrt(lenSq).toFloat
53 |
54 | // Note that length is the function of x_i.
55 | // Let z_i := x_i / len(x_i).
56 | // Then d z_i / d x_i = (len^2 - x_i^2) / len^3 = (1 - z_i^2) / len,
57 | // d z_j / d x_i = - x_i * x_j / len^3 = - z_i * z_j / len
58 | val rows = dFdX.rows
59 | val dZdX = ScalarMatrix $0(rows, rows)
60 | var r = 0
61 | while (r < rows) {
62 | //dZ_r
63 | var c = 0
64 | while (c < rows) {
65 | if (r == c) {
66 | //dX_c
67 | dZdX.update(r, c, (1.0f - Xsq(r, 0) / lenSq) / len)
68 | } else {
69 | dZdX.update(r, c, (-X(r, 0) * X(c, 0)) / (len * lenSq))
70 | }
71 | c += 1
72 | }
73 | r += 1
74 | }
75 |
76 | // un-normalize the error
77 | super.updateBy(delta, dZdX * error)
78 | }
79 | }
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/layer/NormalizeOperation.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.layer
2 |
3 | import breeze.linalg.sum
4 | import breeze.numerics.pow
5 | import kr.ac.kaist.ir.deep.fn._
6 | import play.api.libs.json.{JsObject, Json}
7 |
8 | /**
9 | * __Layer__ that normalizes its input.
10 | *
11 | * @param factor The multiplication factor of the normalized output `(Default 1.0)`
12 | */
13 | @deprecated
14 | class NormalizeOperation(protected val factor: Scalar = 1.0f) extends Layer {
15 | /**
16 | * weights for update
17 | *
18 | * @return weights
19 | */
20 | override val W: IndexedSeq[ScalarMatrix] = IndexedSeq.empty
21 | /** Null activation */
22 | protected override val act = null
23 |
24 | /**
25 | * Translate this layer into JSON object (in Play! framework)
26 | *
27 | * @return JSON object describes this layer
28 | */
29 | override def toJSON: JsObject = Json.obj(
30 | "type" → "NormOp",
31 | "factor" → factor
32 | )
33 |
34 | /**
35 | * Backward computation.
36 | * 37 | * @note Because this layer only mediates two layers, this layer just remove propagated error for unused elements. 38 | * 39 | * @param delta Sequence of delta amount of weight. The order must be the reverse of [[W]] 40 | * In this case, centers :: weight :: REMAINDER-SEQ 41 | * @param error to be propagated (dG / dF
is propagated from higher layer )
42 | * @return propagated error (in this case, dG/dx
) and remainder of delta sequence
43 | */
44 | def updateBy(delta: Iterator[ScalarMatrix], error: ScalarMatrix): ScalarMatrix = {
45 | val len: Scalar = Math.sqrt(sum(pow(X, 2.0f))).toFloat
46 | val output: ScalarMatrix = apply(X)
47 |
48 | // Note that length is the function of x_i.
49 | // Let z_i := x_i / len(x_i).
50 | // Then d z_i / d x_i = (len^2 - x_i^2) / len^3 = (1 - z_i^2) / len,
51 | // d z_j / d x_i = - x_i * x_j / len^3 = - z_i * z_j / len
52 | val rows = dFdX.rows
53 | val dZdX = ScalarMatrix $0(rows, rows)
54 | var r = 0
55 | while (r < rows) {
56 | //dZ_r
57 | var c = 0
58 | while (c < rows) {
59 | if (r == c) {
60 | //dX_c
61 | dZdX.update(r, c, (1.0f - output(r, 0) * output(r, 0)) / len)
62 | } else {
63 | dZdX.update(r, c, (-output(r, 0) * output(c, 0)) / len)
64 | }
65 | c += 1
66 | }
67 | r += 1
68 | }
69 |
70 | // un-normalize the error
71 | dZdX * error
72 | }
73 |
74 | /**
75 | * Forward computation
76 | *
77 | * @param x input matrix
78 | * @return output matrix
79 | */
80 | override def apply(x: ScalarMatrix): ScalarMatrix = {
81 | val len = Math.sqrt(sum(pow(x, 2.0f))).toFloat
82 | val normalized: ScalarMatrix = x :/ len
83 |
84 | if (factor != 1.0f)
85 | normalized :* factor
86 | else
87 | normalized
88 | }
89 | }
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/layer/Rank3TensorLayer.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.layer
2 |
3 | import kr.ac.kaist.ir.deep.fn._
4 |
5 | /**
6 | * __Layer__: Basic, Fully-connected Rank 3 Tensor Layer.
7 | *
8 | * @note 9 | * v0 = a column vector concatenate v2 after v1 (v11, v12, ... v1in1, v21, ...) 10 | * Q = Rank 3 Tensor with size out, in1 × in2 is its entry. 11 | * L = Rank 3 Tensor with size out, 1 × (in1 + in2) is its entry. 12 | * b = out × 1 matrix. 13 | * 14 | * output = f( v1'.Q.v2 + L.v0 + b ) 15 | *16 | * 17 | * @param fanIns is the number of input. (vector1, vector2, entire). 18 | * @param fanOut is the number of output 19 | * @param act is an activation function to be applied 20 | * @param quad is initial quadratic-level weight matrix Q for the case that it is restored from JSON (default: Seq()) 21 | * @param lin is initial linear-level weight matrix L for the case that it is restored from JSON (default: null) 22 | * @param const is initial bias weight matrix b for the case that it is restored from JSON (default: null) 23 | */ 24 | abstract class Rank3TensorLayer(protected val fanIns: (Int, Int, Int), 25 | protected val fanOut: Int, 26 | protected override val act: Activation, 27 | quad: Seq[ScalarMatrix] = Seq(), 28 | lin: ScalarMatrix = null, 29 | const: ScalarMatrix = null) 30 | extends Layer { 31 | /* Number of Fan-ins */ 32 | protected final val fanInA = fanIns._1 33 | protected final val fanInB = fanIns._2 34 | protected final val fanIn = fanIns._3 35 | /* Initialize weight */ 36 | protected final val quadratic: IndexedSeq[ScalarMatrix] = 37 | if (quad.nonEmpty) quad.toIndexedSeq 38 | else (0 until fanOut).map(_ ⇒ act.initialize(fanIn, fanOut, fanInA, fanInB)) 39 | protected final val linear: ScalarMatrix = if (lin != null) lin else act.initialize(fanIn, fanOut, fanOut, fanIn) 40 | protected final val bias: ScalarMatrix = if (const != null) const else act.initialize(fanIn, fanOut, fanOut, 1) 41 | 42 | /** 43 | * Retrieve first input 44 | * 45 | * @param x input to be separated 46 | * @return first input 47 | */ 48 | protected def in1(x: ScalarMatrix): ScalarMatrix 49 | 50 | /** 51 | * Retrive second input 52 | * 53 | * @param x input to be separated 54 | * @return second input 55 | */ 56 | protected def in2(x: ScalarMatrix): ScalarMatrix 57 | 58 | /** 59 | * Reconstruct error from fragments 60 | * @param in1 error of input1 61 | * @param in2 error of input2 62 | * @return restored error 63 | */ 64 | protected def restoreError(in1: ScalarMatrix, in2: ScalarMatrix): ScalarMatrix 65 | 66 | /** 67 | * Forward computation 68 | * 69 | * @param x input matrix 70 | * @return output matrix 71 | */ 72 | override def apply(x: ScalarMatrix): ScalarMatrix = { 73 | val inA = in1(x) 74 | val inB = in2(x) 75 | 76 | val intermediate: ScalarMatrix = linear * x 77 | intermediate += bias 78 | 79 | val quads = quadratic.map { q ⇒ 80 | val xQ: ScalarMatrix = inA.t * q 81 | val xQy: ScalarMatrix = xQ * inB 82 | xQy(0, 0) 83 | } 84 | intermediate += ScalarMatrix(quads: _*) 85 | 86 | act(intermediate) 87 | } 88 | 89 | /** 90 | * weights for update 91 | * 92 | * @return weights 93 | */ 94 | override val W: IndexedSeq[ScalarMatrix] = (quadratic :+ linear) :+ bias 95 | 96 | /** 97 | *
Backward computation.
98 | * 99 | * @note
100 | * Let this layer have function F composed with function X(x) = x1'.Q.x2 + L.x + b
101 | * and higher layer have function G. (Each output is treated as separately except propagation)
102 | *
105 | * Weight is updated with: dG/dW
106 | * and propagate dG/dx
107 | *
110 | * For the computation, we only used denominator layout. (cf. Wikipedia Page of Matrix Computation) 111 | * For the computation rules, see "Matrix Cookbook" from MIT. 112 | *
113 | * 114 | * @param delta Sequence of delta amount of weight. The order must be the reverse of [[W]] 115 | * In this case, bias :: linear :: quadratic(K to 0) ::: lowerStack 116 | * @param error to be propagated (dG / dF
is propagated from higher layer )
117 | * @return propagated error (in this case, dG/dx
)
118 | */
119 | def updateBy(delta: Iterator[ScalarMatrix], error: ScalarMatrix): ScalarMatrix = {
120 | val inA = in1(X)
121 | val inB = in2(X)
122 |
123 | /*
124 | * Chain Rule : dG/dX_ij = tr[ ( dG/dF ).t * dF/dX_ij ].
125 | *
126 | * Note 1. X, dG/dF, dF/dX_ij are row vectors. Therefore tr(.) can be omitted.
127 | *
128 | * Thus, dG/dX = [ (dG/dF).t * dF/dX ].t, because [...] is 1 × fanOut matrix.
129 | * Therefore dG/dX = dF/dX * dG/dF, because dF/dX is symmetric in our case.
130 | */
131 | val dGdX: ScalarMatrix = dFdX * error
132 |
133 | // For bias, input is always 1. We only need dG/dX
134 | delta.next += dGdX
135 |
136 | /*
137 | * Chain Rule (Linear weight case) : dG/dW_ij = tr[ ( dG/dX ).t * dX/dW_ij ].
138 | *
139 | * dX/dW_ij is a fan-Out dimension column vector with all zero but (i, 1) = X_j.
140 | * Thus, tr(.) can be omitted, and dG/dW_ij = (dX/dW_ij).t * dG/dX
141 | * Then {j-th column of dG/dW} = X_j * dG/dX = dG/dX * X_j.
142 | *
143 | * Therefore dG/dW = dG/dX * X.t
144 | */
145 | val dGdL = dGdX * X.t
146 | delta.next += dGdL
147 | /*
148 | * Chain Rule (Linear weight part) : dG/dx_ij = tr[ ( dG/dX ).t * dX/dx_ij ].
149 | *
150 | * X is column vector. Thus j is always 1, so dX/dx_i is a W_?i.
151 | * Hence dG/dx_i = tr[ (dG/dX).t * dX/dx_ij ] = (W_?i).t * dG/dX.
152 | *
153 | * Thus dG/dx (linear part) = W.t * dG/dX.
154 | */
155 | val dGdx = linear.t * dGdX
156 |
157 | /*
158 | * Because X = inA.t * Q * inB, dX/dQ = inA * inB.t
159 | */
160 | val dXdQ: ScalarMatrix = inA * inB.t //d tr(axb)/dx = a'b'
161 |
162 | // Add dG/dx quadratic part.
163 | updateQuadratic(inA, inB, dGdX, dXdQ, dGdx, delta)
164 | }
165 |
166 | private def updateQuadratic(inA: ScalarMatrix, inB: ScalarMatrix,
167 | dGdXAll: ScalarMatrix, dXdQ: ScalarMatrix,
168 | acc: ScalarMatrix, delta: Iterator[ScalarMatrix], id: Int = fanOut - 1): ScalarMatrix =
169 | if (id >= 0) {
170 | // This is scalar
171 | val dGdX = dGdXAll(id, 0)
172 |
173 | /*
174 | * Chain Rule (Quadratic weight case) : dG/dQ_ij = tr[ ( dG/dX ).t * dX/dQ_ij ].
175 | *
176 | * dX/dQ_ij = (inA * inB.t)_ij, and so dG/dQ_ij = (dG/dX).t * dX/dQ_ij.
177 | * They are scalar, so dG/dQ = dG/dX * dX/dQ.
178 | */
179 | val dGdQ: ScalarMatrix = dXdQ :* dGdX
180 | delta.next += dGdQ
181 |
182 | /*
183 | * Chain Rule (Linear weight part) : dG/dx_ij = tr[ ( dG/dX ).t * dX/dx_ij ].
184 | *
185 | * X is column vector. Thus j is always 1, so dX/dx_i is a W_?i.
186 | * Hence dG/dx_i = tr[ (dG/dX).t * dX/dx_ij ] = (W_?i).t * dG/dX.
187 | *
188 | * Thus dG/dx = W.t * dG/dX.
189 | *
190 | * Chain Rule (Quadratic weight part) : dG/dx_ij = tr[ ( dG/dX ).t * dX/dx_ij ].
191 | *
192 | * Note that x is a column vector with inA, inB as parts.
193 | * Because X = inA.t * Q * inB, dX/dxA = inB.t * Q.t and dX/dxB = inA.t * Q
194 | * Since dG/dX is scalar, we obtain dG/dx by scalar multiplication.
195 | */
196 | val dXdxQ1: ScalarMatrix = inB.t * quadratic(id).t //d tr(ax')/dx = d tr(x'a)/dx = a'
197 | val dXdxQ2: ScalarMatrix = inA.t * quadratic(id) //d tr(ax)/dx = d tr(xa)/dx = a
198 | val dGdx: ScalarMatrix = restoreError(dXdxQ1, dXdxQ2) :* dGdX
199 | acc += dGdx
200 |
201 | updateQuadratic(inA, inB, dGdXAll, dXdQ, acc, delta, id - 1)
202 | } else
203 | acc
204 | }
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/layer/ReconBasicLayer.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.layer
2 |
3 | import kr.ac.kaist.ir.deep.fn._
4 | import play.api.libs.json.JsObject
5 |
6 | /**
7 | * __Layer__ : Reconstructable Basic Layer
8 | *
9 | * @param IO is a pair of input & output, such as 2 -> 3
10 | * @param act is an activation function to be applied
11 | * @param w is initial weight matrix for the case that it is restored from JSON (default: null)
12 | * @param b is inital bias matrix for the case that it is restored from JSON (default: null)
13 | * @param rb is initial reconstruct bias matrix for the case that it is restored from JSON (default: null)
14 | */
15 | class ReconBasicLayer(IO: (Int, Int),
16 | act: Activation,
17 | w: ScalarMatrix = null,
18 | b: ScalarMatrix = null,
19 | rb: ScalarMatrix = null)
20 | extends BasicLayer(IO, act, w, b) with Reconstructable {
21 | protected final val reBias = if (rb != null) rb else act initialize(fanIn, fanOut, fanIn, 1)
22 | /**
23 | * weights for update
24 | *
25 | * @return weights
26 | */
27 | override val W: IndexedSeq[ScalarMatrix] = IndexedSeq(bias, weight, weight, reBias)
28 |
29 | /**
30 | * Sugar: reconstruction
31 | *
32 | * @param x hidden layer output matrix
33 | * @return tuple of reconstruction output
34 | */
35 | override def decodeFrom(x: ScalarMatrix): ScalarMatrix = {
36 | val wx: ScalarMatrix = weight.t[ScalarMatrix, ScalarMatrix] * x
37 | val wxb: ScalarMatrix = wx + reBias
38 | act(wxb)
39 | }
40 |
41 | /**
42 | * Translate this layer into JSON object (in Play! framework)
43 | *
44 | * @return JSON object describes this layer
45 | */
46 | override def toJSON: JsObject = super.toJSON + ("reconst_bias" → reBias.to2DSeq)
47 |
48 | /**
49 | * Backpropagation of reconstruction. For the information about backpropagation calculation, see [[kr.ac.kaist.ir.deep.layer.Layer]]
50 | *
51 | * @param delta Sequence of delta amount of weight. The order must be the reverse of [[W]]
52 | * In this case, reBias :: weight ::: lowerStack
53 | * @param error error matrix to be propagated
54 | * @return propagated error
55 | */
56 | protected[deep] def decodeUpdateBy(delta: Iterator[ScalarMatrix], error: ScalarMatrix): ScalarMatrix = {
57 | /*
58 | * Chain Rule : dG/dX_ij = tr[ ( dG/dF ).t * dF/dX_ij ].
59 | *
60 | * Note 1. X, dG/dF, dF/dX_ij are row vectors. Therefore tr(.) can be omitted.
61 | *
62 | * Thus, dG/dX = [ (dG/dF).t * dF/dX ].t, because [...] is 1 × fanOut matrix.
63 | * Therefore dG/dX = dF/dX * dG/dF, because dF/dX is symmetric in our case.
64 | */
65 | val dGdX: ScalarMatrix = decdFdX * error
66 |
67 | // For bias, input is always 1. We only need dG/dX
68 | delta.next += dGdX
69 |
70 | /*
71 | * Chain Rule : dG/dW_ij = tr[ ( dG/dX ).t * dX/dW_ij ].
72 | *
73 | * dX/dW_ij is a fan-Out dimension column vector with all zero but (i, 1) = X_j.
74 | * Thus, tr(.) can be omitted, and dG/dW_ij = (dX/dW_ij).t * dG/dX
75 | * Then {j-th column of dG/dW} = X_j * dG/dX = dG/dX * X_j.
76 | *
77 | * Therefore dG/dW = dG/dX * X.t
78 | */
79 | val dGdW: ScalarMatrix = dGdX * decX.t
80 | delta.next += dGdW.t // Because we used transposed weight for reconstruction, we need to transpose it.
81 |
82 | /*
83 | * Chain Rule : dG/dx_ij = tr[ ( dG/dX ).t * dX/dx_ij ].
84 | *
85 | * X is column vector. Thus j is always 1, so dX/dx_i is a W_?i.
86 | * Hence dG/dx_i = tr[ (dG/dX).t * dX/dx_ij ] = (W_?i).t * dG/dX.
87 | *
88 | * Thus dG/dx = W.t * dG/dX
89 | */
90 | val dGdx: ScalarMatrix = weight * dGdX // Because we used transposed weight for reconstruction.
91 | dGdx
92 | }
93 | }
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/layer/Reconstructable.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.layer
2 |
3 | import kr.ac.kaist.ir.deep.fn._
4 |
5 | /**
6 | * __Trait__ of Layer that can be used for autoencoder
7 | */
8 | trait Reconstructable extends Layer {
9 | protected var decX: ScalarMatrix = _
10 | protected var decdFdX: ScalarMatrix = _
11 | /**
12 | * Reconstruction
13 | *
14 | * @param x hidden layer output matrix
15 | * @return tuple of reconstruction output
16 | */
17 | def decodeFrom(x: ScalarMatrix): ScalarMatrix
18 |
19 | /**
20 | * Sugar: reconstruction
21 | *
22 | * @param x hidden layer output matrix
23 | * @return tuple of reconstruction output
24 | */
25 | def decodeBy(x: ScalarMatrix): ScalarMatrix = {
26 | decX = x
27 | val out = decodeFrom(x)
28 | decdFdX = act.derivative(out)
29 | out
30 | }
31 |
32 | /**
33 | * Sugar: reconstruction
34 | *
35 | * @param x hidden layer output matrix
36 | * @return tuple of reconstruction output
37 | */
38 | @deprecated
39 | def decodeBy_:(x: ScalarMatrix): ScalarMatrix = decodeBy(x)
40 |
41 | /**
42 | * Backpropagation of reconstruction. For the information about backpropagation calculation, see [[kr.ac.kaist.ir.deep.layer.Layer]]
43 | *
44 | * @param delta Sequence of delta amount of weight. The order must be the reverse of [[W]]
45 | * @param error error matrix to be propagated
46 | * @return propagated error
47 | */
48 | protected[deep] def decodeUpdateBy(delta: Iterator[ScalarMatrix], error: ScalarMatrix): ScalarMatrix
49 | }
50 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/layer/SplitTensorLayer.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.layer
2 |
3 | import kr.ac.kaist.ir.deep.fn._
4 | import play.api.libs.json.{JsArray, JsObject, Json}
5 |
6 | /**
7 | * __Layer__: Basic, Fully-connected Rank 3 Tensor Layer.
8 | *
9 | * @note 10 | * v0 = a column vector concatenate v2 after v1 (v11, v12, ... v1in1, v21, ...) 11 | * Q = Rank 3 Tensor with size out, in1 × in2 is its entry. 12 | * L = Rank 3 Tensor with size out, 1 × (in1 + in2) is its entry. 13 | * b = out × 1 matrix. 14 | * 15 | * output = f( v1'.Q.v2 + L.v0 + b ) 16 | *17 | * 18 | * @param IO is a tuple of the number of input and output, i.e. ((2, 3) → 4) 19 | * @param act is an activation function to be applied 20 | * @param quad is initial quadratic-level weight matrix Q for the case that it is restored from JSON (default: Seq()) 21 | * @param lin is initial linear-level weight matrix L for the case that it is restored from JSON (default: null) 22 | * @param const is initial bias weight matrix b for the case that it is restored from JSON (default: null) 23 | */ 24 | class SplitTensorLayer(IO: ((Int, Int), Int), 25 | protected override val act: Activation, 26 | quad: Seq[ScalarMatrix] = Seq(), 27 | lin: ScalarMatrix = null, 28 | const: ScalarMatrix = null) 29 | extends Rank3TensorLayer((IO._1._1, IO._1._2, IO._1._1 + IO._1._2), IO._2, act, quad, lin, const) { 30 | 31 | /** 32 | * Translate this layer into JSON object (in Play! framework) 33 | * 34 | * @return JSON object describes this layer 35 | */ 36 | override def toJSON: JsObject = Json.obj( 37 | "type" → "SplitTensorLayer", 38 | "in" → Json.arr(fanInA, fanInB), 39 | "out" → fanOut, 40 | "act" → act.toJSON, 41 | "quadratic" → JsArray.apply(quadratic.map(_.to2DSeq)), 42 | "linear" → linear.to2DSeq, 43 | "bias" → bias.to2DSeq 44 | ) 45 | 46 | /** 47 | * Retrieve first input 48 | * 49 | * @param x input to be separated 50 | * @return first input 51 | */ 52 | protected override def in1(x: ScalarMatrix): ScalarMatrix = x(0 until fanInA, ::) 53 | 54 | /** 55 | * Retrive second input 56 | * @param x input to be separated 57 | * @return second input 58 | */ 59 | protected override def in2(x: ScalarMatrix): ScalarMatrix = x(fanInA to -1, ::) 60 | 61 | /** 62 | * Reconstruct error from fragments 63 | * @param in1 error of input1 64 | * @param in2 error of input2 65 | * @return restored error 66 | */ 67 | override protected def restoreError(in1: ScalarMatrix, in2: ScalarMatrix): ScalarMatrix = in1 col_+ in2 68 | } 69 | -------------------------------------------------------------------------------- /src/main/scala/kr/ac/kaist/ir/deep/layer/package.scala: -------------------------------------------------------------------------------- 1 | package kr.ac.kaist.ir.deep 2 | 3 | import kr.ac.kaist.ir.deep.fn._ 4 | import play.api.libs.json.{JsObject, JsValue} 5 | 6 | import scala.reflect.runtime._ 7 | 8 | /** 9 | * Package for layer implementation 10 | */ 11 | package object layer { 12 | 13 | /** 14 | * __Trait__ that describes layer-level computation 15 | * 16 | * Layer is an instance of ScalaMatrix => ScalaMatrix function. 17 | * Therefore "layers" can be composed together. 18 | */ 19 | trait Layer extends (ScalarMatrix ⇒ ScalarMatrix) with Serializable { 20 | /** Activation Function */ 21 | protected val act: Activation 22 | protected var X: ScalarMatrix = _ 23 | protected var dFdX: ScalarMatrix = _ 24 | 25 | /** 26 | * Forward computation 27 | * 28 | * @param x input matrix 29 | * @return output matrix 30 | */ 31 | override def apply(x: ScalarMatrix): ScalarMatrix 32 | 33 | /** 34 | *
Backward computation.
35 | * 36 | * @note
37 | * Let this layer have function F composed with function X(x) = W.x + b
38 | * and higher layer have function G.
39 | *
42 | * Weight is updated with: dG/dW
43 | * and propagate dG/dx
44 | *
47 | * For the computation, we only used denominator layout. (cf. Wikipedia Page of Matrix Computation) 48 | * For the computation rules, see "Matrix Cookbook" from MIT. 49 | *
50 | * 51 | * @param delta Sequence of delta amount of weight. The order must be the re of [[W]] 52 | * @param error to be propagated (dG / dF
is propagated from higher layer )
53 | * @return propagated error (in this case, dG/dx
)
54 | */
55 | def updateBy(delta: Iterator[ScalarMatrix], error: ScalarMatrix): ScalarMatrix
56 |
57 | /**
58 | * Sugar: Forward computation. Calls apply(x)
59 | *
60 | * @param x input matrix
61 | * @return output matrix
62 | */
63 | def passedBy(x: ScalarMatrix) = {
64 | this.X = x
65 | val out = apply(x)
66 | dFdX =
67 | if (act != null)
68 | act.derivative(out)
69 | else
70 | out
71 | out
72 | }
73 |
74 | /**
75 | * Translate this layer into JSON object (in Play! framework)
76 | * @note Please make an LayerReviver object if you're using custom layer.
77 | * In that case, please specify LayerReviver object's full class name as "__reviver__,"
78 | * and fill up LayerReviver.revive method.
79 | * @return JSON object describes this layer
80 | */
81 | def toJSON: JsObject
82 |
83 | /**
84 | * Sugar: Forward computation. Calls apply(x)
85 | *
86 | * @param x input matrix
87 | * @return output matrix
88 | */
89 | @deprecated
90 | protected[deep] def into_:(x: ScalarMatrix) = passedBy(x)
91 |
92 | /**
93 | * weights for update
94 | *
95 | * @return weights
96 | */
97 | val W: IndexedSeq[ScalarMatrix]
98 | }
99 |
100 | /**
101 | * __Trait__ that revives layer from JSON value
102 | */
103 | trait LayerReviver extends Serializable {
104 | /**
105 | * Revive layer using given JSON value
106 | * @param obj JSON value to be revived
107 | * @return Revived layer.
108 | */
109 | def revive(obj: JsValue): Layer
110 | }
111 |
112 | /**
113 | * Companion object of Layer
114 | */
115 | object Layer extends LayerReviver {
116 | @transient val runtimeMirror = universe.synchronized(universe.runtimeMirror(getClass.getClassLoader))
117 |
118 | /**
119 | * Load layer from JsObject
120 | *
121 | * @param obj JsObject to be parsed
122 | * @return New layer reconstructed from this object
123 | */
124 | def apply(obj: JsValue) = {
125 | val companion =
126 | universe.synchronized {
127 | (obj \ "reviver").asOpt[String] match {
128 | case Some(clsName) ⇒
129 | val module = runtimeMirror.staticModule(clsName)
130 | runtimeMirror.reflectModule(module).instance.asInstanceOf[LayerReviver]
131 | case None ⇒
132 | this
133 | }
134 | }
135 | companion.revive(obj)
136 | }
137 |
138 | /**
139 | * Load layer from JsObject
140 | *
141 | * @param obj JsObject to be parsed
142 | * @return New layer reconstructed from this object
143 | */
144 | def revive(obj: JsValue) = {
145 | val in = obj \ "in"
146 | val out = obj \ "out"
147 | val typeStr = (obj \ "type").as[String]
148 |
149 | val act = if (typeStr.endsWith("Layer")) {
150 | Activation.apply(obj \ "act")
151 | } else null
152 |
153 | val dropout = (obj \ "Dropout").asOpt[Probability]
154 | val normalize = (obj \ "Normalize").asOpt[String]
155 |
156 | typeStr match {
157 | case "NormOp" ⇒
158 | val factor = (obj \ "factor").as[Scalar]
159 | new NormalizeOperation(factor)
160 | case "DropoutOp" ⇒
161 | val presence = (obj \ "presence").as[Probability]
162 | new DropoutOperation(presence)
163 | case "GaussianRBF" ⇒
164 | val w = ScalarMatrix restore (obj \ "weight").as[IndexedSeq[IndexedSeq[String]]]
165 | val c = ScalarMatrix restore (obj \ "center").as[IndexedSeq[IndexedSeq[String]]]
166 | val modifiable = (obj \ "canModifyCenter").as[Boolean]
167 | (dropout, normalize) match {
168 | case (Some(p), Some(_)) ⇒
169 | new GaussianRBFLayer(in.as[Int], c, modifiable, w) with Dropout with Normalize withProbability p
170 | case (Some(p), None) ⇒
171 | new GaussianRBFLayer(in.as[Int], c, modifiable, w) with Dropout withProbability p
172 | case (None, Some(_)) ⇒
173 | new GaussianRBFLayer(in.as[Int], c, modifiable, w) with Normalize
174 | case _ ⇒
175 | new GaussianRBFLayer(in.as[Int], c, modifiable, w)
176 | }
177 |
178 | case "BasicLayer" ⇒
179 | val i = in.as[Int]
180 | val o = out.as[Int]
181 | val b = ScalarMatrix restore (obj \ "bias").as[IndexedSeq[IndexedSeq[String]]]
182 | val w = ScalarMatrix restore (obj \ "weight").as[IndexedSeq[IndexedSeq[String]]]
183 | (obj \ "reconst_bias").asOpt[IndexedSeq[IndexedSeq[String]]] match {
184 | case Some(rbraw) ⇒
185 | val rb = ScalarMatrix restore rbraw
186 | (dropout, normalize) match {
187 | case (Some(p), Some(_)) ⇒
188 | new ReconBasicLayer(i → o, act, w, b, rb) with Dropout with Normalize withProbability p
189 | case (Some(p), None) ⇒
190 | new ReconBasicLayer(i → o, act, w, b, rb) with Dropout withProbability p
191 | case (None, Some(_)) ⇒
192 | new ReconBasicLayer(i → o, act, w, b, rb) with Normalize
193 | case _ ⇒
194 | new ReconBasicLayer(i → o, act, w, b, rb)
195 | }
196 | case None ⇒
197 | (dropout, normalize) match {
198 | case (Some(p), Some(_)) ⇒
199 | new BasicLayer(i → o, act, w, b) with Dropout with Normalize withProbability p
200 | case (Some(p), None) ⇒
201 | new BasicLayer(i → o, act, w, b) with Dropout withProbability p
202 | case (None, Some(_)) ⇒
203 | new BasicLayer(i → o, act, w, b) with Normalize
204 | case _ ⇒
205 | new BasicLayer(i → o, act, w, b)
206 | }
207 | }
208 |
209 | case "LowerTriangularLayer" ⇒
210 | val i = in.as[Int]
211 | val o = out.as[Int]
212 | val b = ScalarMatrix restore (obj \ "bias").as[IndexedSeq[IndexedSeq[String]]]
213 | val w = ScalarMatrix restore (obj \ "weight").as[IndexedSeq[IndexedSeq[String]]]
214 |
215 | (dropout, normalize) match {
216 | case (Some(p), Some(_)) ⇒
217 | new LowerTriangularLayer(i → o, act, w, b) with Dropout with Normalize withProbability p
218 | case (Some(p), None) ⇒
219 | new LowerTriangularLayer(i → o, act, w, b) with Dropout withProbability p
220 | case (None, Some(_)) ⇒
221 | new LowerTriangularLayer(i → o, act, w, b) with Normalize
222 | case _ ⇒
223 | new LowerTriangularLayer(i → o, act, w, b)
224 | }
225 |
226 | case "SplitTensorLayer" ⇒
227 | val tuple = in.as[Seq[Int]]
228 | val i = (tuple.head, tuple(1))
229 | val o = out.as[Int]
230 | val b = ScalarMatrix restore (obj \ "bias").as[IndexedSeq[IndexedSeq[String]]]
231 | val quad = (obj \ "quadratic").as[Seq[IndexedSeq[IndexedSeq[String]]]] map ScalarMatrix.restore
232 | val linear =
233 | try {
234 | ScalarMatrix restore (obj \ "linear").as[IndexedSeq[IndexedSeq[String]]]
235 | } catch {
236 | case _: Throwable ⇒
237 | (obj \ "linear").as[Seq[IndexedSeq[IndexedSeq[String]]]].map(ScalarMatrix.restore)
238 | .zipWithIndex.foldLeft(ScalarMatrix.$0(out.as[Int], tuple.sum)) {
239 | case (matx, (row, id)) ⇒
240 | matx(id to id, ::) := row
241 | matx
242 | }
243 | }
244 |
245 | (dropout, normalize) match {
246 | case (Some(p), Some(_)) ⇒
247 | new SplitTensorLayer(i → o, act, quad, linear, b) with Dropout with Normalize withProbability p
248 | case (Some(p), None) ⇒
249 | new SplitTensorLayer(i → o, act, quad, linear, b) with Dropout withProbability p
250 | case (None, Some(_)) ⇒
251 | new SplitTensorLayer(i → o, act, quad, linear, b) with Normalize
252 | case _ ⇒
253 | new SplitTensorLayer(i → o, act, quad, linear, b)
254 | }
255 | case "FullTensorLayer" ⇒
256 | val i = in.as[Int]
257 | val o = out.as[Int]
258 | val b = ScalarMatrix restore (obj \ "bias").as[IndexedSeq[IndexedSeq[String]]]
259 | val quad = (obj \ "quadratic").as[Seq[IndexedSeq[IndexedSeq[String]]]] map ScalarMatrix.restore
260 | val linear =
261 | try {
262 | ScalarMatrix restore (obj \ "linear").as[IndexedSeq[IndexedSeq[String]]]
263 | } catch {
264 | case _: Throwable ⇒
265 | (obj \ "linear").as[Seq[IndexedSeq[IndexedSeq[String]]]].map(ScalarMatrix.restore)
266 | .zipWithIndex.foldLeft(ScalarMatrix.$0(out.as[Int], in.as[Int])) {
267 | case (matx, (row, id)) ⇒
268 | matx(id to id, ::) := row
269 | matx
270 | }
271 | }
272 |
273 | (dropout, normalize) match {
274 | case (Some(p), Some(_)) ⇒
275 | new FullTensorLayer(i → o, act, quad, linear, b) with Dropout with Normalize withProbability p
276 | case (Some(p), None) ⇒
277 | new FullTensorLayer(i → o, act, quad, linear, b) with Dropout withProbability p
278 | case (None, Some(_)) ⇒
279 | new FullTensorLayer(i → o, act, quad, linear, b) with Normalize
280 | case _ ⇒
281 | new FullTensorLayer(i → o, act, quad, linear, b)
282 | }
283 | }
284 | }
285 | }
286 |
287 | }
288 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/network/AutoEncoder.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.network
2 |
3 | import kr.ac.kaist.ir.deep.fn._
4 | import kr.ac.kaist.ir.deep.layer.Reconstructable
5 | import play.api.libs.json.Json
6 |
7 | /**
8 | * __Network__: Single-layer Autoencoder
9 | *
10 | * @param layer A __reconstructable__ layer for this network
11 | * @param presence the probability of non-dropped neurons (for drop-out training). `(default : 100% = 1.0)`
12 | */
13 | class AutoEncoder(val layer: Reconstructable,
14 | private val presence: Probability = 1.0f)
15 | extends Network {
16 | /**
17 | * All weights of layers
18 | *
19 | * @return all weights of layers
20 | */
21 | override val W: IndexedSeq[ScalarMatrix] = layer.W
22 |
23 | /**
24 | * Compute output of neural network with given input (without reconstruction)
25 | * If drop-out is used, to average drop-out effect, we need to multiply output by presence probability.
26 | *
27 | * @param in an input vector
28 | * @return output of the vector
29 | */
30 | override def apply(in: ScalarMatrix): ScalarMatrix = layer(in)
31 |
32 | /**
33 | * Serialize network to JSON
34 | *
35 | * @return JsObject of this network
36 | */
37 | override def toJSON = Json.obj(
38 | "type" → this.getClass.getSimpleName,
39 | "presence" → presence.safe,
40 | "layers" → Json.arr(layer.toJSON)
41 | )
42 |
43 | /**
44 | * Reconstruct the given hidden value
45 | *
46 | * @param x hidden value to be reconstructed.
47 | * @return reconstruction value.
48 | */
49 | def reconstruct(x: ScalarMatrix): ScalarMatrix = layer.decodeBy(x)
50 |
51 | /**
52 | * Backpropagation algorithm
53 | *
54 | * @param delta Sequence of delta amount of weight. The order must be the reverse of [[W]]
55 | * @param err backpropagated error from error function
56 | */
57 | override def updateBy(delta: Iterator[ScalarMatrix], err: ScalarMatrix): ScalarMatrix = {
58 | val e = decode_!(delta)(err)
59 | encode_!(delta)(e)
60 | }
61 |
62 | /**
63 | * Backpropagation algorithm for decoding phrase
64 | *
65 | * @param err backpropagated error from error function
66 | */
67 | def decode_!(delta: Iterator[ScalarMatrix])(err: ScalarMatrix) = {
68 | layer decodeUpdateBy(delta, err)
69 | }
70 |
71 | /**
72 | * Backpropagation algorithm for encoding phrase
73 | *
74 | * @param err backpropagated error from error function
75 | */
76 | def encode_!(delta: Iterator[ScalarMatrix])(err: ScalarMatrix) = {
77 | layer updateBy(delta, err)
78 | }
79 |
80 | /**
81 | * Forward computation for training.
82 | * If drop-out is used, we need to drop-out entry of input vector.
83 | *
84 | * @param x input matrix
85 | * @return output matrix
86 | */
87 | override def passedBy(x: ScalarMatrix): ScalarMatrix = decode(encode(x))
88 |
89 | /**
90 | * Encode computation for training.
91 | * If drop-out is used, we need to drop-out entry of input vector.
92 | *
93 | * @param x input matrix
94 | * @return hidden values
95 | */
96 | def encode(x: ScalarMatrix): ScalarMatrix = {
97 | layer.passedBy(x)
98 | }
99 |
100 | /**
101 | * Decode computation for training.
102 | * If drop-out is used, we need to drop-out entry of input vector.
103 | *
104 | * @param x hidden values
105 | * @return output matrix
106 | */
107 | def decode(x: ScalarMatrix): ScalarMatrix = {
108 | layer.decodeBy(x)
109 | }
110 |
111 | /**
112 | * Sugar: Forward computation for validation. Calls apply(x)
113 | *
114 | * @param x input matrix
115 | * @return output matrix
116 | */
117 | override def of(x: ScalarMatrix): ScalarMatrix = {
118 | layer.decodeFrom(layer(x))
119 | }
120 | }
121 |
122 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/network/BasicNetwork.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.network
2 |
3 | import kr.ac.kaist.ir.deep.fn._
4 | import kr.ac.kaist.ir.deep.layer.Layer
5 | import play.api.libs.json.{JsArray, Json}
6 |
7 | /**
8 | * __Network__: A basic network implementation
9 | * @param layers __Sequence of layers__ of this network
10 | */
11 | class BasicNetwork(val layers: IndexedSeq[Layer])
12 | extends Network {
13 | /**
14 | * All weights of layers
15 | *
16 | * @return all weights of layers
17 | */
18 | override val W: IndexedSeq[ScalarMatrix] = layers flatMap (_.W)
19 |
20 | /**
21 | * Compute output of neural network with given input
22 | * If drop-out is used, to average drop-out effect, we need to multiply output by presence probability.
23 | *
24 | * @param in an input vector
25 | * @return output of the vector
26 | */
27 | override def apply(in: ScalarMatrix): ScalarMatrix = {
28 | layers.foldLeft(in) {
29 | case (v, l) ⇒ l apply v
30 | }
31 | }
32 |
33 | /**
34 | * Serialize network to JSON
35 | *
36 | * @return JsObject of this network
37 | */
38 | override def toJSON = Json.obj(
39 | "type" → this.getClass.getSimpleName,
40 | "layers" → JsArray(layers map (_.toJSON))
41 | )
42 |
43 | /**
44 | * Backpropagation algorithm
45 | *
46 | * @param delta Sequence of delta amount of weight. The order must be the reverse of [[W]]
47 | * @param err backpropagated error from error function
48 | */
49 | override def updateBy(delta: Iterator[ScalarMatrix], err: ScalarMatrix): ScalarMatrix = {
50 | layers.foldRight(err) {
51 | case (l, e) ⇒ l updateBy(delta, e)
52 | }
53 | }
54 |
55 | /**
56 | * Forward computation for training.
57 | * If drop-out is used, we need to drop-out entry of input vector.
58 | *
59 | * @param x input matrix
60 | * @return output matrix
61 | */
62 | override def passedBy(x: ScalarMatrix): ScalarMatrix = {
63 | layers.foldLeft(x) {
64 | case (v, l) ⇒ l passedBy v
65 | }
66 | }
67 | }
68 |
69 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/network/StackedAutoEncoder.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.network
2 |
3 | import kr.ac.kaist.ir.deep.fn.ScalarMatrix
4 | import play.api.libs.json.{JsObject, Json}
5 |
6 | import scala.collection.mutable.ArrayBuffer
7 |
8 | /**
9 | * __Network__: Stack of autoencoders.
10 | *
11 | * @param encoders __Sequence of AutoEncoders__ to be stacked.
12 | */
13 | class StackedAutoEncoder(val encoders: Seq[AutoEncoder]) extends Network {
14 | /**
15 | * All weights of layers
16 | *
17 | * @return all weights of layers
18 | */
19 | override val W: IndexedSeq[ScalarMatrix] = {
20 | val matrices = ArrayBuffer[ScalarMatrix]()
21 | encoders.flatMap(_.W).foreach(matrices += _)
22 | matrices
23 | }
24 |
25 | /**
26 | * Serialize network to JSON
27 | *
28 | * @return JsObject of this network
29 | */
30 | override def toJSON: JsObject =
31 | Json.obj(
32 | "type" → this.getClass.getSimpleName,
33 | "stack" → Json.arr(encoders map (_.toJSON))
34 | )
35 |
36 | /**
37 | * Compute output of neural network with given input (without reconstruction)
38 | * If drop-out is used, to average drop-out effect, we need to multiply output by presence probability.
39 | *
40 | * @param in an input vector
41 | * @return output of the vector
42 | */
43 | override def apply(in: ScalarMatrix): ScalarMatrix = {
44 | encoders.foldLeft(in) {
45 | case (v, l) ⇒ l apply v
46 | }
47 | }
48 |
49 | /**
50 | * Sugar: Forward computation for training. Calls apply(x)
51 | *
52 | * @param x input matrix
53 | * @return output matrix
54 | */
55 | override def passedBy(x: ScalarMatrix): ScalarMatrix = {
56 | encoders.foldLeft(x) {
57 | case (v, l) ⇒ l passedBy v
58 | }
59 | }
60 |
61 | /**
62 | * Backpropagation algorithm
63 | *
64 | * @param delta Sequence of delta amount of weight. The order must be the reverse of [[W]]
65 | * @param err backpropagated error from error function
66 | */
67 | override def updateBy(delta: Iterator[ScalarMatrix], err: ScalarMatrix): ScalarMatrix = {
68 | encoders.foldRight(err) {
69 | case (l, e) ⇒ l updateBy(delta, e)
70 | }
71 | }
72 | }
73 |
74 |
75 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/network/package.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep
2 |
3 | import java.io.Serializable
4 |
5 | import kr.ac.kaist.ir.deep.fn.{Activation, Probability, ScalarMatrix}
6 | import kr.ac.kaist.ir.deep.layer.{BasicLayer, Layer, Reconstructable}
7 | import play.api.libs.json.{JsArray, JsObject, JsValue, Json}
8 |
9 | import scala.collection.mutable.ArrayBuffer
10 | import scala.io.Codec
11 | import scala.reflect.io.{File, Path}
12 | import scala.reflect.runtime.universe
13 |
14 | /**
15 | * Package for network structure
16 | */
17 | package object network {
18 |
19 | /**
20 | * __Trait__: Network interface
21 | */
22 | trait Network extends (ScalarMatrix ⇒ ScalarMatrix) with Serializable {
23 | /**
24 | * All weights of layers
25 | *
26 | * @return all weights of layers
27 | */
28 | val W: IndexedSeq[ScalarMatrix]
29 |
30 | /**
31 | * Serialize network to JSON
32 | * @note Please make an NetReviver object if you're using custom network.
33 | * In that case, please specify NetReviver object's full class name as "__reviver__,"
34 | * and fill up NetReviver.revive method.
35 | *
36 | * @return JsObject of this network
37 | */
38 | def toJSON: JsObject
39 |
40 | /**
41 | * Backpropagation algorithm
42 | *
43 | * @param delta Sequence of delta amount of weight. The order must be the reverse of [[W]]
44 | * @param err backpropagated error from error function
45 | */
46 | def updateBy(delta: Iterator[ScalarMatrix], err: ScalarMatrix): ScalarMatrix
47 |
48 | /**
49 | * Forward computation for training
50 | *
51 | * @param x input matrix
52 | * @return output matrix
53 | */
54 | def passedBy(x: ScalarMatrix): ScalarMatrix
55 |
56 | /**
57 | * Forward computation for training
58 | *
59 | * @param x input matrix
60 | * @return output matrix
61 | */
62 | @deprecated
63 | def into_:(x: ScalarMatrix): ScalarMatrix = passedBy(x)
64 |
65 | /**
66 | * Sugar: Forward computation for validation. Calls apply(x)
67 | *
68 | * @param x input matrix
69 | * @return output matrix
70 | */
71 | def of(x: ScalarMatrix): ScalarMatrix = apply(x)
72 |
73 | /**
74 | * Save given network into given file.
75 | * @param path Path to save this network.
76 | * @param codec Codec used for writer. `(Default: Codec.UTF8)`
77 | */
78 | def saveAsJsonFile(path: Path, codec: Codec = Codec.UTF8): Unit = {
79 | val writer = File(path).bufferedWriter(append = false, codec = codec)
80 | writer.write(Json.prettyPrint(this.toJSON))
81 | writer.close()
82 | }
83 | }
84 |
85 | /**
86 | * __Trait__ of Network Reviver (Companion) objects
87 | */
88 | trait NetReviver extends Serializable {
89 | /**
90 | * Revive network using given JSON value
91 | * @param obj JSON value to be revived
92 | * @return Revived network.
93 | */
94 | def revive(obj: JsValue): Network
95 | }
96 |
97 | /**
98 | * Companion object of BasicNetwork
99 | */
100 | object Network extends NetReviver {
101 | @transient lazy val runtimeMirror = universe.synchronized(universe.runtimeMirror(getClass.getClassLoader))
102 | /**
103 | * Construct network from given layer size information
104 | *
105 | * @param act Activation function for activation function
106 | * @param layerSizes Sizes for construct layers
107 | */
108 | def apply(act: Activation, layerSizes: Int*): Network = {
109 | val layers = ArrayBuffer[Layer]()
110 | layers ++= layerSizes.indices.tail.map {
111 | i ⇒ new BasicLayer(layerSizes(i - 1) → layerSizes(i), act)
112 | }
113 | new BasicNetwork(layers)
114 | }
115 |
116 | /**
117 | * Load network from given file.
118 | * @param path Path to save this network.
119 | * @param codec Codec used for writer. `(Default: Codec.UTF8)`
120 | *
121 | * @tparam T Type of network casted into.
122 | */
123 | def jsonFile[T >: Network](path: Path, codec: Codec = Codec.UTF8): T = {
124 | val line = File(path).lines(codec).mkString("")
125 | val json = Json.parse(line)
126 | apply(json).asInstanceOf[T]
127 | }
128 |
129 | /**
130 | * Load network from JsObject
131 | *
132 | * @param obj JsObject to be parsed
133 | * @return New Network reconstructed from this object
134 | */
135 | def apply(obj: JsValue): Network = {
136 | val companion =
137 | universe.synchronized {
138 | (obj \ "reviver").asOpt[String] match {
139 | case Some(clsName) ⇒
140 | val module = runtimeMirror.staticModule(clsName)
141 | runtimeMirror.reflectModule(module).instance.asInstanceOf[NetReviver]
142 | case None ⇒
143 | this
144 | }
145 | }
146 | companion.revive(obj)
147 | }
148 |
149 | /**
150 | * Revive network using given JSON value
151 | * @param obj JSON value to be revived
152 | * @return Revived network.
153 | */
154 | override def revive(obj: JsValue): Network = {
155 | (obj \ "type").as[String] match {
156 | case "AutoEncoder" ⇒ AutoEncoder(obj)
157 | case "BasicNetwork" ⇒ BasicNetwork(obj)
158 | case "StackedAutoEncoder" ⇒ StackedAutoEncoder(obj)
159 | }
160 | }
161 |
162 | /**
163 | * Load network from JsObject
164 | *
165 | * @param obj JsObject to be parsed
166 | * @return New AutoEncoder reconstructed from this object
167 | */
168 | def AutoEncoder(obj: JsValue): AutoEncoder = {
169 | val layers = (obj \ "layers").as[JsArray].value map Layer.apply
170 | val presence = (obj \ "presence").as[Probability]
171 | new AutoEncoder(layers.head.asInstanceOf[Reconstructable], presence)
172 | }
173 |
174 | /**
175 | * Load network from JsObject
176 | *
177 | * @param obj JsObject to be parsed
178 | * @return New Basic Network reconstructed from this object
179 | */
180 | def BasicNetwork(obj: JsValue): BasicNetwork = {
181 | val layers = ArrayBuffer[Layer]()
182 | layers ++= (obj \ "layers").as[JsArray].value.map(Layer.apply)
183 | new BasicNetwork(layers)
184 | }
185 |
186 | /**
187 | * Load network from JsObject
188 | *
189 | * @param obj JsObject to be parsed
190 | * @return New Stacked AutoEncoder reconstructed from this object
191 | */
192 | def StackedAutoEncoder(obj: JsValue): StackedAutoEncoder = {
193 | val layers = (obj \ "stack").as[Seq[JsObject]] map Network.AutoEncoder
194 | new StackedAutoEncoder(layers)
195 | }
196 |
197 | }
198 |
199 | }
200 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/package.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir
2 |
3 | /**
4 | * A ''Neural Network implementation'' with Scala, [[https://github.com/scalanlp/breeze Breeze]] & [[http://spark.apache.org Spark]]
5 | *
6 | * @example
7 | * {{{// Define 2 -> 4 -> 1 Layered, Fully connected network.
8 | * val net = Network(Sigmoid, 2, 4, 1)
9 | *
10 | * // Define Manipulation Type. VectorType, AEType, RAEType, StandardRAEType, URAEType, and StringToVectorType.
11 | * val operation = new VectorType(
12 | * corrupt = GaussianCorruption(variance = 0.1)
13 | * )
14 | *
15 | * // Define Training Style. SingleThreadTrainStyle vs DistBeliefTrainStyle
16 | * val style = new SingleThreadTrainStyle(
17 | * net = net,
18 | * algorithm = new StochasticGradientDescent(l2decay = 0.0001),
19 | * make = operation,
20 | * param = SimpleTrainingCriteria(miniBatchFraction = 0.01))
21 | *
22 | * // Define Trainer
23 | * val train = new Trainer(
24 | * style = style,
25 | * stops = StoppingCriteria(maxIter = 100000))
26 | *
27 | * // Do Train
28 | * train.train(set, valid)}}}
29 | */
30 | package object deep
31 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/rec/BinaryTree.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.rec
2 |
3 | import kr.ac.kaist.ir.deep.fn._
4 | import kr.ac.kaist.ir.deep.train.Corruption
5 |
6 | /**
7 | * __Node__ for internal structure (non-terminal)
8 | */
9 | class BinaryTree(val left: Node, right: Node) extends Node {
10 |
11 | /**
12 | * Forward computation of Binary Tree
13 | *
14 | * @param fn function to be applied
15 | * @return the result
16 | */
17 | override def forward(fn: ScalarMatrix ⇒ ScalarMatrix): ScalarMatrix = {
18 | val leftMatx = left.forward(fn)
19 | val rightMatx = right.forward(fn)
20 | fn(leftMatx row_+ rightMatx)
21 | }
22 |
23 | /**
24 | * Backward computation of Binary Tree
25 | *
26 | * @param err Matrix to be propagated
27 | * @param fn function to be applied
28 | * @return Sequence of terminal nodes
29 | */
30 | def backward(err: ScalarMatrix, fn: ScalarMatrix ⇒ ScalarMatrix): Seq[Leaf] = {
31 | val error = fn(err)
32 | val rSize = error.rows / 2
33 |
34 | val seqLeft = left.backward(error(0 until rSize, ::), fn)
35 | val seqRight = right.backward(error(rSize to -1, ::), fn)
36 | seqLeft ++ seqRight
37 | }
38 |
39 | /**
40 | * Corrupt this node
41 | * *
42 | * @param corrupt Corruption function to be applied
43 | * @return Corrupted Binary Tree
44 | */
45 | override def through(corrupt: Corruption): Node =
46 | new BinaryTree(left through corrupt, right through corrupt)
47 |
48 | /**
49 | * Replace wildcard node
50 | * @param resolve Wildcard Resolver function
51 | * @return new Node without wildcard
52 | */
53 | override def ?(resolve: (Int) ⇒ Node): Node = {
54 | val newLeft = left ? resolve
55 | val newRight = right ? resolve
56 |
57 | if (left.equals(newLeft) && right.equals(newRight))
58 | this
59 | else
60 | new BinaryTree(newLeft, newRight)
61 | }
62 | }
63 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/rec/Leaf.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.rec
2 |
3 | import kr.ac.kaist.ir.deep.fn.ScalarMatrix
4 | import kr.ac.kaist.ir.deep.train.Corruption
5 |
6 | /**
7 | * __Node of BinaryTree__ whose position is terminal.
8 | *
9 | * This node does not do any computation.
10 | *
11 | * @param x original value matrix
12 | */
13 | class Leaf(val x: ScalarMatrix) extends Node {
14 | var out: ScalarMatrix = x
15 |
16 | /**
17 | * Forward computation of Binary Tree
18 | *
19 | * @param fn function to be applied
20 | * @return the result
21 | */
22 | override def forward(fn: ScalarMatrix ⇒ ScalarMatrix): ScalarMatrix = out
23 |
24 | /**
25 | * Backward computation of Binary Tree
26 | *
27 | * @param err Matrix to be propagated
28 | * @param fn function to be applied
29 | * @return Sequence of terminal nodes
30 | */
31 | def backward(err: ScalarMatrix, fn: ScalarMatrix ⇒ ScalarMatrix): Seq[Leaf] = {
32 | out = err
33 | Seq(this)
34 | }
35 |
36 | /**
37 | * Corrupt this node
38 | * *
39 | * @param corrupt Corruption function to be applied
40 | * @return Corrupted Binary Tree
41 | */
42 | override def through(corrupt: Corruption): Node =
43 | new Leaf(corrupt(x))
44 |
45 | /**
46 | * Replace wildcard node
47 | * @param resolve Wildcard Resolver function
48 | * @return new Node without wildcard
49 | */
50 | override def ?(resolve: (Int) ⇒ Node): Node = this
51 | }
52 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/rec/Node.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.rec
2 |
3 | import kr.ac.kaist.ir.deep.fn.ScalarMatrix
4 | import kr.ac.kaist.ir.deep.train.Corruption
5 |
6 | /**
7 | * __Trait__ that describes a node in BinaryTree.
8 | */
9 | trait Node extends Serializable {
10 | /**
11 | * Forward computation of Binary Tree
12 | *
13 | * @param fn function to be applied
14 | * @return the result
15 | */
16 | def forward(fn: ScalarMatrix ⇒ ScalarMatrix): ScalarMatrix
17 |
18 | /**
19 | * Backward computation of Binary Tree
20 | *
21 | * @param err Matrix to be propagated
22 | * @param fn function to be applied
23 | * @return Sequence of terminal nodes
24 | */
25 | def backward(err: ScalarMatrix, fn: ScalarMatrix ⇒ ScalarMatrix): Seq[Leaf]
26 |
27 | /**
28 | * Corrupt this node
29 | * *
30 | * @param corrupt Corruption function to be applied
31 | * @return Corrupted Binary Tree
32 | */
33 | def through(corrupt: Corruption): Node
34 |
35 | /**
36 | * Replace wildcard node
37 | * @param resolve Wildcard Resolver function
38 | * @return new Node without wildcard
39 | */
40 | def ?(resolve: Int ⇒ Node): Node
41 | }
42 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/rec/WildcardLeaf.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.rec
2 |
3 | import kr.ac.kaist.ir.deep.fn.ScalarMatrix
4 | import kr.ac.kaist.ir.deep.train.Corruption
5 |
6 | /**
7 | * __Node of BinaryTree__ whose position is terminal.
8 | *
9 | * This node does not do any computation.
10 | *
11 | * @param id ID of wildcard entry
12 | */
13 | class WildcardLeaf(val id: Int) extends Node {
14 | /**
15 | * Forward computation of Binary Tree
16 | *
17 | * @param fn function to be applied
18 | * @return the result
19 | */
20 | override def forward(fn: ScalarMatrix ⇒ ScalarMatrix): ScalarMatrix = null
21 |
22 | /**
23 | * Backward computation of Binary Tree
24 | *
25 | * @param err Matrix to be propagated
26 | * @param fn function to be applied
27 | * @return Sequence of terminal nodes
28 | */
29 | def backward(err: ScalarMatrix, fn: ScalarMatrix ⇒ ScalarMatrix): Seq[Leaf] = Seq()
30 |
31 | /**
32 | * Corrupt this node
33 | * *
34 | * @param corrupt Corruption function to be applied
35 | * @return Corrupted Binary Tree
36 | */
37 | override def through(corrupt: Corruption): Node = this
38 |
39 | /**
40 | * Replace wildcard node
41 | * @param resolve Wildcard Resolver function
42 | * @return new Node without wildcard
43 | */
44 | override def ?(resolve: (Int) ⇒ Node): Node = resolve(id)
45 | }
46 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/rec/package.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep
2 |
3 | /**
4 | * Package object for DAG
5 | */
6 | package object rec
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/train/AEType.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.train
2 |
3 | import kr.ac.kaist.ir.deep.fn._
4 | import kr.ac.kaist.ir.deep.network.Network
5 |
6 | /**
7 | * __Input Operation__ : Vector as Input & Auto Encoder Training (no output type)
8 | *
9 | * @param corrupt Corruption that supervises how to corrupt the input matrix. (Default : [[NoCorruption]])
10 | * @param error An objective function (Default: [[kr.ac.kaist.ir.deep.fn.SquaredErr]])
11 | *
12 | * @example
13 | * {{{var make = new AEType(error = CrossEntropyErr)
14 | * var corruptedIn = make corrupted in
15 | * var out = make onewayTrip (net, corruptedIn)}}}
16 | */
17 | class AEType(override val corrupt: Corruption = NoCorruption,
18 | override val error: Objective = SquaredErr)
19 | extends ManipulationType[ScalarMatrix, Null] {
20 |
21 | /**
22 | * Corrupt input
23 | *
24 | * @param x input to be corrupted
25 | * @return corrupted input
26 | */
27 | override def corrupted(x: ScalarMatrix): ScalarMatrix = corrupt(x)
28 |
29 | /**
30 | * Apply & Back-prop given single input
31 | *
32 | * @param net A network that gets input
33 | * @param delta Sequence of delta updates
34 | */
35 | def roundTrip(net: Network, delta: Seq[ScalarMatrix]) = (in: ScalarMatrix, real: Null) ⇒ {
36 | val out = net passedBy in
37 | val err: ScalarMatrix = error.derivative(in, out)
38 | net updateBy(delta.toIterator, err)
39 | }
40 |
41 | /**
42 | * Apply given input and compute the error
43 | *
44 | * @param net A network that gets input
45 | * @param pair (Input, Real output) for error computation.
46 | * @return error of this network
47 | */
48 | def lossOf(net: Network)(pair: (ScalarMatrix, Null)): Scalar = {
49 | val in = pair._1
50 | val out = net of in
51 | error(in, out)
52 | }
53 |
54 | /**
55 | * Apply given single input as one-way forward trip.
56 | *
57 | * @param net A network that gets input
58 | * @param x input to be computed
59 | * @return output of the network.
60 | */
61 | override def onewayTrip(net: Network, x: ScalarMatrix): ScalarMatrix = net of x
62 |
63 |
64 | /**
65 | * Make validation output
66 | *
67 | * @return input as string
68 | */
69 | def stringOf(net: Network, pair: (ScalarMatrix, Null)): String = {
70 | val in = pair._1
71 | val out = net of in
72 | s"IN: ${in.mkString} RECON → OUT: ${out.mkString}"
73 | }
74 | }
75 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/train/DistBeliefTrainStyle.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.train
2 |
3 | import kr.ac.kaist.ir.deep.fn._
4 | import kr.ac.kaist.ir.deep.network._
5 | import org.apache.spark.SparkContext
6 | import org.apache.spark.broadcast.Broadcast
7 |
8 | import scala.collection.mutable.ArrayBuffer
9 | import scala.concurrent.ExecutionContext.Implicits.global
10 | import scala.concurrent._
11 | import scala.reflect._
12 |
13 | /**
14 | * __Train Style__ : Semi-DistBelief Style, Spark-based.
15 | *
16 | * @note Unlike with DistBelief, this trainer do updates and fetch by '''master''' not the '''workers'''.
17 | *
18 | * @param net __Network__ to be trained
19 | * @param algorithm Weight __update algorithm__ to be applied
20 | * @param sc A __spark context__ that network will be distributed
21 | * @param make __Input Operation__ that supervises how to manipulate input as matrices.
22 | * This also controls how to compute actual network. (default: [[VectorType]])
23 | * @param param __DistBelief-style__ Training criteria (default: [[DistBeliefCriteria]])
24 | */
25 | class DistBeliefTrainStyle[IN: ClassTag, OUT: ClassTag](net: Network,
26 | algorithm: WeightUpdater,
27 | @transient sc: SparkContext,
28 | make: ManipulationType[IN, OUT] = new VectorType(),
29 | param: DistBeliefCriteria = DistBeliefCriteria())
30 | extends MultiThreadTrainStyle[IN, OUT](net, algorithm, sc, make, param) {
31 | /** Flag for batch : Is Batch remaining? */
32 | @transient protected var batchFlag = ArrayBuffer[Future[Unit]]()
33 | /** Flag for fetch : Is fetching? */
34 | @transient protected var fetchFlag: Future[Unit] = null
35 | /** Flag for update : Is updating? */
36 | @transient protected var updateFlag: Future[Unit] = null
37 | /** Spark distributed networks */
38 | protected var bcNet: Broadcast[Network] = _
39 |
40 | /**
41 | * Fetch weights
42 | *
43 | * @param iter current iteration
44 | */
45 | override def fetch(iter: Int): Unit =
46 | if (iter % param.fetchStep == 0) {
47 | if (fetchFlag != null && !fetchFlag.isCompleted) {
48 | logger warn "Fetch command arrived before previous fetch is done. Need more steps between fetch commands!"
49 | }
50 |
51 | fetchFlag =
52 | future {
53 | val oldNet = bcNet
54 | bcNet = sc.broadcast(net)
55 |
56 | // Because DistBelief submit fetching job after n_fetch steps,
57 | // submit this fetch after already submitted jobs are done.
58 | // This does not block others because batch can be submitted anyway,
59 | // and that batch does not affect this thread.
60 | stopUntilBatchFinished()
61 |
62 | future {
63 | Thread.sleep(param.submitInterval.toMillis * param.fetchStep)
64 | oldNet.destroy()
65 | }
66 | }
67 | }
68 |
69 | /**
70 | * Non-blocking pending, until all assigned batches are finished
71 | */
72 | override def stopUntilBatchFinished(): Unit = {
73 | AsyncAwait.readyAll(param.submitInterval, batchFlag: _*)
74 | batchFlag = batchFlag.filterNot(_.isCompleted)
75 | }
76 |
77 | /**
78 | * Send update of weights
79 | *
80 | * @param iter current iteration
81 | */
82 | override def update(iter: Int): Unit =
83 | if (iter % param.updateStep == 0) {
84 | if (updateFlag != null && !updateFlag.isCompleted) {
85 | logger warn "Update command arrived before previous update is done. Need more steps between update commands!"
86 | }
87 |
88 | updateFlag =
89 | future {
90 | // Because DistBelief submit updating job after n_update steps,
91 | // Submit this update after already submitted jobs are done.
92 | // This does not block others because batch can be submitted anyway,
93 | // and that batch does not affect this thread.
94 | stopUntilBatchFinished()
95 |
96 | val dWUpdate = accNet.value.reverse
97 | accNet.setValue(WeightAccumulator.zero(accNet.zero))
98 | val count = accCount.value
99 | accCount.setValue(0)
100 |
101 | dWUpdate :/= count.toFloat
102 | net.W -= dWUpdate
103 | }
104 | }
105 |
106 | /**
107 | * Indicates whether the asynchrononus update is finished or not.
108 | *
109 | * @return future object of update
110 | */
111 | override def isUpdateFinished: Future[_] = updateFlag
112 |
113 | /**
114 | * Do mini-batch
115 | */
116 | override def batch(): Unit = {
117 | val part = partFunction(bcNet)
118 | val x = if (param.miniBatchFraction > 0) {
119 | val rddSet = trainingSet.sample(withReplacement = true, fraction = param.miniBatchFraction)
120 | .repartition(param.numCores)
121 |
122 | val x = rddSet foreachPartitionAsync part
123 | batchFlag += x
124 |
125 | x.onComplete {
126 | _ ⇒ rddSet.unpersist()
127 | }
128 | x
129 | } else {
130 | val x = trainingSet foreachPartitionAsync part
131 | batchFlag += x
132 |
133 | x
134 | }
135 |
136 | try {
137 | Await.ready(x, param.submitInterval)
138 | } catch {
139 | case _: Throwable ⇒
140 | }
141 | }
142 | }
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/train/ManipulationType.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.train
2 |
3 | import kr.ac.kaist.ir.deep.fn.{Objective, Scalar, ScalarMatrix}
4 | import kr.ac.kaist.ir.deep.network.Network
5 |
6 | /**
7 | * __Trait__ that describes how to convert input into corrupted matrix
8 | *
9 | * Input operation corrupts the given input, and apply network propagations onto matrix representation of input
10 | *
11 | * @tparam IN the type of input
12 | * @tparam OUT the type of output
13 | */
14 | trait ManipulationType[IN, OUT] extends Serializable {
15 | /** Corruption function */
16 | val corrupt: Corruption
17 | /** Objective function */
18 | val error: Objective
19 |
20 | // We didn't assign a "network" value, because of dist-belief training style.
21 |
22 | /**
23 | * Corrupt input
24 | *
25 | * @param x input to be corrupted
26 | * @return corrupted input
27 | */
28 | def corrupted(x: IN): IN
29 |
30 | /**
31 | * Apply & Back-prop given single input
32 | *
33 | * @param net A network that gets input
34 | * @param delta Sequence of delta updates
35 | */
36 | def roundTrip(net: Network, delta: Seq[ScalarMatrix]): (IN, OUT) ⇒ Unit
37 |
38 | /**
39 | * Apply given single input as one-way forward trip.
40 | *
41 | * @param net A network that gets input
42 | * @param x input to be computed
43 | * @return output of the network.
44 | */
45 | def onewayTrip(net: Network, x: IN): ScalarMatrix
46 |
47 | /**
48 | * Make validation output
49 | *
50 | * @param net A network that gets input
51 | * @param in (Input, Real output) pair for computation
52 | * @return input as string
53 | */
54 | def stringOf(net: Network, in: (IN, OUT)): String
55 |
56 | /**
57 | * Apply given input and compute the error
58 | *
59 | * @param net A network that gets input
60 | * @param pair (Input, Real output) for error computation.
61 | * @return error of this network
62 | */
63 | def lossOf(net: Network)(pair: (IN, OUT)): Scalar
64 |
65 | /**
66 | * Check whether given two are same or not.
67 | * @param x Out-type object
68 | * @param y Out-type object
69 | * @return True if they are different.
70 | */
71 | def different(x: OUT, y: OUT): Boolean = true
72 | }
73 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/train/MultiThreadTrainStyle.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.train
2 |
3 | import kr.ac.kaist.ir.deep.fn.{ScalarMatrix, WeightSeqOp, WeightUpdater}
4 | import kr.ac.kaist.ir.deep.network.Network
5 | import org.apache.spark.SparkContext
6 | import org.apache.spark.broadcast.Broadcast
7 | import org.apache.spark.rdd.RDD
8 |
9 | import scala.concurrent.ExecutionContext.Implicits.global
10 | import scala.concurrent._
11 | import scala.concurrent.duration._
12 | import scala.reflect.ClassTag
13 |
14 | /**
15 | * __Trainer__ : Stochastic-Style, Multi-Threaded using Spark.
16 | *
17 | * @note This is not a implementation using DistBelief Paper.
18 | * This is between [[DistBeliefTrainStyle]](DBTS) and [[SingleThreadTrainStyle]](STTS).
19 | * The major difference is whether "updating" is asynchronous(DBTS) or not(MTTS).
20 | *
21 | * @param net __Network__ to be trained
22 | * @param algorithm Weight __update algorithm__ to be applied
23 | * @param make __Input Operation__ that supervises how to manipulate input as matrices.
24 | * This also controls how to compute actual network. (default: [[VectorType]])
25 | * @param param __Training criteria__ (default: [[SimpleTrainingCriteria]])
26 | */
27 | class MultiThreadTrainStyle[IN: ClassTag, OUT: ClassTag](override val net: Network,
28 | override val algorithm: WeightUpdater,
29 | @transient val sc: SparkContext,
30 | override val make: ManipulationType[IN, OUT] = new VectorType(),
31 | override val param: DistBeliefCriteria = DistBeliefCriteria())
32 | extends TrainStyle[IN, OUT] {
33 | /** Accumulator variable for networks */
34 | protected val accNet = sc.accumulator(WeightAccumulator.zero(net.W).reverse)(WeightAccumulator)
35 | protected val weightSizes = sc.broadcast(net.W.map(m ⇒ m.rows → m.cols).reverse)
36 | /** Accumulator variable for counter */
37 | protected val accCount = sc.accumulator(0)
38 | /** Training set */
39 | protected var trainingSet: RDD[Pair] = null
40 | /** Test Set */
41 | protected var testSet: RDD[Pair] = null
42 |
43 | /**
44 | * Unpersist all
45 | */
46 | def unpersist(blocking: Boolean = false): Unit = {
47 | if (trainingSet != null)
48 | trainingSet.unpersist(blocking = blocking)
49 | if (testSet != null)
50 | testSet.unpersist(blocking = blocking)
51 | weightSizes.unpersist(blocking = false)
52 | }
53 |
54 | /**
55 | * Fetch weights
56 | *
57 | * @param iter current iteration
58 | */
59 | override def fetch(iter: Int): Unit = {
60 | accNet.value.par.map(_ := 0f)
61 | accCount.setValue(0)
62 | }
63 |
64 | /**
65 | * Send update of weights
66 | *
67 | * @param iter current iteration
68 | */
69 | override def update(iter: Int): Unit = {
70 | val dWUpdate = accNet.value.reverse
71 | val cnt = accCount.value.toFloat
72 | if (cnt > 0) {
73 | dWUpdate :/= cnt
74 | net.W -= dWUpdate
75 | } else {
76 | logger.warn(s"Epoch $iter trained with 0 instances. Please check.")
77 | }
78 | }
79 |
80 | /**
81 | * Do mini-batch
82 | */
83 | override def batch(): Unit = {
84 | val bcNet = sc.broadcast(net)
85 | val part = partFunction(bcNet)
86 | if (param.miniBatchFraction > 0) {
87 | val set = trainingSet.sample(withReplacement = true, fraction = param.miniBatchFraction)
88 | set.foreachPartition(part)
89 | set.unpersist(blocking = false)
90 | } else {
91 | trainingSet.foreachPartition(part)
92 | }
93 | bcNet.unpersist(blocking = false)
94 | }
95 |
96 | protected final def partFunction(net: Broadcast[Network]) = {
97 |
98 | (part: Iterator[(IN, OUT)]) ⇒ {
99 | var count = 0
100 | val f = future {
101 | lazy val dW = weightSizes.value.map(ScalarMatrix.$0)
102 | lazy val trip = make.roundTrip(net.value, dW)
103 |
104 | part.foreach {
105 | case (x, y) ⇒
106 | count += 1
107 | trip(x, y)
108 | }
109 |
110 | accCount += count
111 | accNet += dW
112 | }
113 |
114 | AsyncAwait.ready(f, 1.second)
115 | }
116 | }
117 |
118 | /**
119 | * Set training instances
120 | * @param set Sequence of training set
121 | */
122 | override def setPositiveTrainingReference(set: Seq[(IN, OUT)]): Unit = {
123 | val rdd =
124 | if (param.repartitionOnStart) sc.parallelize(set, param.numCores)
125 | else sc.parallelize(set)
126 | trainingSet = rdd.setName("Positives").persist(param.storageLevel)
127 | validationEpoch = if (param.miniBatchFraction > 0) Math.round(1.0f / param.miniBatchFraction) else 1
128 | }
129 |
130 | /**
131 | * Set training instances
132 | * @param set RDD of training set
133 | */
134 | override def setPositiveTrainingReference(set: RDD[(IN, OUT)]): Unit = {
135 | val rdd =
136 | if (param.repartitionOnStart) set.repartition(param.numCores).persist(param.storageLevel)
137 | else set
138 | trainingSet = rdd.setName(set.name + " (Positives)")
139 | validationEpoch = if (param.miniBatchFraction > 0) Math.round(1.0f / param.miniBatchFraction) else 1
140 | }
141 |
142 | /**
143 | * Set testing instances
144 | * @param set Sequence of testing set
145 | */
146 | override def setTestReference(set: Seq[(IN, OUT)]): Unit = {
147 | val rdd =
148 | if (param.repartitionOnStart) sc.parallelize(set, param.numCores)
149 | else sc.parallelize(set)
150 | testSet = rdd.setName("Validation").persist(param.storageLevel)
151 | }
152 |
153 | /**
154 | * Set testing instances
155 | * @param set RDD of testing set
156 | */
157 | override def setTestReference(set: RDD[(IN, OUT)]): Unit = {
158 | val rdd =
159 | if (param.repartitionOnStart) set.repartition(param.numCores).persist(param.storageLevel)
160 | else set
161 | testSet = rdd.setName(set.name + " (Validation)")
162 | }
163 |
164 | /**
165 | * Iterate over given number of test instances
166 | * @param n number of random sampled instances
167 | * @param fn iteratee function
168 | */
169 | override def foreachTestSet(n: Int)(fn: ((IN, OUT)) ⇒ Unit): Unit = {
170 | var seq = testSet.takeSample(withReplacement = true, num = n)
171 | while (seq.nonEmpty) {
172 | fn(seq.head)
173 | seq = seq.tail
174 | }
175 | }
176 |
177 | /**
178 | * Calculate validation error
179 | *
180 | * @return validation error
181 | */
182 | def validationError() = {
183 | val loss = sc.accumulator(0.0f)
184 | val count = sc.accumulator(0)
185 | val lossOf = make.lossOf(net) _
186 | testSet.foreachPartition {
187 | iter ⇒
188 | val f = future {
189 | var sum = 0.0f
190 | var c = 0
191 | while (iter.hasNext) {
192 | sum += lossOf(iter.next())
193 | c += 1
194 | }
195 | loss += sum
196 | count += c
197 | }
198 |
199 | AsyncAwait.ready(f, 1.second)
200 | }
201 |
202 | loss.value / count.value.toFloat
203 | }
204 | }
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/train/RAEType.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.train
2 |
3 | import kr.ac.kaist.ir.deep.fn._
4 | import kr.ac.kaist.ir.deep.network.Network
5 | import kr.ac.kaist.ir.deep.rec.BinaryTree
6 |
7 | /**
8 | * __Input Operation__ : VectorTree as Input & Recursive Auto-Encoder Training (no output type)
9 | *
10 | * @note We recommend that you should not apply this method to non-AutoEncoder tasks
11 | * @note This implementation designed as a replica of the traditional RAE in
12 | * [[http://ai.stanford.edu/~ang/papers/emnlp11-RecursiveAutoencodersSentimentDistributions.pdf this paper]]
13 | *
14 | * @param corrupt Corruption that supervises how to corrupt the input matrix. `(Default : [[kr.ac.kaist.ir.deep.train.NoCorruption]])`
15 | * @param error An objective function `(Default: [[kr.ac.kaist.ir.deep.fn.SquaredErr]])`
16 | *
17 | * @example
18 | * {{{var make = new RAEType(error = CrossEntropyErr)
19 | * var corruptedIn = make corrupted in
20 | * var out = make onewayTrip (net, corruptedIn)}}}
21 | */
22 | class RAEType(override val corrupt: Corruption = NoCorruption,
23 | override val error: Objective = SquaredErr)
24 | extends TreeType {
25 |
26 | /**
27 | * Apply & Back-prop given single input
28 | *
29 | * @param net A network that gets input
30 | * @param delta Sequence of delta updates
31 | */
32 | def roundTrip(net: Network, delta: Seq[ScalarMatrix]) = (in: BinaryTree, real: Null) ⇒ {
33 | in forward {
34 | x ⇒
35 | val err = error.derivative(x, net passedBy x)
36 | net updateBy(delta.toIterator, err)
37 | // propagate hidden-layer value
38 | net(x)
39 | }
40 | }
41 |
42 | /**
43 | * Apply given input and compute the error
44 | *
45 | * @param net A network that gets input
46 | * @param pair (Input, Real output) for error computation.
47 | * @return error of this network
48 | */
49 | def lossOf(net: Network)(pair: (BinaryTree, Null)): Scalar = {
50 | var sum = 0.0f
51 | val in = pair._1
52 | in forward {
53 | x ⇒
54 | sum += error(x, net of x)
55 | //propagate hidden-layer value
56 | net(x)
57 | }
58 | sum
59 | }
60 |
61 | /**
62 | * Make validation output
63 | *
64 | * @return input as string
65 | */
66 | def stringOf(net: Network, pair: (BinaryTree, Null)): String = {
67 | val string = StringBuilder.newBuilder
68 | pair._1 forward {
69 | x ⇒
70 | val out = net of x
71 | val hid = net(x)
72 | string append s"IN: ${x.mkString} RAE → OUT: ${out.mkString}, HDN: ${hid.mkString}; "
73 | // propagate hidden-layer value
74 | hid
75 | }
76 | string.mkString
77 | }
78 | }
79 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/train/RandomEqualPartitioner.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.train
2 |
3 | import org.apache.spark.Partitioner
4 |
5 | /**
6 | * Spark Partitioner that gives almost-equal partitions.
7 | *
8 | * @note Use this with RDD.zipWithUniqueId()
9 | *
10 | * @param numPartition Number of partitions
11 | */
12 | class RandomEqualPartitioner(val numPartition: Int) extends Partitioner {
13 | private var nextNumber = 0
14 |
15 | def refreshRandom() = {
16 | nextNumber += 1
17 | }
18 |
19 | override def numPartitions: Int = numPartition
20 |
21 | override def getPartition(key: Any): Int = {
22 | val i = key.asInstanceOf[Long] + nextNumber
23 | val remain = i % numPartition
24 | val quotient = ((i / numPartition) * nextNumber) % numPartition
25 | val hash = ((remain + quotient) % numPartition).asInstanceOf[Int]
26 | if (hash < 0)
27 | hash + numPartition
28 | else
29 | hash
30 | }
31 | }
32 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/train/SingleThreadTrainStyle.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.train
2 |
3 | import java.util.concurrent.ThreadLocalRandom
4 |
5 | import kr.ac.kaist.ir.deep.fn.{Scalar, WeightSeqOp, WeightUpdater}
6 | import kr.ac.kaist.ir.deep.network.Network
7 | import org.apache.spark.SparkContext
8 | import org.apache.spark.rdd.RDD
9 |
10 | /**
11 | * __Trainer__ : Stochastic-Style, Single-Threaded
12 | *
13 | * @param net __Network__ to be trained
14 | * @param algorithm Weight __update algorithm__ to be applied
15 | * @param make __Input Operation__ that supervises how to manipulate input as matrices.
16 | * This also controls how to compute actual network. (default: [[VectorType]])
17 | * @param param __Training criteria__ (default: [[SimpleTrainingCriteria]])
18 | */
19 | class SingleThreadTrainStyle[IN, OUT](override val net: Network,
20 | override val algorithm: WeightUpdater,
21 | override val make: ManipulationType[IN, OUT] = new VectorType(),
22 | override val param: TrainingCriteria = SimpleTrainingCriteria())
23 | extends TrainStyle[IN, OUT] {
24 |
25 | /** dWeight */
26 | private val dW = WeightAccumulator.zero(net.W).reverse
27 | /** Training set */
28 | private var trainingSet: Scalar ⇒ Seq[Pair] = null
29 | /** Test Set */
30 | private var testSet: Int ⇒ Seq[Pair] = null
31 | /** Test Set iterator */
32 | private var testSetMapper: (Pair ⇒ Unit) ⇒ Unit = null
33 | /** Test Set Context. Null if testset is a local seq */
34 | private var testSetSC: SparkContext = null
35 | /** Count */
36 | private var count = 0
37 |
38 | /**
39 | * Fetch weights
40 | *
41 | * @param iter current iteration
42 | */
43 | override def fetch(iter: Int): Unit = {}
44 |
45 | /**
46 | * Send update of weights
47 | *
48 | * @param iter current iteration
49 | */
50 | override def update(iter: Int): Unit = {
51 | dW :/= count.toFloat
52 | net.W -= dW.reverse
53 | count = 0
54 | }
55 |
56 | /**
57 | * Do mini-batch
58 | */
59 | override def batch(): Unit = {
60 | val seq = trainingSet(param.miniBatchFraction)
61 | val trip = make.roundTrip(net, dW)
62 | seq.foreach {
63 | case (x, y) ⇒
64 | count += 1
65 | trip(x, y)
66 | }
67 | }
68 |
69 | /**
70 | * Set training instances
71 | * @param set Sequence of training set
72 | */
73 | override def setPositiveTrainingReference(set: Seq[(IN, OUT)]): Unit = {
74 | trainingSet = (x: Scalar) ⇒
75 | if (x > 0) {
76 | set.filter(_ ⇒ ThreadLocalRandom.current().nextFloat() < x)
77 | } else {
78 | set
79 | }
80 | validationEpoch = if (param.miniBatchFraction > 0) Math.round(1.0f / param.miniBatchFraction) else 1
81 | }
82 |
83 | /**
84 | * Set training instances
85 | * @param set RDD of training set
86 | */
87 | override def setPositiveTrainingReference(set: RDD[(IN, OUT)]): Unit = {
88 | trainingSet = (x: Scalar) ⇒
89 | if (x > 0) set.sample(withReplacement = true, fraction = x).collect().toSeq
90 | else set.collect()
91 | validationEpoch = if (param.miniBatchFraction > 0) Math.round(1.0f / param.miniBatchFraction) else 1
92 | }
93 |
94 | /**
95 | * Set testing instances
96 | * @param set Sequence of testing set
97 | */
98 | override def setTestReference(set: Seq[(IN, OUT)]): Unit = {
99 | testSet = set.take
100 | testSetMapper = (mapper: Pair ⇒ Unit) ⇒ {
101 | var seq = set
102 | while (seq.nonEmpty) {
103 | mapper(seq.head)
104 | seq = seq.tail
105 | }
106 | }
107 | testSetSC = null
108 | }
109 |
110 | /**
111 | * Set testing instances
112 | * @param set RDD of testing set
113 | */
114 | override def setTestReference(set: RDD[(IN, OUT)]): Unit = {
115 | testSet = (n: Int) ⇒ set.takeSample(withReplacement = true, num = n).toSeq
116 | testSetMapper = (mapper: Pair ⇒ Unit) ⇒ {
117 | set.foreach(mapper)
118 | }
119 | testSetSC = set.context
120 | }
121 |
122 | /**
123 | * Calculate validation error
124 | *
125 | * @return validation error
126 | */
127 | def validationError() = {
128 | val lossOf = make.lossOf(net) _
129 |
130 | if (testSetSC == null) {
131 | // If it is from general "local" sequence
132 | var sum = 0.0f
133 | var count = 0
134 | testSetMapper {
135 | item ⇒
136 | sum += lossOf(item)
137 | count += 1
138 | }
139 | sum / count.toFloat
140 | } else {
141 | // If it is from RDD
142 | val sum = testSetSC.accumulator(0.0f)
143 | val count = testSetSC.accumulator(0)
144 | val bcLoss = testSetSC.broadcast(lossOf)
145 | testSetMapper {
146 | item ⇒
147 | sum += bcLoss.value(item)
148 | count += 1
149 | }
150 | bcLoss.destroy()
151 | sum.value / count.value.toFloat
152 | }
153 | }
154 |
155 | /**
156 | * Iterate over given number of test instances
157 | * @param n number of random sampled instances
158 | * @param fn iteratee function
159 | */
160 | override def foreachTestSet(n: Int)(fn: ((IN, OUT)) ⇒ Unit): Unit = {
161 | var set = testSet(n)
162 | while (set.nonEmpty) {
163 | fn(set.head)
164 | set = set.tail
165 | }
166 | }
167 | }
168 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/train/StandardRAEType.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.train
2 |
3 | import kr.ac.kaist.ir.deep.fn._
4 | import kr.ac.kaist.ir.deep.layer.NormalizeOperation
5 | import kr.ac.kaist.ir.deep.network.Network
6 | import kr.ac.kaist.ir.deep.rec.BinaryTree
7 |
8 | /**
9 | * __Input Operation__ : VectorTree as Input & Recursive Auto-Encoder Training (no output type)
10 | *
11 | * @note We recommend that you should not apply this method to non-AutoEncoder tasks
12 | * @note This implementation designed as a replica of the standard RAE (RAE + normalization) in
13 | * [[http://ai.stanford.edu/~ang/papers/emnlp11-RecursiveAutoencodersSentimentDistributions.pdf this paper]]
14 | *
15 | * @param corrupt Corruption that supervises how to corrupt the input matrix. `(Default : [[kr.ac.kaist.ir.deep.train.NoCorruption]])`
16 | * @param error An objective function `(Default: [[kr.ac.kaist.ir.deep.fn.SquaredErr]])`
17 | *
18 | * @example
19 | * {{{var make = new RAEType(error = CrossEntropyErr)
20 | * var corruptedIn = make corrupted in
21 | * var out = make onewayTrip (net, corruptedIn)}}}
22 | */
23 | class StandardRAEType(override val corrupt: Corruption = NoCorruption,
24 | override val error: Objective = SquaredErr)
25 | extends TreeType {
26 | /** Normalization layer */
27 | val normalizeLayer = new NormalizeOperation()
28 |
29 | /**
30 | * Apply & Back-prop given single input
31 | *
32 | * @param net A network that gets input
33 | * @param delta Sequence of delta updates
34 | */
35 | def roundTrip(net: Network, delta: Seq[ScalarMatrix]) = (in: BinaryTree, real: Null) ⇒ {
36 | in forward {
37 | x ⇒
38 | val out = net passedBy x
39 | val zOut = normalizeLayer passedBy out
40 | val dit = delta.toIterator
41 |
42 | // un-normalize the error
43 | val normalErr = error.derivative(x, zOut)
44 | val err = normalizeLayer updateBy(dit, normalErr)
45 |
46 | net updateBy(dit, err)
47 |
48 | // propagate hidden-layer value
49 | net(x)
50 | }
51 | }
52 |
53 | /**
54 | * Apply given input and compute the error
55 | *
56 | * @param net A network that gets input
57 | * @param pair (Input, Real output) for error computation.
58 | * @return error of this network
59 | */
60 | def lossOf(net: Network)(pair: (BinaryTree, Null)): Scalar = {
61 | var total = 0.0f
62 | val in = pair._1
63 | in forward {
64 | x ⇒
65 | val out = net of x
66 | val normalized = normalizeLayer(out)
67 | total += error(x, normalized)
68 | //propagate hidden-layer value
69 | net(x)
70 | }
71 | total
72 | }
73 |
74 | /**
75 | * Make validation output
76 | *
77 | * @return input as string
78 | */
79 | def stringOf(net: Network, pair: (BinaryTree, Null)): String = {
80 | val string = StringBuilder.newBuilder
81 | pair._1 forward {
82 | x ⇒
83 | val out = net of x
84 | val normalized = normalizeLayer(out)
85 | val hid = net(x)
86 | string append s"IN: ${x.mkString} RAE → OUT: ${normalized.mkString}, HDN: ${hid.mkString}; "
87 | // propagate hidden-layer value
88 | hid
89 | }
90 | string.mkString
91 | }
92 | }
93 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/train/TrainStyle.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.train
2 |
3 | import kr.ac.kaist.ir.deep.fn._
4 | import kr.ac.kaist.ir.deep.network.Network
5 | import org.apache.log4j.{Level, Logger}
6 | import org.apache.spark.rdd.RDD
7 |
8 | import scala.concurrent.Future
9 |
10 | /**
11 | * __Trait__ that describes style of training
12 | *
13 | * This trait controls how to train, i.e. __Single-threaded__ or __Distributed__.
14 | *
15 | * @tparam IN the type of input
16 | * @tparam OUT the type of output
17 | */
18 | trait TrainStyle[IN, OUT] extends Serializable {
19 | // Turnoff spark logging feature.
20 | Logger.getRootLogger.setLevel(Level.WARN)
21 | Logger.getLogger("kr.ac").setLevel(Level.INFO)
22 |
23 | /** Training Pair Type */
24 | type Pair = (IN, OUT)
25 | /** Sampler Type */
26 | type Sampler = Int ⇒ Seq[OUT]
27 | /** Training parameters */
28 | val param: TrainingCriteria
29 | /** Network */
30 | val net: Network
31 | /** Algorithm */
32 | val algorithm: WeightUpdater
33 | /** Set of input manipulations */
34 | val make: ManipulationType[IN, OUT]
35 | /** Logger */
36 | @transient protected val logger = Logger.getLogger(this.getClass)
37 | /** number of epochs for iterating one training set */
38 | var validationEpoch: Int = 0
39 |
40 | /**
41 | * Calculate validation error
42 | *
43 | * @return validation error
44 | */
45 | def validationError(): Scalar
46 |
47 | /**
48 | * Iterate over given number of test instances
49 | * @param n number of random sampled instances
50 | * @param fn iteratee function
51 | */
52 | def foreachTestSet(n: Int)(fn: Pair ⇒ Unit): Unit
53 |
54 | /**
55 | * Set training instances
56 | * @param set Sequence of training set
57 | */
58 | def setPositiveTrainingReference(set: Seq[Pair]): Unit
59 |
60 | /**
61 | * Set training instances
62 | * @param set RDD of training set
63 | */
64 | def setPositiveTrainingReference(set: RDD[Pair]): Unit
65 |
66 | /**
67 | * Set testing instances
68 | * @param set Sequence of testing set
69 | */
70 | def setTestReference(set: Seq[Pair]): Unit
71 |
72 | /**
73 | * Set testing instances
74 | * @param set RDD of testing set
75 | */
76 | def setTestReference(set: RDD[Pair]): Unit
77 |
78 | /**
79 | * Fetch weights
80 | *
81 | * @param iter current iteration
82 | */
83 | def fetch(iter: Int): Unit
84 |
85 | /**
86 | * Do mini-batch
87 | */
88 | def batch(): Unit
89 |
90 | /**
91 | * Send update of weights
92 | *
93 | * @param iter current iteration
94 | */
95 | def update(iter: Int): Unit
96 |
97 | /**
98 | * Indicates whether the asynchronous update is finished or not.
99 | *
100 | * @return future object of update
101 | */
102 | def isUpdateFinished: Future[_] = null
103 |
104 | /**
105 | * Non-blocking pending, until all assigned batches are finished
106 | */
107 | def stopUntilBatchFinished(): Unit = {}
108 |
109 | /**
110 | * Implicit weight operation
111 | *
112 | * @param w Sequence of weight to be applied
113 | */
114 | implicit class WeightOp(w: IndexedSeq[ScalarMatrix]) extends Serializable {
115 | /**
116 | * Sugar: Weight update
117 | *
118 | * @param dw A amount of update i.e. __ΔWeight__
119 | */
120 | def -=(dw: IndexedSeq[ScalarMatrix]) = algorithm(dw, w)
121 | }
122 |
123 | }
124 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/train/Trainer.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.train
2 |
3 | import java.text.SimpleDateFormat
4 | import java.util.Date
5 |
6 | import kr.ac.kaist.ir.deep.fn._
7 | import org.apache.log4j.Logger
8 | import org.apache.spark.rdd.RDD
9 |
10 | import scala.annotation.tailrec
11 | import scala.concurrent.Await
12 | import scala.concurrent.duration._
13 |
14 |
15 | /**
16 | * __General__ Trainer Implementation.
17 | *
18 | * This class trains with help of Training Style and Input Operation.
19 | *
20 | * @note This trainer is generalized class. Further implementation, you should see several styles.
21 | * @example
22 | * {{{val net:Network = ...
23 | *
24 | * // Define Manipulation Type. VectorType, AEType, RAEType and URAEType.
25 | * val operation = new VectorType(
26 | * corrupt = GaussianCorruption(variance = 0.1)
27 | * )
28 | *
29 | * // Define Manipulation Type. VectorType, AEType, RAEType, StandardRAEType, URAEType, and StringToVectorType.
30 | * val style = new SingleThreadTrainStyle(
31 | * net = net,
32 | * algorithm = new StochasticGradientDescent(l2decay = 0.0001),
33 | * make = operation,
34 | * param = SimpleTrainingCriteria(miniBatchFraction = 0.01))
35 | *
36 | * // Define Trainer
37 | * val train = new Trainer(
38 | * style = style,
39 | * stops = StoppingCriteria(maxIter = 100000))
40 | *
41 | * // Do Train
42 | * train.train(set, valid)}}}
43 | *
44 | * @note To train an autoencoder, you can provide same training set as validation set.
45 | *
46 | * @param style __Training style__ that supervises how to train. There are two styles,
47 | * one is [[SingleThreadTrainStyle]]
48 | * and the other is [[DistBeliefTrainStyle]].
49 | * @param stops __Stopping Criteria__ that controls the threshold for stopping. (Default : [[StoppingCriteria]])
50 | * @param name Name used for logging.
51 | *
52 | * @tparam IN the type of input.
53 | * Currently, [[kr.ac.kaist.ir.deep.fn.ScalarMatrix]] and DAG are supported
54 | * @tparam OUT the type of output
55 | * Currently, [[kr.ac.kaist.ir.deep.fn.ScalarMatrix]] and Null are supported
56 | */
57 | class Trainer[IN, OUT](val style: TrainStyle[IN, OUT],
58 | val stops: StoppingCriteria = StoppingCriteria(),
59 | val name: String = "Trainer")
60 | extends Serializable {
61 | /** import everything in the style */
62 |
63 | import style._
64 |
65 | @transient private final val dateFormatter = new SimpleDateFormat("MM/dd HH:mm:ss")
66 | /** Logger */
67 | @transient protected val logger = Logger.getLogger(this.getClass)
68 | /** Best Parameter History */
69 | @transient protected var bestParam: IndexedSeq[ScalarMatrix] = null
70 | /** Best Loss Iteration Number */
71 | @transient protected var bestIter: Int = 0
72 | /** Period of validation */
73 | @transient protected var validationPeriod: Int = 0
74 | /** Get command line column width */
75 | @transient protected var columns = try {
76 | System.getenv("COLUMNS").toInt
77 | } catch {
78 | case _: Throwable ⇒ 80
79 | }
80 | /** Finish time of last iteration */
81 | @transient protected var startAt: Long = _
82 |
83 | /**
84 | * Train given sequence, and validate with given sequence.
85 | *
86 | * @param set Full Sequence of training set
87 | * @return Training error (loss)
88 | */
89 | def train(set: Seq[Pair]): (Scalar, Scalar, Scalar) = train(set, set)
90 |
91 | /**
92 | * Train given sequence, and validate with another sequence.
93 | *
94 | * @param set Full Sequence of training set
95 | * @param validation Full Sequence of validation set
96 | * @return Training error (loss)
97 | */
98 | def train(set: Seq[Pair],
99 | validation: Seq[Pair]): (Scalar, Scalar, Scalar) = {
100 | setPositiveTrainingReference(set)
101 | setTestReference(validation)
102 |
103 | validationPeriod = (stops.validationFreq * validationEpoch).toInt
104 |
105 | if (validationPeriod > 0) {
106 | logger info f"($name) Starts training. "
107 | logger info f"($name) Every $validationPeriod%5d (${stops.validationFreq * 100}%6.2f%% of TrainingSet), " +
108 | f"validation process will be submitted."
109 |
110 | saveParams()
111 | val err = lossOfTraining
112 | restoreParams()
113 | printValidation()
114 |
115 | err
116 | } else {
117 | logger warn f"($name) Validation Period is zero! Training stopped."
118 | logger warn f"($name) Maybe because miniBatchFraction value is too large. Please check."
119 | (Float.PositiveInfinity, Float.PositiveInfinity, Float.PositiveInfinity)
120 | }
121 | }
122 |
123 | /**
124 | * Train using given RDD sequence.
125 | *
126 | * @param set RDD of training set
127 | */
128 | def train(set: RDD[Pair]): (Scalar, Scalar, Scalar) = train(set, set)
129 |
130 | /**
131 | * Train using given RDD sequence.
132 | *
133 | * @param set RDD of training set
134 | * @param validation RDD of validation set
135 | */
136 | def train(set: RDD[Pair], validation: RDD[Pair]): (Scalar, Scalar, Scalar) = {
137 | setPositiveTrainingReference(set)
138 | setTestReference(validation)
139 |
140 | validationPeriod = (stops.validationFreq * validationEpoch).toInt
141 |
142 | if (validationPeriod > 0) {
143 | logger info f"($name) Starts training. "
144 | logger info f"($name) Every $validationPeriod%5d (${stops.validationFreq * 100}%6.2f%% of TrainingSet), " +
145 | f"validation process will be submitted."
146 |
147 | saveParams()
148 | val err = lossOfTraining
149 | restoreParams()
150 | printValidation()
151 |
152 | err
153 | } else {
154 | logger warn f"($name) Validation Period is zero! Training stopped."
155 | logger warn f"($name) Maybe because miniBatchFraction value is too large. Please check."
156 | (Float.PositiveInfinity, Float.PositiveInfinity, Float.PositiveInfinity)
157 | }
158 | }
159 |
160 | /**
161 | * Print validation result into logger
162 | */
163 | protected def printValidation() = {
164 | logger info s"($name) BEST ITERATION : $bestIter"
165 | foreachTestSet(5) {
166 | item ⇒ logger info make.stringOf(net, item)
167 | }
168 | }
169 |
170 | /**
171 | * Store best parameters
172 | *
173 | * @param iteration current iteration. (1 iteration = 1 validation freq)
174 | * @param loss previous loss
175 | * @param patience current patience, i.e. loop until at least this epoch.
176 | */
177 | protected final def saveParams(iteration: Int = 0,
178 | loss: Scalar = Float.MaxValue,
179 | patience: Int = validationPeriod * 5) = {
180 | bestParam = net.W.copy
181 | bestIter = iteration
182 | }
183 |
184 | /**
185 | * Restore best parameters
186 | */
187 | protected final def restoreParams() = {
188 | // Wait for finish of update, to prohibit race condition.
189 | if (isUpdateFinished != null) {
190 | try {
191 | Await.ready(isUpdateFinished, 5.minutes)
192 | } catch {
193 | case _: Throwable ⇒
194 | }
195 | }
196 |
197 | net.W := bestParam
198 | }
199 |
200 | /**
201 | * Tail Recursive : Train each batch
202 | *
203 | * @param epoch current iteration epoch. (1 iteration = 1 validation freq)
204 | * @param prevEloss previous loss (Evaluation)
205 | * @param prevWloss previous loss (Weight)
206 | * @param patience current patience, i.e. loop until at least this epoch.
207 | * @return (Evaluation, Weight, Total) Loss when train is finished
208 | */
209 | @tailrec
210 | protected final def trainBatch(epoch: Int = 0,
211 | prevEloss: Scalar = Float.MaxValue,
212 | prevWloss: Scalar = Float.MaxValue,
213 | patience: Int = 5): (Scalar, Scalar, Scalar) = {
214 | fetch(epoch)
215 | batch()
216 | update(epoch)
217 |
218 | var nPatience = patience
219 | val iter = epoch / validationPeriod + 1
220 |
221 | val prevloss = prevEloss + prevWloss
222 | val nLoss = if ((epoch + 1) % validationPeriod == 0) {
223 | // Pending until batch finished
224 | stopUntilBatchFinished()
225 |
226 | val train = validationError()
227 | val weight = algorithm loss net.W
228 | val loss = train + weight
229 | val improvement = if (prevloss > 0f) loss / prevloss else stops.improveThreshold
230 | if (improvement < stops.improveThreshold) {
231 | nPatience = Math.min(Math.max(patience, iter * (stops.waitAfterUpdate + 1)), stops.maxIter)
232 | saveParams(iter, loss, nPatience)
233 |
234 | printProgress(iter, nPatience, train, weight, improved = true)
235 | (train, weight, loss)
236 | } else {
237 | printProgress(iter, nPatience, prevEloss, prevWloss, improved = false)
238 | (prevEloss, prevWloss, prevloss)
239 | }
240 | } else {
241 | (prevEloss, prevWloss, prevloss)
242 | }
243 |
244 | if (iter <= nPatience && (nLoss._3 >= stops.lossThreshold || iter < 5)) {
245 | trainBatch(epoch + 1, nLoss._1, nLoss._2, nPatience)
246 | } else {
247 | if (nLoss._3 < stops.lossThreshold)
248 | logger info f"($name) # $iter%4d/$nPatience%4d, " +
249 | f"FINISHED with E + W = ${nLoss._3}%.5f [Loss < ${stops.lossThreshold}%.5f]"
250 | else if (iter > stops.maxIter)
251 | logger info f"($name) # $iter%4d/$nPatience%4d, " +
252 | f"FINISHED with E + W = ${nLoss._3}%.5f [Iteration > ${stops.maxIter}%6d]"
253 | else if (nPatience < iter)
254 | logger info f"($name) # $iter%4d/$nPatience%4d, " +
255 | f"FINISHED with E + W = ${nLoss._3}%.5f [NoUpdate after $bestIter%6d]"
256 |
257 | nLoss
258 | }
259 | }
260 |
261 | private def printProgress(iter: Int, patience: Int, eLoss: Float, wLoss: Float, improved: Boolean) = {
262 | val wait = patience / stops.maxIter.toFloat
263 | val header = f"\033[4m$name\033[24m $iter%4d/$patience%4d \033[0m["
264 | val impr = if (improved) "IMPROVED" else f" @ $bestIter%4d "
265 | val footer = f" E + W = $eLoss%7.5f + $wLoss%7.5f $impr"
266 |
267 | val buf = new StringBuilder(s"\033[2A\033[${columns}D\033[2K \033[1;33m$header\033[46;36m")
268 | val total = columns - header.length - footer.length + 10
269 | val len = Math.floor(wait * total).toInt
270 | val step = Math.floor(iter / stops.maxIter.toFloat * total).toInt
271 | buf.append(" " * step)
272 | buf.append("\033[49m")
273 | buf.append(" " * (len - step))
274 | buf.append("\033[0m]\033[34m")
275 | if (total > len) buf.append(s"\033[${total - len}C")
276 | buf.append(s"$footer\033[0m")
277 |
278 | val now = System.currentTimeMillis()
279 | val remainA = (now - startAt) / iter * patience
280 | val etaA = startAt + remainA
281 | val calA = dateFormatter.format(new Date(etaA))
282 | val remainB = (now - startAt) / iter * stops.maxIter
283 | val etaB = startAt + remainB
284 | val calB = dateFormatter.format(new Date(etaB))
285 |
286 | buf.append(f"\n\033[2K Estimated Finish Time : $calA \t ~ $calB")
287 |
288 | println(buf.result())
289 | }
290 |
291 | /**
292 | * Do actual training process
293 | * @return MSE of the training process
294 | */
295 | private def lossOfTraining: (Scalar, Scalar, Scalar) =
296 | if (param.miniBatchFraction > 0) {
297 | println("Start training...\n Estimated Time: NONE")
298 | startAt = System.currentTimeMillis()
299 | trainBatch()
300 | } else {
301 | fetch(0)
302 | batch()
303 | update(0)
304 |
305 | val train = validationError()
306 | val weight = algorithm loss net.W
307 | val loss = train + weight
308 | saveParams(0, loss, 0)
309 |
310 | logger info f"($name) PASSONCE, E + W = $train%.5f + $weight%.5f = $loss%.5f"
311 | (train, weight, loss)
312 | }
313 |
314 | }
315 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/train/TrainingCriteria.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.train
2 |
3 | /**
4 | * __Trait__ that describes Training Criteria
5 | */
6 | trait TrainingCriteria extends Serializable {
7 | /** Size of mini-batch.
8 | * If below or equal to zero, then this indicates no batch training (i.e. just go through once.) */
9 | val miniBatchFraction: Float
10 | }
11 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/train/TreeType.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.train
2 |
3 | import kr.ac.kaist.ir.deep.fn._
4 | import kr.ac.kaist.ir.deep.network.Network
5 | import kr.ac.kaist.ir.deep.rec._
6 |
7 | /**
8 | * __Trait of Input Operation__ : VectorTree as Input. This is an '''Abstract Implementation'''
9 | */
10 | trait TreeType extends ManipulationType[BinaryTree, Null] {
11 |
12 | /**
13 | * Corrupt input
14 | *
15 | * @param x input to be corrupted
16 | * @return corrupted input
17 | */
18 | override def corrupted(x: BinaryTree): BinaryTree = (x through corrupt).asInstanceOf[BinaryTree]
19 |
20 | /**
21 | * Apply given single input as one-way forward trip.
22 | *
23 | * @param net A network that gets input
24 | * @param x input to be computed
25 | * @return output of the network.
26 | */
27 | override def onewayTrip(net: Network, x: BinaryTree): ScalarMatrix =
28 | x forward net.of
29 | }
30 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/train/URAEType.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.train
2 |
3 | import kr.ac.kaist.ir.deep.fn._
4 | import kr.ac.kaist.ir.deep.network.{AutoEncoder, Network}
5 | import kr.ac.kaist.ir.deep.rec.BinaryTree
6 | import org.apache.spark.annotation.Experimental
7 |
8 | /**
9 | * __Input Operation__ : VectorTree as Input & Unfolding Recursive Auto Encoder Training (no output type)
10 | *
11 | * ::Experimental::
12 | * @note This cannot be applied into non-AutoEncoder tasks
13 | * @note This is designed for Unfolding RAE, in
14 | * [[http://ai.stanford.edu/~ang/papers/nips11-DynamicPoolingUnfoldingRecursiveAutoencoders.pdf this paper]]
15 | *
16 | * @param corrupt Corruption that supervises how to corrupt the input matrix. `(Default : [[kr.ac.kaist.ir.deep.train.NoCorruption]])`
17 | * @param error An objective function `(Default: [[kr.ac.kaist.ir.deep.fn.SquaredErr]])`
18 | *
19 | * @example
20 | * {{{var make = new URAEType(error = CrossEntropyErr)
21 | * var corruptedIn = make corrupted in
22 | * var out = make onewayTrip (net, corruptedIn)}}}
23 | */
24 | @Experimental
25 | class URAEType(override val corrupt: Corruption = NoCorruption,
26 | override val error: Objective = SquaredErr)
27 | extends TreeType {
28 |
29 | /**
30 | * Apply & Back-prop given single input
31 | *
32 | * @param net A network that gets input
33 | * @param delta Sequence of delta updates
34 | */
35 | def roundTrip(net: Network, delta: Seq[ScalarMatrix]) = (in: BinaryTree, real: Null) ⇒
36 | net match {
37 | case net: AutoEncoder ⇒
38 | val out = in forward net.encode
39 |
40 | // Decode phrase of reconstruction
41 | var terminals = in.backward(out, net.decode)
42 | while (terminals.nonEmpty) {
43 | val leaf = terminals.head
44 | terminals = terminals.tail
45 |
46 | leaf.out = error.derivative(leaf.out, leaf.x)
47 | }
48 |
49 | // Error propagation for decoder
50 | val err = in forward net.decode_!(delta.take(2).toIterator)
51 |
52 | // Error propagation for encoder
53 | in backward(err, net.encode_!(delta.takeRight(2).toIterator))
54 | }
55 |
56 |
57 | /**
58 | * Apply given input and compute the error
59 | *
60 | * @param net A network that gets input
61 | * @param pair (Input, Real output) for error computation.
62 | * @return error of this network
63 | */
64 | def lossOf(net: Network)(pair: (BinaryTree, Null)): Scalar =
65 | net match {
66 | case net: AutoEncoder ⇒
67 | var sum = 0.0f
68 | val in = pair._1
69 | // Encode phrase of Reconstruction
70 | val out = in forward net.apply
71 |
72 | // Decode phrase of reconstruction
73 | var terminals = in.backward(out, net.reconstruct)
74 | val size = terminals.size
75 | while (terminals.nonEmpty) {
76 | val leaf = terminals.head
77 | terminals = terminals.tail
78 | sum += error(leaf.out, leaf.x)
79 | }
80 | sum
81 | case _ ⇒ 0.0f
82 | }
83 |
84 |
85 | /**
86 | * Make validation output
87 | *
88 | * @return input as string
89 | */
90 | def stringOf(net: Network, pair: (BinaryTree, Null)): String =
91 | net match {
92 | case net: AutoEncoder ⇒
93 | val string = StringBuilder.newBuilder
94 | val in = pair._1
95 | // Encode phrase of Reconstruction
96 | val out = in forward net.apply
97 |
98 | // Decode phrase of reconstruction
99 | var terminals = in.backward(out, net.reconstruct)
100 | while (terminals.nonEmpty) {
101 | val leaf = terminals.head
102 | terminals = terminals.tail
103 |
104 | string append s"IN: ${leaf.x.mkString} URAE → OUT: ${leaf.out.mkString};"
105 | }
106 | string.mkString
107 | case _ ⇒ "NOT AN AUTOENCODER"
108 | }
109 | }
110 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/train/VectorType.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.train
2 |
3 | import breeze.linalg.any
4 | import kr.ac.kaist.ir.deep.fn._
5 | import kr.ac.kaist.ir.deep.network.Network
6 |
7 | /**
8 | * __Input Operation__ : Vector as Input and output
9 | *
10 | * @param corrupt Corruption that supervises how to corrupt the input matrix. (Default : [[NoCorruption]])
11 | * @param error An objective function (Default: [[kr.ac.kaist.ir.deep.fn.SquaredErr]])
12 | *
13 | * @example
14 | * {{{var make = new VectorType(error = CrossEntropyErr)
15 | * var corruptedIn = make corrupted in
16 | * var out = make onewayTrip (net, corruptedIn)}}}
17 | */
18 | class VectorType(override val corrupt: Corruption = NoCorruption,
19 | override val error: Objective = SquaredErr)
20 | extends ManipulationType[ScalarMatrix, ScalarMatrix] {
21 |
22 | /**
23 | * Corrupt input
24 | *
25 | * @param x input to be corrupted
26 | * @return corrupted input
27 | */
28 | override def corrupted(x: ScalarMatrix): ScalarMatrix = corrupt(x)
29 |
30 | /**
31 | * Apply & Back-prop given single input
32 | *
33 | * @param net A network that gets input
34 | * @param delta Sequence of delta updates
35 | */
36 | def roundTrip(net: Network, delta: Seq[ScalarMatrix]) = (in: ScalarMatrix, real: ScalarMatrix) ⇒ {
37 | val out = net passedBy in
38 | val err: ScalarMatrix = error.derivative(real, out)
39 | net updateBy(delta.toIterator, err)
40 | }
41 |
42 | /**
43 | * Apply given input and compute the error
44 | *
45 | * @param net A network that gets input
46 | * @param pair (Input, Real output) for error computation.
47 | * @return error of this network
48 | */
49 | override def lossOf(net: Network)(pair: (ScalarMatrix, ScalarMatrix)): Scalar = {
50 | val in = pair._1
51 | val real = pair._2
52 | val out = net of in
53 | error(real, out)
54 | }
55 |
56 | /**
57 | * Apply given single input as one-way forward trip.
58 | *
59 | * @param net A network that gets input
60 | * @param x input to be computed
61 | * @return output of the network.
62 | */
63 | override def onewayTrip(net: Network, x: ScalarMatrix): ScalarMatrix = net of x
64 |
65 | /**
66 | * Make validation output
67 | *
68 | * @return input as string
69 | */
70 | def stringOf(net: Network, pair: (ScalarMatrix, ScalarMatrix)): String = {
71 | val in = pair._1
72 | val real = pair._2
73 | val out = net of in
74 | s"IN: ${in.mkString} EXP: ${real.mkString} → OUT: ${out.mkString}"
75 | }
76 |
77 | /**
78 | * Check whether given two are same or not.
79 | * @param x Out-type object
80 | * @param y Out-type object
81 | * @return True if they are different.
82 | */
83 | override def different(x: ScalarMatrix, y: ScalarMatrix): Boolean = any(x :!= y)
84 | }
85 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/train/package.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep
2 |
3 | import breeze.stats.distributions.Gaussian
4 | import kr.ac.kaist.ir.deep.fn._
5 | import org.apache.spark.AccumulatorParam
6 | import org.apache.spark.storage.StorageLevel
7 |
8 | import scala.annotation.tailrec
9 | import scala.concurrent.duration._
10 |
11 | /**
12 | * Package for training.
13 | */
14 | package object train {
15 |
16 | /** Type of Corruption */
17 | trait Corruption extends (ScalarMatrix ⇒ ScalarMatrix) with Serializable
18 |
19 | /**
20 | * __Input Corruption__: Drop input as zero.
21 | *
22 | * If network uses drop-out training, we recommend that you do not use this.
23 | *
24 | * @note If the presence probability is `P%`, then this corruption leaves `P%` entries of the matrix
25 | *
26 | * @param presence probability of __not-dropped__. `(default 95% = 0.95)`
27 | *
28 | * @example
29 | * {{{var corrupt = DroppingCorruption(presence = 0.99)
30 | * var corrupted = corrupt(vector)}}}
31 | */
32 | case class DroppingCorruption(presence: Float = 0.95f) extends Corruption {
33 | /**
34 | * Do corruption
35 | *
36 | * @param v1 Matrix to be corrupted
37 | * @return corrupted vector
38 | */
39 | override def apply(v1: ScalarMatrix): ScalarMatrix =
40 | v1 mapValues { x ⇒ if (Math.random() > presence) 0.0f else x}
41 | }
42 |
43 | /**
44 | * __Input Corruption__: Gaussian
45 | *
46 | * @param mean __Mean__ of noise `(default 0.0)`
47 | * @param variance __Variance__ of noise `(default 0.1)`
48 | *
49 | * @example
50 | * {{{var corrupt = GaussianCorruption(variance = 0.1)
51 | * var corrupted = corrupt(vector)}}}
52 | */
53 | case class GaussianCorruption(mean: Double = 0.0, variance: Double = 0.1) extends Corruption {
54 | /**
55 | * Gaussian Distribution
56 | */
57 | private lazy val distro = Gaussian distribution(mean, variance)
58 |
59 | /**
60 | * Do corruption
61 | *
62 | * @param v1 Matrix to be corrupted
63 | * @return corrupted vector
64 | */
65 | override def apply(v1: ScalarMatrix): ScalarMatrix =
66 | v1 mapValues { x ⇒ x + distro.draw().toFloat}
67 | }
68 |
69 | /**
70 | * __Criteria__: When to stop training
71 | *
72 | * This case class defines when to stop training. Training stops if one of the following condition is satisfied.
73 | *
74 | - #Iteration ≥ maxIter
75 | - #Iteration ≥ current patience value, which is calculated by `max(patience, bestIteration * patienceStep)`
76 | - Amount of loss < lossThreshold
77 | *
78 | * Validation is done for each `validationFreq` iterations,
79 | * and whenever current/best loss ratio below improveThreshold,
80 | * that iteration is marked as best iteration.
81 | *
82 | * @param maxIter __maximum mini-batch__ iteration count `(default 100,000)`
83 | * @param waitAfterUpdate __multiplier__ for calculating patience `(default 1 := Wait lastupdate# * 1 after update)`
84 | * @param improveThreshold __threshold__ that iteration is marked as "improved" `(default 99.5% = 0.995)`
85 | * @param lossThreshold __maximum-tolerant__ loss value. `(default 0.0001)`
86 | * @param validationFreq __multiplier__ used for count for validation. `(default 1.0f)`
87 | * Validation checked whenever (validationFreq) * (#epoch for 1 training batch).
88 | * where #epoch for 1 iteration = round(1 / miniBatchFraction).
89 | */
90 | case class StoppingCriteria(maxIter: Int = 100000,
91 | waitAfterUpdate: Int = 1,
92 | improveThreshold: Float = 0.995f,
93 | lossThreshold: Float = 0.0001f,
94 | validationFreq: Float = 1.0f)
95 | extends Serializable
96 |
97 | /**
98 | * __Criteria__: How to train (for [[SingleThreadTrainStyle]])
99 | *
100 | * This case class defines how to train the network. Training parameter is defined in this class.
101 | *
102 | * @param miniBatchFraction size of __mini-batch__ `(default 0.01 = 1%)`
103 | * If below or equal to zero, then this indicates no batch training (i.e. just go through once.)
104 | */
105 | case class SimpleTrainingCriteria(override val miniBatchFraction: Float = 0.01f) extends TrainingCriteria
106 |
107 | /**
108 | * __Criteria__: How to train (for [[DistBeliefTrainStyle]])
109 | *
110 | * This case class defines how to train the network. Training parameter is defined in this class.
111 | *
112 | * @param miniBatchFraction size of __mini-batch__ `(default 0.01 = 1%)`
113 | * If below or equal to zero, then this indicates no batch training (i.e. just go through once.)
114 | * @param submitInterval Time interval between batch submission. `(default 1.minute)`
115 | * @param updateStep number of __mini-batches__ between update `(default 2)`
116 | * @param fetchStep number of __mini-batches__ between fetching `(default 10)`
117 | * @param numCores number of __v-cores__ in the spark cluster. `(default 1)`
118 | * @param repartitionOnStart true if do repartition when define training/testing RDD instances. `(default false)`
119 | * @param storageLevel StorageLevel that will be used in Spark. `(default DISK_ONLY_2)`
120 | *
121 | * @note We recommend set numCores as similar as possible with allocated spark v-cores.
122 | */
123 | case class DistBeliefCriteria(override val miniBatchFraction: Float = 0.01f,
124 | submitInterval: Duration = 30.seconds,
125 | updateStep: Int = 2,
126 | fetchStep: Int = 10,
127 | numCores: Int = 1,
128 | repartitionOnStart: Boolean = true,
129 | storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) extends TrainingCriteria
130 |
131 | /**
132 | * Accumulator Param object for DistBelief Train Style.
133 | */
134 | implicit object WeightAccumulator extends AccumulatorParam[IndexedSeq[ScalarMatrix]] {
135 | /**
136 | * Add in place function
137 | * @param r1 left hand side
138 | * @param r2 right hand side
139 | * @return r1 + r2 in r1
140 | */
141 | override def addInPlace(r1: IndexedSeq[ScalarMatrix], r2: IndexedSeq[ScalarMatrix]): IndexedSeq[ScalarMatrix] = {
142 | r1 :+= r2
143 | }
144 |
145 | /**
146 | * Zero value
147 | * @param initialValue initial value
148 | * @return initial zero value.
149 | */
150 | override def zero(initialValue: IndexedSeq[ScalarMatrix]): IndexedSeq[ScalarMatrix] =
151 | initialValue.map {
152 | matx ⇒
153 | ScalarMatrix $0(matx.rows, matx.cols)
154 | }
155 | }
156 |
157 | /**
158 | * Non-blocking await
159 | */
160 | object AsyncAwait extends Serializable {
161 |
162 | import scala.concurrent.ExecutionContext.Implicits.global
163 | import scala.concurrent._
164 |
165 | /**
166 | * Tail-recursive version of non-block pending
167 | * @param f Future object to wait
168 | * @param interval Duration object specifying waiting time.
169 | */
170 | @tailrec
171 | final def ready(f: Future[_], interval: Duration): Unit = try {
172 | Await.ready(f, interval)
173 | } catch {
174 | case _: TimeoutException ⇒ ready(f, interval)
175 | }
176 |
177 | /**
178 | * Tail-recursive version of non-block pending
179 | * @param interval Duration object specifying waiting time.
180 | * @param f Future objects to wait
181 | */
182 | final def readyAll(interval: Duration, f: Future[Any]*): Unit =
183 | ready(Future.sequence(f.seq), interval)
184 | }
185 |
186 | /**
187 | * __Input Corruption__: Never corrupts input
188 | *
189 | * @example
190 | * {{{var corrupt = NoCorruption(variance = 0.1)
191 | * var corrupted = corrupt(vector)}}}
192 | */
193 | case object NoCorruption extends Corruption {
194 |
195 | /**
196 | * Identity.
197 | * @param v1 to be corrupted
198 | * @return the vector
199 | */
200 | override def apply(v1: ScalarMatrix) = v1
201 | }
202 |
203 | }
204 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/wordvec/PrepareCorpus.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.wordvec
2 |
3 | import java.util
4 |
5 | import org.apache.spark.broadcast.Broadcast
6 | import org.apache.spark.rdd.RDD
7 | import org.apache.spark.storage.StorageLevel
8 | import org.apache.spark.{Logging, SparkConf, SparkContext}
9 | import org.apache.log4j._
10 |
11 | import scala.collection.JavaConversions._
12 |
13 | /**
14 | * Train Word2Vec and save the model.
15 | */
16 | object PrepareCorpus extends Logging {
17 | {
18 | // Initialize Network Logging
19 | val PATTERN = "%d{yy/MM/dd HH:mm:ss} %p %C{2}: %m%n"
20 | val orgFile = new RollingFileAppender(new PatternLayout(PATTERN), "spark.log")
21 | orgFile.setMaxFileSize("1MB")
22 | orgFile.setMaxBackupIndex(5)
23 | val root = Logger.getRootLogger
24 | root.addAppender(orgFile)
25 | root.setLevel(Level.WARN)
26 | root.setAdditivity(false)
27 | val krFile = new RollingFileAppender(new PatternLayout(PATTERN), "trainer.log")
28 | krFile.setMaxFileSize("1MB")
29 | krFile.setMaxBackupIndex(10)
30 | val kr = Logger.getLogger("kr.ac")
31 | kr.addAppender(krFile)
32 | kr.setLevel(Level.INFO)
33 | }
34 |
35 | /**
36 | * Main thread.
37 | * @param args CLI arguments
38 | */
39 | def main(args: Array[String]) =
40 | if (args.length == 0 || args.contains("--help") || args.contains("-h")) {
41 | println(
42 | """Tokenize sentences, and Collect several types of unknown words.
43 | |
44 | |== Arguments without default ==
45 | | -i Path of input corpora file.
46 | | -o Path of tokenized output text file.
47 | |
48 | |== Arguments with default ==
49 | | --srlz Local Path of Serialized Language Filter file. (Default: filter.dat)
50 | | --thre Minimum include count. (Default: 3)
51 | | --part Number of partitios. (Default: organized by Spark)
52 | | --lang Accepted Language Area of Unicode. (Default: \\\\u0000-\\\\u007f)
53 | | For Korean: 가-힣|\\\\u0000-\\\\u007f
54 | |
55 | |== Additional Arguments ==
56 | | --help Display this help message.
57 | | """.stripMargin)
58 | } else {
59 | // Set spark context
60 | val conf = new SparkConf()
61 | .setAppName("Normalize Infrequent words")
62 | .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
63 | .set("spark.scheduler.mode", "FAIR")
64 | .set("spark.shuffle.memoryFraction", "0.05")
65 | .set("spark.storage.unrollFraction", "0.05")
66 | .set("spark.storage.memoryFraction", "0.9")
67 | .set("spark.broadcast.blockSize", "40960")
68 | .set("spark.akka.frameSize", "50")
69 | .set("spark.locality.wait", "10000")
70 | val sc = new SparkContext(conf)
71 | sc.setLocalProperty("spark.scheduler.pool", "production")
72 |
73 | val langArea = getArgument(args, "--lang", "\\u0000-\\u007f")
74 | val langFilter = LangFilter(langArea)
75 | val bcFilter = sc.broadcast(langFilter)
76 | langFilter.saveAs(getArgument(args, "--srlz", "filter.dat"))
77 | logInfo(s"Language filter created : $langArea")
78 |
79 | // read file
80 | val in = getArgument(args, "-i", "article.txt")
81 | val parts = getArgument(args, "--part", "1").toInt
82 | val lines = sc.textFile(in, parts).filter(_.trim.nonEmpty)
83 | val tokens = tokenize(lines, bcFilter)
84 |
85 | val threshold = getArgument(args, "--thre", "3").toInt
86 | val infreqWords = infrequentWords(tokens.flatMap(x ⇒ x), threshold)
87 | val infreqSet = sc.broadcast(infreqWords)
88 |
89 | val out = getArgument(args, "-o", "article-preproc.txt")
90 | normalizedTokens(tokens, infreqSet).saveAsTextFile(out)
91 |
92 | // Stop the context
93 | sc.stop()
94 | }
95 |
96 | /**
97 | * Read argument
98 | * @param args Argument Array
99 | * @param key Argument Key
100 | * @param default Default value of this argument
101 | * @return Value of this key.
102 | */
103 | def getArgument(args: Array[String], key: String, default: String) = {
104 | val idx = args.indexOf(key)
105 | if (idx < 0 || idx > args.length - 1) default
106 | else args(idx + 1)
107 | }
108 |
109 | /**
110 | * Collect frequent words with count >= Threshold
111 | * @param words Word seq.
112 | * @return HashSet of frequent words.
113 | */
114 | def infrequentWords(words: RDD[String], threshold: Int) = {
115 | val counts = words.countByValue()
116 | val above = counts.count(_._2 >= threshold)
117 | val set = counts.filter(_._2 < threshold).keySet
118 | val value = new util.HashSet[String]()
119 | value ++= set
120 |
121 | val all = above + set.size
122 | val ratio = Math.round(set.size.toFloat / all * 100)
123 | logInfo(s"Total $all distinct words, ${set.size} words($ratio%) will be discarded.")
124 |
125 | value
126 | }
127 |
128 | /**
129 | * Convert input into tokenized string, using Stanford NLP toolkit.
130 | * @param lines Input lines
131 | * @return tokenized & normalized lines.
132 | */
133 | def tokenize(lines: RDD[String], bcFilter: Broadcast[_ <: WordFilter]) =
134 | lines.map(bcFilter.value.tokenize).persist(StorageLevel.DISK_ONLY_2)
135 |
136 | /**
137 | * Convert tokenized string into a sentence, with appropriate conversion of (Threshold - 1) count word.
138 | * @param input Tokenized input sentence
139 | * @param infreqSet Less Frequent words
140 | * @return Tokenized converted sentence
141 | */
142 | def normalizedTokens(input: RDD[_ <: Seq[String]], infreqSet: Broadcast[util.HashSet[String]]) =
143 | input.mapPartitions {
144 | lazy val set = infreqSet.value
145 |
146 | _.map {
147 | seq ⇒
148 | val it = seq.iterator
149 | val buf = StringBuilder.newBuilder
150 |
151 | while(it.hasNext){
152 | val word = it.next()
153 | if (set contains word){
154 | buf.append(WordModel.OTHER_UNK)
155 | }else{
156 | buf.append(word)
157 | }
158 | buf.append(' ')
159 | }
160 |
161 | buf.result()
162 | }
163 | }
164 | }
165 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/wordvec/StringToVectorType.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.wordvec
2 |
3 | import kr.ac.kaist.ir.deep.fn._
4 | import kr.ac.kaist.ir.deep.network.Network
5 | import org.apache.spark.broadcast.Broadcast
6 |
7 | /**
8 | * __Input Operation__ : String as Input & ScalarMatrix as Otput __(Spark ONLY)__
9 | *
10 | * @param model Broadcast of WordEmbedding model that contains all meaningful words.
11 | * @param error An objective function `(Default: [[kr.ac.kaist.ir.deep.fn.SquaredErr]])`
12 | *
13 | * @example
14 | * {{{var make = new StringToVectorType(model = wordModel, error = CrossEntropyErr)
15 | * var out = make onewayTrip (net, in)}}}
16 | */
17 | class StringToVectorType(protected override val model: Broadcast[WordModel],
18 | override val error: Objective) extends StringType[ScalarMatrix] {
19 | /**
20 | * Apply & Back-prop given single input
21 | *
22 | * @param net A network that gets input
23 | * @param delta Sequence of delta updates
24 | */
25 | def roundTrip(net: Network, delta: Seq[ScalarMatrix]) = (in: String, real: ScalarMatrix) ⇒ {
26 | val out = net.passedBy(model.value(in))
27 | val err: ScalarMatrix = error.derivative(real, out)
28 | net updateBy(delta.toIterator, err)
29 | }
30 |
31 | /**
32 | * Make validation output
33 | *
34 | * @param net A network that gets input
35 | * @param pair (Input, Real output) pair for computation
36 | * @return input as string
37 | */
38 | override def stringOf(net: Network, pair: (String, ScalarMatrix)): String = {
39 | val in = pair._1
40 | val real = pair._2
41 | val out = net of model.value(in)
42 | s"IN: $in EXP: ${real.mkString} → OUT: ${out.mkString}"
43 | }
44 |
45 | /**
46 | * Apply given input and compute the error
47 | *
48 | * @param net A network that gets input
49 | * @param pair (Input, Real output) for error computation.
50 | * @return error of this network
51 | */
52 | override def lossOf(net: Network)(pair: (String, ScalarMatrix)): Scalar = {
53 | val in = pair._1
54 | val real = pair._2
55 | val out = net of model.value(in)
56 | error(real, out)
57 | }
58 | }
59 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/wordvec/StringType.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep.wordvec
2 |
3 | import kr.ac.kaist.ir.deep.fn._
4 | import kr.ac.kaist.ir.deep.network.Network
5 | import kr.ac.kaist.ir.deep.train.{Corruption, ManipulationType}
6 | import org.apache.spark.broadcast.Broadcast
7 |
8 | /**
9 | * __Trait of Input Operation__ : String as Input. This is an '''Abstract Implementation'''
10 | *
11 | * @tparam OUT Output type.
12 | */
13 | trait StringType[OUT] extends ManipulationType[String, OUT] {
14 | override val corrupt: Corruption = null
15 | protected val model: Broadcast[WordModel]
16 |
17 | /**
18 | * Corrupt input : No corruption for string.
19 | *
20 | * @param x input to be corrupted
21 | * @return corrupted input
22 | */
23 | override def corrupted(x: String): String = x
24 |
25 | /**
26 | * Apply given single input as one-way forward trip.
27 | *
28 | * @param net A network that gets input
29 | * @param x input to be computed
30 | * @return output of the network.
31 | */
32 | override def onewayTrip(net: Network, x: String): ScalarMatrix =
33 | net of model.value(x)
34 | }
35 |
--------------------------------------------------------------------------------
/src/main/scala/kr/ac/kaist/ir/deep/wordvec/package.scala:
--------------------------------------------------------------------------------
1 | package kr.ac.kaist.ir.deep
2 |
3 | import java.io.{ObjectInputStream, ObjectOutputStream}
4 | import java.util
5 | import java.util.regex.Pattern
6 |
7 | import kr.ac.kaist.ir.deep.fn._
8 | import org.apache.log4j.Logger
9 |
10 | import scala.collection.JavaConversions._
11 | import scala.collection.mutable
12 | import scala.io.Codec
13 | import scala.reflect.io.{File, Path}
14 |
15 | /**
16 | * Package for WordEmbedding training __(Unstable)__
17 | */
18 | package object wordvec {
19 |
20 | /** Pattern for real number **/
21 | final val PATTERN_REAL = Pattern.compile("^[0-9]+\\.[0-9]+$", Pattern.UNICODE_CHARACTER_CLASS)
22 | final val PATTERN_REAL_WITHIN = Pattern.compile("\\s+[0-9]+\\.[0-9]+\\s+", Pattern.UNICODE_CHARACTER_CLASS)
23 | /** Pattern for integer **/
24 | final val PATTERN_INTEGER = Pattern.compile("^[0-9]+$", Pattern.UNICODE_CHARACTER_CLASS)
25 | /** Pattern for Punctuation **/
26 | final val PATTERN_PUNCT = Pattern.compile("(\\p{Punct})", Pattern.UNICODE_CHARACTER_CLASS)
27 | /** Pattern for Special Range **/
28 | final val PATTERN_SPECIAL = Pattern.compile("^≪[A-Z]+≫$", Pattern.UNICODE_CHARACTER_CLASS)
29 |
30 | /**
31 | * Word Filter type
32 | */
33 | trait WordFilter extends (String ⇒ String) with Serializable {
34 | /**
35 | * Tokenize given string using this filter
36 | * @param str String for tokenize
37 | * @return Array of tokens
38 | */
39 | def tokenize(str: String): mutable.WrappedArray[String]
40 |
41 | /**
42 | * Save this filter into given path
43 | * @param path Path to save.
44 | */
45 | def saveAs(path: Path): this.type = saveAs(File(path))
46 |
47 | /**
48 | * Save this filter into given file
49 | * @param file File to save.
50 | */
51 | def saveAs(file: File): this.type = {
52 | val oos = new ObjectOutputStream(file.outputStream())
53 | oos.writeObject(this)
54 | oos.close()
55 | this
56 | }
57 | }
58 |
59 | /**
60 | * __WordFilter__ : Filter class for take only specific language area.
61 | * @param langFilter Regular Expression String indicating accepted Unicode area.
62 | */
63 | case class LangFilter(langFilter: String) extends WordFilter{
64 | val langPattern = Pattern.compile(s"[^$langFilter\\p{Punct}]+", Pattern.UNICODE_CHARACTER_CLASS)
65 |
66 | def tokenize(str: String): mutable.WrappedArray[String] = {
67 | val withReal = PATTERN_REAL_WITHIN.matcher(s" $str ")
68 | .replaceAll(" ≪REALNUM≫ ").trim()
69 | PATTERN_PUNCT.matcher(withReal).replaceAll(" $1 ").split("\\s+")
70 | .transform(apply)
71 | }
72 |
73 | /**
74 | * Normalize words
75 | * @param word Word String to be normalized
76 | * @return Normalized word string.
77 | */
78 | def apply(word: String) =
79 | if (PATTERN_SPECIAL.matcher(word).find()){
80 | // Remain those functional words.
81 | word
82 | } else if (PATTERN_REAL.matcher(word).find()) {
83 | "≪REALNUM≫"
84 | } else if (PATTERN_INTEGER.matcher(word).find()) {
85 | "≪NUMBERS≫"
86 | } else if (langPattern.matcher(word).find()) {
87 | "≪FOREIGN≫"
88 | } else
89 | word
90 | }
91 |
92 | /**
93 | * Word2Vec model class.
94 | * @param map Mapping between String to Array[Coord]
95 | */
96 | class WordModel(val map: util.HashMap[String, Array[Scalar]]) extends Serializable with (String ⇒ ScalarMatrix) {
97 | private final val OTHER_VEC = map(WordModel.OTHER_UNK)
98 | lazy val vectorSize = map.head._2.length
99 | private var filter: WordFilter = LangFilter("\\u0000-\\u007f")
100 |
101 | /**
102 | * Load Word Filter
103 | * @param path Path where Serialized Filter saved
104 | */
105 | def loadFilter(path: Path): this.type = loadFilter(File(path))
106 |
107 | /**
108 | * Load Word Filter
109 | * @param file File where Serialized Filter saved
110 | */
111 | def loadFilter(file: File): this.type = {
112 | if (file.exists && file.isFile) {
113 | val ois = new ObjectInputStream(file.inputStream())
114 | val filter = ois.readObject().asInstanceOf[WordFilter]
115 | ois.close()
116 | setFilter(filter)
117 | }
118 |
119 | this
120 | }
121 |
122 | /**
123 | * Set Word Filter
124 | * @param newFilter Filter to be set
125 | */
126 | def setFilter(newFilter: WordFilter) = {
127 | filter = newFilter
128 | }
129 |
130 | /**
131 | * Get Matrix(Vector) of given word
132 | * @param word Word string for search
133 | * @return Column Vector of given word
134 | */
135 | def apply(word: String) = {
136 | val vec = map.getOrDefault(filter(word), OTHER_VEC)
137 | ScalarMatrix(vec:_*)
138 | }
139 |
140 | /**
141 | * Tokenize given string using word filter
142 | * @param str String to tokenize
143 | * @return Tokenized string (WrappedArray)
144 | */
145 | def tokenize(str: String) = filter.tokenize(str)
146 |
147 | /**
148 | * Tokenize given string and take average vector of them
149 | * @param str String to compute
150 | * @return Average word embedding of given string.
151 | */
152 | def tokenizeAndApply(str: String):ScalarMatrix = {
153 | val array = filter.tokenize(str)
154 | val len = array.length
155 | val res = ScalarMatrix $0 (vectorSize, 1)
156 | var i = len
157 | while(i > 0){
158 | i -= 1
159 | val vec = map.getOrDefault(array(i), OTHER_VEC)
160 | var d = vectorSize
161 | while(d > 0){
162 | d -= 1
163 | res(d, 0) += vec(d) / len.toFloat
164 | }
165 | }
166 |
167 | res
168 | }
169 |
170 | /**
171 | * Check existance of given word
172 | * @param word Word string for search
173 | * @return True if it is in the list
174 | */
175 | def contains(word: String) = map.containsKey(filter(word))
176 |
177 | /**
178 | * Write model into given path.
179 | * @param path Path where to store.
180 | */
181 | def saveAs(path: Path): Unit = saveAs(File(path))
182 |
183 | /**
184 | * Write model into given file.
185 | * @param file File where to store
186 | */
187 | def saveAs(file: File): Unit = {
188 | val bw = file.bufferedWriter(append = false, codec = Codec.UTF8)
189 | map.foreach {
190 | case (word, vec) ⇒
191 | bw.write(s"$word\t")
192 | val str = vec.map {
193 | v ⇒ f"$v%.8f"
194 | }.mkString(" ")
195 | bw.write(str)
196 | }
197 | bw.close()
198 | }
199 | }
200 |
201 | /**
202 | * Companion object of [[WordModel]]
203 | */
204 | object WordModel extends Serializable {
205 | final val OTHER_UNK = "≪UNKNOWN≫"
206 | val logger = Logger.getLogger(this.getClass)
207 |
208 | /**
209 | * Restore Word Model from Path.
210 | * @param path Path of word model file.
211 | * @param normalize True if you want vectors are normalized by longest length of vector.
212 | * @return WordModel restored from file.
213 | */
214 | def apply(path: Path, normalize: Boolean): WordModel = apply(File(path), normalize)
215 |
216 | /**
217 | * Restore WordModel from File.
218 | * @param file File where to read
219 | * @param normalize True if you want vectors are normalized by longest length of vector.
220 | * @return WordModel restored from file.
221 | */
222 | def apply(file: File, normalize: Boolean = false): WordModel = {
223 | val path = file.path + (if (normalize) ".norm.obj" else ".orig.obj")
224 | if (File(path).exists) {
225 | val in = new ObjectInputStream(File(path).inputStream())
226 | val model = in.readObject().asInstanceOf[WordModel]
227 | in.close()
228 |
229 | logger info "READ Word2Vec finished."
230 | model
231 | } else {
232 | val br = file.bufferedReader(Codec.UTF8)
233 | val firstLine = br.readLine().split("\\s+")
234 | val mapSize = firstLine(0).toInt
235 | val vectorSize = firstLine(1).toInt
236 |
237 | val buffer = new util.HashMap[String, Array[Scalar]]()
238 | var lineNo = mapSize
239 | var maxlen = 0.0f
240 |
241 | while (lineNo > 0) {
242 | lineNo -= 1
243 | if (lineNo % 10000 == 0)
244 | logger info f"READ Word2Vec file : $lineNo%9d/$mapSize%9d"
245 |
246 | val line = br.readLine()
247 | val splits = line.split("\\s+")
248 | val word = splits(0)
249 | val vector = splits.view.slice(1, vectorSize + 1).map(_.toFloat).force
250 | val len = vector.map(Math.abs).max
251 | require(vector.length == vectorSize, s"'$word' Vector is broken! Read size ${vector.length}, but expected $vectorSize")
252 |
253 | if (maxlen < len)
254 | maxlen = len
255 | buffer += word → vector
256 | }
257 |
258 | br.close()
259 |
260 | if (normalize && maxlen > 0f) {
261 | logger info f"READ Word2Vec file : Maximum absolute value of entry in vector matrix = $maxlen%.4f"
262 | buffer.foreach {
263 | case (_, vec) ⇒
264 | var i = vec.length
265 | while (i > 0) {
266 | i -= 1
267 | vec.update(i, vec(i) / maxlen)
268 | }
269 | }
270 | }
271 |
272 | val model = new WordModel(buffer)
273 | val stream = new ObjectOutputStream(File(path).outputStream())
274 | stream.writeObject(model)
275 | stream.close()
276 |
277 | logger info "READ Word2Vec finished."
278 | model
279 | }
280 | }
281 | }
282 | }
283 |
--------------------------------------------------------------------------------