14 | getFeatureList();
15 | /**
16 | * repeated .tensorflow.Feature feature = 1;
17 | */
18 | org.tensorflow.example.Feature getFeature(int index);
19 | /**
20 | * repeated .tensorflow.Feature feature = 1;
21 | */
22 | int getFeatureCount();
23 | /**
24 | * repeated .tensorflow.Feature feature = 1;
25 | */
26 | java.util.List extends org.tensorflow.example.FeatureOrBuilder>
27 | getFeatureOrBuilderList();
28 | /**
29 | * repeated .tensorflow.Feature feature = 1;
30 | */
31 | org.tensorflow.example.FeatureOrBuilder getFeatureOrBuilder(
32 | int index);
33 | }
34 |
--------------------------------------------------------------------------------
/src/main/java/org/tensorflow/example/FeatureOrBuilder.java:
--------------------------------------------------------------------------------
1 | // Generated by the protocol buffer compiler. DO NOT EDIT!
2 | // source: feature.proto
3 |
4 | package org.tensorflow.example;
5 |
6 | public interface FeatureOrBuilder extends
7 | // @@protoc_insertion_point(interface_extends:tensorflow.Feature)
8 | com.google.protobuf.MessageOrBuilder {
9 |
10 | /**
11 | * .tensorflow.BytesList bytes_list = 1;
12 | */
13 | boolean hasBytesList();
14 | /**
15 | * .tensorflow.BytesList bytes_list = 1;
16 | */
17 | org.tensorflow.example.BytesList getBytesList();
18 | /**
19 | * .tensorflow.BytesList bytes_list = 1;
20 | */
21 | org.tensorflow.example.BytesListOrBuilder getBytesListOrBuilder();
22 |
23 | /**
24 | * .tensorflow.FloatList float_list = 2;
25 | */
26 | boolean hasFloatList();
27 | /**
28 | * .tensorflow.FloatList float_list = 2;
29 | */
30 | org.tensorflow.example.FloatList getFloatList();
31 | /**
32 | * .tensorflow.FloatList float_list = 2;
33 | */
34 | org.tensorflow.example.FloatListOrBuilder getFloatListOrBuilder();
35 |
36 | /**
37 | * .tensorflow.Int64List int64_list = 3;
38 | */
39 | boolean hasInt64List();
40 | /**
41 | * .tensorflow.Int64List int64_list = 3;
42 | */
43 | org.tensorflow.example.Int64List getInt64List();
44 | /**
45 | * .tensorflow.Int64List int64_list = 3;
46 | */
47 | org.tensorflow.example.Int64ListOrBuilder getInt64ListOrBuilder();
48 |
49 | public org.tensorflow.example.Feature.KindCase getKindCase();
50 | }
51 |
--------------------------------------------------------------------------------
/src/main/java/org/tensorflow/example/FeaturesOrBuilder.java:
--------------------------------------------------------------------------------
1 | // Generated by the protocol buffer compiler. DO NOT EDIT!
2 | // source: feature.proto
3 |
4 | package org.tensorflow.example;
5 |
6 | public interface FeaturesOrBuilder extends
7 | // @@protoc_insertion_point(interface_extends:tensorflow.Features)
8 | com.google.protobuf.MessageOrBuilder {
9 |
10 | /**
11 | *
12 | * Map from feature name to feature.
13 | *
14 | *
15 | * map<string, .tensorflow.Feature> feature = 1;
16 | */
17 | int getFeatureCount();
18 | /**
19 | *
20 | * Map from feature name to feature.
21 | *
22 | *
23 | * map<string, .tensorflow.Feature> feature = 1;
24 | */
25 | boolean containsFeature(
26 | java.lang.String key);
27 | /**
28 | * Use {@link #getFeatureMap()} instead.
29 | */
30 | @java.lang.Deprecated
31 | java.util.Map
32 | getFeature();
33 | /**
34 | *
35 | * Map from feature name to feature.
36 | *
37 | *
38 | * map<string, .tensorflow.Feature> feature = 1;
39 | */
40 | java.util.Map
41 | getFeatureMap();
42 | /**
43 | *
44 | * Map from feature name to feature.
45 | *
46 | *
47 | * map<string, .tensorflow.Feature> feature = 1;
48 | */
49 |
50 | org.tensorflow.example.Feature getFeatureOrDefault(
51 | java.lang.String key,
52 | org.tensorflow.example.Feature defaultValue);
53 | /**
54 | *
55 | * Map from feature name to feature.
56 | *
57 | *
58 | * map<string, .tensorflow.Feature> feature = 1;
59 | */
60 |
61 | org.tensorflow.example.Feature getFeatureOrThrow(
62 | java.lang.String key);
63 | }
64 |
--------------------------------------------------------------------------------
/src/main/java/org/tensorflow/example/FloatListOrBuilder.java:
--------------------------------------------------------------------------------
1 | // Generated by the protocol buffer compiler. DO NOT EDIT!
2 | // source: feature.proto
3 |
4 | package org.tensorflow.example;
5 |
6 | public interface FloatListOrBuilder extends
7 | // @@protoc_insertion_point(interface_extends:tensorflow.FloatList)
8 | com.google.protobuf.MessageOrBuilder {
9 |
10 | /**
11 | * repeated float value = 1 [packed = true];
12 | */
13 | java.util.List getValueList();
14 | /**
15 | * repeated float value = 1 [packed = true];
16 | */
17 | int getValueCount();
18 | /**
19 | * repeated float value = 1 [packed = true];
20 | */
21 | float getValue(int index);
22 | }
23 |
--------------------------------------------------------------------------------
/src/main/java/org/tensorflow/example/Int64ListOrBuilder.java:
--------------------------------------------------------------------------------
1 | // Generated by the protocol buffer compiler. DO NOT EDIT!
2 | // source: feature.proto
3 |
4 | package org.tensorflow.example;
5 |
6 | public interface Int64ListOrBuilder extends
7 | // @@protoc_insertion_point(interface_extends:tensorflow.Int64List)
8 | com.google.protobuf.MessageOrBuilder {
9 |
10 | /**
11 | * repeated int64 value = 1 [packed = true];
12 | */
13 | java.util.List getValueList();
14 | /**
15 | * repeated int64 value = 1 [packed = true];
16 | */
17 | int getValueCount();
18 | /**
19 | * repeated int64 value = 1 [packed = true];
20 | */
21 | long getValue(int index);
22 | }
23 |
--------------------------------------------------------------------------------
/src/main/java/org/tensorflow/example/RecordWriter.java:
--------------------------------------------------------------------------------
1 | package org.tensorflow.example;
2 | import java.io.*;
3 | import java.util.zip.*;
4 |
5 | public class RecordWriter {
6 | private static final long serialVersionUID = 0L;
7 |
8 | public RecordWriter(DataInputStream ds) {
9 | }
10 | }
11 |
--------------------------------------------------------------------------------
/src/main/java/org/tensorflow/example/SequenceExampleOrBuilder.java:
--------------------------------------------------------------------------------
1 | // Generated by the protocol buffer compiler. DO NOT EDIT!
2 | // source: example.proto
3 |
4 | package org.tensorflow.example;
5 |
6 | public interface SequenceExampleOrBuilder extends
7 | // @@protoc_insertion_point(interface_extends:tensorflow.SequenceExample)
8 | com.google.protobuf.MessageOrBuilder {
9 |
10 | /**
11 | * .tensorflow.Features context = 1;
12 | */
13 | boolean hasContext();
14 | /**
15 | * .tensorflow.Features context = 1;
16 | */
17 | org.tensorflow.example.Features getContext();
18 | /**
19 | * .tensorflow.Features context = 1;
20 | */
21 | org.tensorflow.example.FeaturesOrBuilder getContextOrBuilder();
22 |
23 | /**
24 | * .tensorflow.FeatureLists feature_lists = 2;
25 | */
26 | boolean hasFeatureLists();
27 | /**
28 | * .tensorflow.FeatureLists feature_lists = 2;
29 | */
30 | org.tensorflow.example.FeatureLists getFeatureLists();
31 | /**
32 | * .tensorflow.FeatureLists feature_lists = 2;
33 | */
34 | org.tensorflow.example.FeatureListsOrBuilder getFeatureListsOrBuilder();
35 | }
36 |
--------------------------------------------------------------------------------
/src/main/java/org/tensorflow/io/RecordWriter.java:
--------------------------------------------------------------------------------
1 | package org.tensorflow.io;
2 | import java.io.*;
3 | import java.util.zip.*;
4 |
5 | public class RecordWriter {
6 | private static final long serialVersionUID = 0L;
7 | private static final int DEFAULT_BUFSIZE = 64*1024;
8 |
9 | private BufferedOutputStream ds_;
10 |
11 | public RecordWriter(OutputStream ds) {
12 | ds_ = new BufferedOutputStream(ds, DEFAULT_BUFSIZE);
13 | }
14 |
15 | public RecordWriter(String fname) throws IOException {
16 | FileOutputStream fout = new FileOutputStream(fname);
17 | ds_ = new BufferedOutputStream(fout, DEFAULT_BUFSIZE);
18 | }
19 |
20 | public int maskedCRC(byte [] bytes, int count) {
21 | return CRC32C.mask(CRC32C.getValue(bytes, 0, count));
22 | }
23 |
24 | public int writeRecord(byte [] data) throws IOException {
25 | byte [] header = new byte[12];
26 | byte [] footer = new byte[4];
27 | CRC32C.encodeFixed64(header, 0, data.length);
28 | CRC32C.encodeFixed32(header, 8, maskedCRC(header, 8));
29 |
30 | CRC32C.encodeFixed32(footer, 0, maskedCRC(data, data.length));
31 |
32 | ds_.write(header, 0, 12);
33 | ds_.write(data, 0, data.length);
34 | ds_.write(footer, 0, 4);
35 |
36 | return 0;
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/src/main/resources/application.conf:
--------------------------------------------------------------------------------
1 | akka {
2 | actor {
3 | provider = cluster
4 | }
5 | remote {
6 | log-remote-lifecycle-events = off
7 | netty.tcp {
8 | hostname = "127.0.0.1"
9 | port = 0
10 | }
11 | }
12 |
13 | cluster {
14 | seed-nodes = [
15 | "akka.tcp://ClusterSystem@127.0.0.1:2551",
16 | "akka.tcp://ClusterSystem@127.0.0.1:2552"]
17 |
18 | # auto downing is NOT safe for production deployments.
19 | # you may want to use it during development, read more about it in the docs.
20 | auto-down-unreachable-after = 10s
21 | }
22 | log-dead-letters = 0
23 | log-dead-letters-during-shutdown = off
24 | }
25 |
26 | # Disable legacy metrics in akka-cluster.
27 | akka.cluster.metrics.enabled=off
28 |
29 | # Enable metrics extension in akka-cluster-metrics.
30 | //akka.extensions=["akka.cluster.metrics.ClusterMetricsExtension"]
31 |
32 | # Sigar native library extract location during tests.
33 | # Note: use per-jvm-instance folder when running multiple jvm on one host.
34 | akka.cluster.metrics.native-library-extract-folder=${user.dir}/target/native
35 |
--------------------------------------------------------------------------------
/src/main/resources/lib/touch.txt:
--------------------------------------------------------------------------------
1 | touch
2 |
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/Copyright.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2012, Regents of the University of California
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 | * Redistributions of source code must retain the above copyright
7 | notice, this list of conditions and the following disclaimer.
8 | * Redistributions in binary form must reproduce the above copyright
9 | notice, this list of conditions and the following disclaimer in the
10 | documentation and/or other materials provided with the distribution.
11 | * Neither the name of the nor the
12 | names of its contributors may be used to endorse or promote products
13 | derived from this software without specific prior written permission.
14 |
15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 | DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY
19 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 |
26 |
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/Logging.scala:
--------------------------------------------------------------------------------
1 | package BIDMach
2 | import BIDMat.{Mat,SBMat,CMat,DMat,FMat,IMat,HMat,GDMat,GLMat,GMat,GIMat,GSDMat,GSMat,LMat,SMat,SDMat,TMat}
3 | import BIDMat.MatFunctions._
4 | import BIDMat.SciFunctions._
5 | import BIDMat.Plotting._
6 | import BIDMach.models._
7 | import BIDMach.datasinks._
8 |
9 |
10 | object Logging{
11 | def logGradientL2Norm(model:Model,data:Array[Mat]):Array[Mat] = {
12 | val m = model.modelmats
13 | val res = new Array[Float](m.length)
14 | for(i<-0 until m.length){
15 | res(i) = sum(snorm(m(i))).dv.toFloat
16 | }
17 | Array(new FMat(m.length,1,res))
18 | }
19 |
20 | def logGradientL1Norm(model:Model,data:Array[Mat]):Array[Mat] = {
21 | val m = model.modelmats
22 | val res = new Array[Float](m.length)
23 | for(i<-0 until m.length){
24 | res(i) = sum(sum(abs(m(i)))).dv.toFloat
25 | }
26 | Array(new FMat(m.length,1,res))
27 | }
28 |
29 | def getResults(model:Model): Array[Mat] = {
30 | model.opts.logDataSink match {
31 | case f:FileSink=>{println("Found results at "+f.opts.ofnames.head(0));null}
32 | case m:MatSink=>m.mats
33 | case null=>{println("No logDataSink found");null}
34 | }
35 | }
36 |
37 | def getResults(l:Learner): Array[Mat] = getResults(l.model)
38 | }
39 |
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/allreduce/AllreduceDummyLearner.scala:
--------------------------------------------------------------------------------
1 | package BIDMach.allreduce
2 |
3 | import BIDMach.Learner
4 | import BIDMach.networks.Net
5 |
6 | /**
7 | * A dummy learner for ease of test. Can be opt or refactored out if necessary
8 | * @param learner
9 | * @param dummy_model
10 | */
11 | class AllreduceDummyLearner(learner:Learner, dummy_model:AllreduceDummyModel)
12 | extends Learner(learner.datasource,dummy_model,learner.mixins, learner.updater, learner.datasink ,learner.opts) {
13 |
14 | def this(){
15 | this(Net.learner("dummy learner")._1, new AllreduceDummyModel())
16 | }
17 |
18 |
19 | override def train: Unit = {
20 | println("dummy model is training!")
21 | while(true){
22 | this.ipass+=1
23 | myLogger.info("pass=%2d" format ipass)
24 | this.dummy_model.showSomeWork()
25 | }
26 |
27 | }
28 |
29 | }
30 |
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/allreduce/AllreduceDummyModel.scala:
--------------------------------------------------------------------------------
1 | package BIDMach.allreduce
2 |
3 | import BIDMach.models.Model
4 | import BIDMat.{FMat, Mat}
5 |
6 | class AllreduceDummyModel(val _modelmat: Array[Mat]) extends Model {
7 | def this(){
8 | this(Array[Mat](FMat.ones(30,100),FMat.ones(100,30)))
9 | }
10 |
11 |
12 | override def modelmats:Array[Mat] = {
13 | _modelmat
14 | }
15 | override def init()={}
16 | override def dobatch(mats:Array[Mat], ipass:Int, here:Long)={}
17 | override def evalbatch(mats: Array[Mat], ipass: Int, here:Long):FMat = {
18 | FMat.zeros(0,0)
19 | }
20 | def showSomeWork(){
21 | println("I'm learning something")
22 | Thread.sleep(1000)
23 | }
24 |
25 | }
26 |
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/allreduce/AllreduceMessage.scala:
--------------------------------------------------------------------------------
1 | package BIDMach.allreduce
2 |
3 | import akka.actor.ActorRef
4 | import scala.collection.mutable.ArrayBuffer
5 |
6 |
7 | // worker messages
8 | final case class StartAllreduce(config : RoundConfig)
9 | final case class CompleteAllreduce(srcId : Int, config : RoundConfig)
10 |
11 | final case class ScatterBlock(value : Array[Float], srcId : Int, destId : Int, chunkId : Int, config : RoundConfig)
12 | final case class ReduceBlock(value: Array[Float], srcId : Int, destId : Int, chunkId : Int, config : RoundConfig, count: Int)
13 |
14 | final case class AllreduceStats(outgoingFloats: Long, incomingFloats: Long)
15 |
16 | /**
17 | * "comparison override to provide a (line master version, round) pair for a smooth transition when nodes are added or removed
18 | */
19 | final case class RoundConfig(lineMasterVersion : Int, round: Int, lineMaster : ActorRef, peerWorkers: Map[Int, ActorRef], workerId: Int) {
20 | def < (other : RoundConfig): Boolean = {
21 | return if (lineMasterVersion < other.lineMasterVersion ||
22 | (lineMasterVersion == other.lineMasterVersion && round < other.round)) {true}
23 | else {false}
24 | }
25 |
26 | def == (other : RoundConfig): Boolean = {
27 | return if (lineMasterVersion == other.lineMasterVersion && round == other.round) {true} else {false}
28 | }
29 |
30 | def > (other : RoundConfig): Boolean = {
31 | return !(this < other || this == other)
32 | }
33 | }
34 |
35 | /*
36 | * Following message used by Line Master
37 | */
38 | final case class StartAllreduceTask(peerNodes: ArrayBuffer[ActorRef], lineMasterVersion : Int)
39 | final case class StopAllreduceTask(lineMasterVersion : Int)
40 |
41 | /*
42 | * For grid master in case we want to kill the node
43 | */
44 | final case class StopAllreduceNode()
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/allreduce/binder/AllreduceBinder.scala:
--------------------------------------------------------------------------------
1 | package BIDMach.allreduce.binder
2 |
3 | import BIDMach.allreduce.binder.AllreduceBinder.{DataSink, DataSource}
4 |
5 | /**
6 | * Trait to specify source and sink, allowing binding data input/output to the all-reduce process.
7 | */
8 | trait AllreduceBinder {
9 |
10 | def totalDataSize: Int
11 |
12 | def dataSource: DataSource
13 |
14 | def dataSink: DataSink
15 |
16 | }
17 |
18 | object AllreduceBinder {
19 |
20 | type DataSink = AllReduceOutput => Unit
21 | type DataSource = AllReduceInputRequest => AllReduceInput
22 | var updateCounts = 100
23 |
24 | }
25 |
26 | case class AllReduceInputRequest(iteration: Int)
27 |
28 | case class AllReduceInput(data: Array[Float])
29 |
30 | case class AllReduceOutput(data: Array[Float], iteration: Int)
31 |
32 |
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/allreduce/binder/AssertCorrectnessBinder.scala:
--------------------------------------------------------------------------------
1 | package BIDMach.allreduce.binder
2 |
3 | import BIDMach.allreduce.binder.AllreduceBinder.{DataSink, DataSource}
4 |
5 |
6 | class AssertCorrectnessBinder(dataSize: Int, checkpoint: Int) extends AllreduceBinder {
7 |
8 | val random = new scala.util.Random(100)
9 | val totalInputSample = 8
10 |
11 | lazy val randomFloats = {
12 | val nestedArray = new Array[Array[Float]](totalInputSample)
13 | for (i <- 0 until totalInputSample) {
14 | nestedArray(i) = Array.range(0, dataSize).toList.map(_ => random.nextFloat()).toArray
15 | }
16 | nestedArray
17 | }
18 |
19 | private def ~=(x: Double, y: Double, precision: Double = 1e-5) = {
20 | if ((x - y).abs < precision) true else false
21 | }
22 |
23 | override def dataSource: DataSource = r => {
24 | AllReduceInput(randomFloats(r.iteration % totalInputSample))
25 | }
26 |
27 | override def dataSink: DataSink = r => {
28 |
29 | if (r.iteration % checkpoint == 0) {
30 | val inputUsed = randomFloats(r.iteration % totalInputSample)
31 | println(s"\n----Asserting #${r.iteration} output...")
32 | for (i <- 0 until dataSize) {
33 | val meanActual = r.data(i)
34 | val expected = inputUsed(i)
35 | assert(~=(expected, meanActual), s"Expected [$expected], but actual [$meanActual] at pos $i for iteraton #${r.iteration}")
36 | }
37 | println("OK: Means match the expected value!")
38 | }
39 |
40 | }
41 |
42 | override def totalDataSize: Int = dataSize
43 | }
44 |
45 |
46 |
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/allreduce/binder/NoOpBinder.scala:
--------------------------------------------------------------------------------
1 | package BIDMach.allreduce.binder
2 | import BIDMach.allreduce.binder.AllreduceBinder.{DataSink, DataSource}
3 |
4 | /**
5 | * Just for experiment. Can be opted out or refactored.
6 | */
7 | class NoOpBinder(dataSize: Int, printFrequency: Int = 10) extends AllreduceBinder {
8 |
9 |
10 | val random = new scala.util.Random(100)
11 | val totalInputSample = 4
12 |
13 | lazy val randomFloats = {
14 | val nestedArray: Array[Array[Float]] = Array.ofDim(totalInputSample, dataSize)
15 | for (i <- 0 until totalInputSample) {
16 | for (j <- 0 until dataSize)
17 | nestedArray(i)(j) = random.nextFloat()
18 | }
19 | nestedArray
20 | }
21 |
22 |
23 | override def dataSource: DataSource = { inputRequest =>
24 | if (inputRequest.iteration % printFrequency == 0) {
25 | println(s"--NoOptBinder: dump model data at ${inputRequest.iteration}--")
26 | }
27 |
28 | AllReduceInput(randomFloats(inputRequest.iteration % totalInputSample))
29 | }
30 |
31 | override def dataSink: DataSink = { output =>
32 | if (output.iteration % printFrequency == 0) {
33 | println(s"--NoOptBinder: reduced done data at ${output.iteration}--")
34 | }
35 |
36 | }
37 |
38 | override def totalDataSize: Int = dataSize
39 | }
40 |
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/allreduce/buffer/AllReduceBuffer.scala:
--------------------------------------------------------------------------------
1 | package BIDMach.allreduce.buffer
2 |
3 |
4 | abstract class AllReduceBuffer(dataSize: Int,
5 | peerSize: Int,
6 | maxChunkSize: Int) {
7 |
8 | type Buffer = Array[Array[Float]]
9 |
10 | val peerBuffer: Buffer = Array.ofDim(peerSize, dataSize)
11 |
12 | val numChunks = getNumChunk(dataSize)
13 |
14 | protected def store(data: Array[Float], srcId: Int, chunkId: Int) = {
15 |
16 | val array = peerBuffer(srcId)
17 | System.arraycopy(
18 | data, 0,
19 | array, chunkId * maxChunkSize,
20 | data.size)
21 | }
22 |
23 | protected def getNumChunk(size: Int) = {
24 | math.ceil(1f * size / maxChunkSize).toInt
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/caffe/Classifier.scala:
--------------------------------------------------------------------------------
1 | package BIDMach.caffe
2 | import BIDMat.{Mat,SBMat,CMat,CSMat,DMat,FMat,GMat,GIMat,GSMat,HMat,Image,IMat,ND,SMat,SDMat}
3 | import BIDMat.MatFunctions._
4 | import BIDMat.SciFunctions._
5 | import BIDMach.datasources._
6 | import edu.berkeley.bvlc.SGDSOLVER
7 | import edu.berkeley.bvlc.NET
8 | import edu.berkeley.bvlc.CAFFE
9 |
10 | class Classifier {
11 |
12 | val net = new Net
13 |
14 | def init(model_file:String, pretrained_file:String, image_dims:Array[Int] = Array(256, 256),
15 | gpu:Boolean = false, mean_file:String = null, input_scale:Float = 1f, channel_swap:IMat = 2\1\0) = {
16 |
17 | net.init(model_file, pretrained_file);
18 |
19 | CAFFE.set_phase(1);
20 |
21 | CAFFE.set_mode(if (gpu) 1 else 0)
22 |
23 | if (image_dims != null) {
24 | net.set_image_dims(image_dims)
25 | } else {
26 | net.set_image_dims(Array(net.inwidth, net.inheight))
27 | }
28 |
29 | if (mean_file != null) net.set_mean(mean_file)
30 |
31 | if (input_scale != 1f) net.set_input_scale(input_scale)
32 |
33 | if (channel_swap.asInstanceOf[AnyRef] != null) net.set_channel_swap(channel_swap)
34 |
35 | }
36 |
37 | def classify(im:Image):FMat = {
38 | val fnd = net.preprocess(im)
39 | net.clear_inputs
40 | net.add_input(fnd, 0, 0)
41 | net.forward
42 | net.output_data(0)(?,?,?,0)
43 | }
44 |
45 |
46 | }
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/caffe/SGDSolver.scala:
--------------------------------------------------------------------------------
1 | package BIDMach.caffe
2 | import BIDMat.{Mat,SBMat,CMat,CSMat,DMat,FMat,GMat,GIMat,GSMat,HMat,Image,IMat,ND,SMat,SDMat}
3 | import BIDMat.MatFunctions._
4 | import BIDMat.SciFunctions._
5 | import BIDMach.datasources._
6 | import edu.berkeley.bvlc.SGDSOLVER
7 | import edu.berkeley.bvlc.NET
8 | import edu.berkeley.bvlc.CAFFE
9 |
10 | class SGDSolver (val sgd:SGDSOLVER) {
11 | val net = sgd.net
12 |
13 | def Solve = sgd.Solve
14 |
15 | def SolveResume(fname:String) = sgd.SolveResume(fname)
16 |
17 | }
18 |
19 | object SGDSolver {
20 | def apply(paramFile:String):SGDSolver = new SGDSolver(new SGDSOLVER(paramFile))
21 | }
22 |
23 |
24 |
25 |
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/datasinks/DataSink.scala:
--------------------------------------------------------------------------------
1 | package BIDMach.datasinks
2 | import BIDMat.{Mat,SBMat,CMat,CSMat,DMat,FMat,IMat,HMat,GMat,GIMat,GSMat,SMat,SDMat}
3 | import BIDMat.MatFunctions._
4 | import BIDMat.SciFunctions._
5 | import java.io._
6 |
7 | @SerialVersionUID(100L)
8 | abstract class DataSink(val opts:DataSink.Opts = new DataSink.Options) extends Serializable {
9 | private var _GUID = Mat.myrand.nextLong
10 | def setGUID(v:Long):Unit = {_GUID = v}
11 | def GUID:Long = _GUID
12 | def put;
13 | def init:Unit = {}
14 | def close = {}
15 | private var _nmats = 0;
16 | def nmats = _nmats;
17 | def setnmats(k:Int) = {_nmats = k;}
18 | var omats:Array[Mat] = null
19 | }
20 |
21 | @SerialVersionUID(100L)
22 | object DataSink {
23 | trait Opts extends BIDMat.Opts {
24 | }
25 |
26 | class Options extends Opts {}
27 | }
28 |
29 |
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/datasinks/FileSink.scala:
--------------------------------------------------------------------------------
1 | package BIDMach.datasinks
2 | import BIDMat.{Mat,SBMat,CMat,CSMat,DMat,FMat,IMat,HMat,GMat,GDMat,GIMat,GLMat,GSMat,GSDMat,LMat,SMat,SDMat}
3 | import BIDMat.MatFunctions._
4 | import BIDMat.SciFunctions._
5 | import BIDMach.datasources._
6 | import scala.collection.mutable.ListBuffer
7 |
8 | @SerialVersionUID(100L)
9 | class FileSink(override val opts:FileSink.Opts = new FileSink.Options) extends MatSink(opts) {
10 | var ifile = 0;
11 | var colsdone = 0;
12 |
13 | override def init = {
14 | blocks = new ListBuffer[Array[Mat]]();
15 | setnmats(opts.ofnames.length);
16 | omats = new Array[Mat](nmats);
17 | ifile = 0;
18 | opts match {
19 | case fopts:FileSource.Opts => {
20 | ifile = fopts.nstart;
21 | }
22 | }
23 | colsdone = 0;
24 | }
25 |
26 | override def put = {
27 | blocks += omats.map(MatSink.copyCPUmat);
28 | colsdone += omats(0).ncols;
29 | if (colsdone >= opts.ofcols) {
30 | mergeSaveBlocks;
31 | colsdone = 0;
32 | ifile += 1;
33 | blocks = new ListBuffer[Array[Mat]]();
34 | }
35 | }
36 |
37 | override def close () = {
38 | mergeSaveBlocks;
39 | }
40 |
41 | def mergeSaveBlocks = {
42 | mergeBlocks
43 | if (blocks.size > 0) {
44 | for (i <- 0 until opts.ofnames.length) {
45 | saveMat(opts.ofnames(i)(ifile), mats(i));
46 | }
47 | }
48 | }
49 | }
50 |
51 | @SerialVersionUID(100L)
52 | object FileSink {
53 | trait Opts extends MatSink.Opts {
54 | var ofnames:List[(Int)=>String] = null;
55 | var ofcols = 100000;
56 | }
57 |
58 | class Options extends Opts {
59 |
60 | }
61 | }
62 |
63 |
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/datasources/ArraySource.scala:
--------------------------------------------------------------------------------
1 | package BIDMach.datasources
2 | import BIDMat.{Mat,SBMat,CMat,CSMat,DMat,FMat,IMat,HMat,GMat,GIMat,GSMat,SMat,SDMat}
3 | import BIDMat.MatFunctions._
4 | import BIDMat.SciFunctions._
5 | import BIDMat.MatIOtrait
6 | import scala.concurrent.Future
7 | import scala.concurrent.ExecutionContextExecutor
8 | import java.io._
9 |
10 | @SerialVersionUID(100L)
11 | class ArraySource(override val opts:ArraySource.Opts = new ArraySource.Options) extends IteratorSource(opts) {
12 | @transient var dataArray:Array[_ <: AnyRef] = null
13 |
14 | override def init = {
15 | dataArray = opts.dataArray
16 | super.init
17 | }
18 |
19 | override def iterHasNext:Boolean = {
20 | iblock += 1
21 | iblock < dataArray.length
22 | }
23 |
24 | override def hasNext:Boolean = {
25 | val matq = inMats(0)
26 | val matqnr = if (opts.dorows) matq.nrows else matq.ncols
27 | val ihn = iblock < dataArray.length
28 | if (! ihn && iblock > 0) {
29 | nblocks = iblock
30 | }
31 | (ihn || (matqnr - samplesDone) == 0);
32 | }
33 |
34 | override def iterNext() = {
35 | val marr = dataArray(iblock)
36 | marr match {
37 | case (key:AnyRef,v:MatIOtrait) => {inMats = v.get}
38 | case m:Mat => {
39 | if (inMats == null) inMats = Array[Mat](1);
40 | inMats(0) = m;
41 | }
42 | case ma:Array[Mat] => inMats = ma;
43 | }
44 | }
45 |
46 | override def close = {
47 | iblock = 0
48 | }
49 | }
50 |
51 | @SerialVersionUID(100L)
52 | object ArraySource {
53 | def apply(opts:ArraySource.Opts):ArraySource = {
54 | new ArraySource(opts);
55 | }
56 |
57 | trait Opts extends IteratorSource.Opts {
58 | @transient var dataArray:Array[_ <: AnyRef] = null
59 | }
60 |
61 | @SerialVersionUID(100L)
62 | class Options extends Opts {}
63 | }
64 |
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/datasources/DataSource.scala:
--------------------------------------------------------------------------------
1 | package BIDMach.datasources
2 | import BIDMat.{Mat,SBMat,CMat,CSMat,DMat,FMat,IMat,HMat,GMat,GIMat,GSMat,SMat,SDMat}
3 | import BIDMat.MatFunctions._
4 | import BIDMat.SciFunctions._
5 | import java.io._
6 |
7 | @SerialVersionUID(100L)
8 | abstract class DataSource(val opts:DataSource.Opts = new DataSource.Options) extends Serializable {
9 | private var _GUID = Mat.myrand.nextLong
10 | def setGUID(v:Long):Unit = {_GUID = v}
11 | def GUID:Long = _GUID
12 | def next:Array[Mat]
13 | def hasNext:Boolean
14 | def reset:Unit
15 | def putBack(mats:Array[Mat],i:Int):Unit = {throw new RuntimeException("putBack not implemented")}
16 | def setupPutBack(n:Int,dim:Int):Unit = {throw new RuntimeException("putBack not implemented")}
17 | def nmats:Int
18 | def init:Unit
19 | def progress:Float
20 | def close = {}
21 | var omats:Array[Mat] = null
22 | var endmats:Array[Mat] = null
23 | var fullmats:Array[Mat] = null
24 | }
25 |
26 | @SerialVersionUID(100L)
27 | object DataSource {
28 | trait Opts extends BIDMat.Opts {
29 | var batchSize = 10000
30 | var sizeMargin = 3f
31 | var sample = 1f
32 | var addConstFeat:Boolean = false
33 | var featType:Int = 1 // 0 = binary features, 1 = linear features, 2 = threshold features
34 | var featThreshold:Mat = null
35 | var putBack = -1
36 | }
37 |
38 | class Options extends Opts {}
39 | }
40 |
41 |
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/mixins/Mixin.scala:
--------------------------------------------------------------------------------
1 | package BIDMach.mixins
2 | import BIDMat.{Mat,SBMat,CMat,DMat,FMat,IMat,HMat,GMat,GIMat,GSMat,SMat,SDMat}
3 | import BIDMat.MatFunctions._
4 | import BIDMat.SciFunctions._
5 | import BIDMach.models._
6 |
7 | @SerialVersionUID(100L)
8 | abstract class Mixin(val opts:Mixin.Opts = new Mixin.Options) extends Serializable {
9 | val options = opts
10 | var modelmats:Array[Mat] = null
11 | var updatemats:Array[Mat] = null
12 | var counter = 0
13 |
14 | def compute(mats:Array[Mat], step:Float)
15 |
16 | def score(mats:Array[Mat], step:Float):FMat
17 |
18 | def init(model:Model) = {
19 | modelmats = model.modelmats
20 | updatemats = model.updatemats
21 | }
22 | }
23 |
24 | object Mixin {
25 | trait Opts extends BIDMat.Opts {
26 | var mixinInterval = 1
27 | }
28 |
29 | class Options extends Opts {}
30 | }
31 |
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/networks/layers/ForwardLayer.scala:
--------------------------------------------------------------------------------
1 | package BIDMach.networks.layers
2 |
3 | import BIDMat.{Mat,SBMat,CMat,DMat,FMat,IMat,LMat,HMat,GMat,GDMat,GIMat,GLMat,GSMat,GSDMat,SMat,SDMat}
4 | import BIDMat.MatFunctions._
5 | import BIDMat.SciFunctions._
6 | import BIDMach.datasources._
7 | import BIDMach.updaters._
8 | import BIDMach.mixins._
9 | import BIDMach.models._
10 | import BIDMach._
11 | import edu.berkeley.bid.CPUMACH
12 | import edu.berkeley.bid.CUMACH
13 | import scala.util.hashing.MurmurHash3;
14 | import java.util.HashMap;
15 | import BIDMach.networks._
16 |
17 |
18 | @SerialVersionUID(100L)
19 | class ForwardLayer(override val net:Net, override val opts:ForwardNodeOpts = new ForwardNode) extends Layer(net, opts) {
20 |
21 | override def forward = {
22 | val start = toc;
23 | inplaceNoConnectGetOutput();
24 |
25 | output <-- inputData;
26 | // clearDeriv;
27 | forwardtime += toc - start;
28 | }
29 |
30 | override def backward = {
31 | }
32 |
33 | override def toString = {
34 | "forward@"+Integer.toHexString(hashCode % 0x10000).toString
35 | }
36 | }
37 |
38 | trait ForwardNodeOpts extends NodeOpts {
39 | }
40 |
41 | @SerialVersionUID(100L)
42 | class ForwardNode extends Node with ForwardNodeOpts {
43 |
44 | override def clone:ForwardNode = {copyTo(new ForwardNode).asInstanceOf[ForwardNode];}
45 |
46 | override def create(net:Net):ForwardLayer = {ForwardLayer(net, this);}
47 |
48 | override def toString = {
49 | "forward@"+Integer.toHexString(hashCode % 0x10000).toString
50 | }
51 | }
52 |
53 | @SerialVersionUID(100L)
54 | object ForwardLayer {
55 |
56 | def apply(net:Net) = new ForwardLayer(net, new ForwardNode);
57 |
58 | def apply(net:Net, opts:ForwardNode) = new ForwardLayer(net, opts);
59 | }
60 |
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/networks/layers/MaxIndexLayer.scala:
--------------------------------------------------------------------------------
1 | package BIDMach.networks.layers
2 |
3 | import BIDMat.{Mat,SBMat,CMat,DMat,FMat,IMat,LMat,HMat,GMat,GDMat,GIMat,GLMat,GSMat,GSDMat,SMat,SDMat}
4 | import BIDMat.MatFunctions._
5 | import BIDMat.SciFunctions._
6 | import BIDMach.datasources._
7 | import BIDMach.updaters._
8 | import BIDMach.mixins._
9 | import BIDMach.models._
10 | import BIDMach._
11 | import edu.berkeley.bid.CPUMACH
12 | import edu.berkeley.bid.CUMACH
13 | import scala.util.hashing.MurmurHash3;
14 | import java.util.HashMap;
15 | import BIDMach.networks._
16 |
17 | @SerialVersionUID(100L)
18 | class MaxIndexLayer(override val net:Net, override val opts:MaxIndexNodeOpts = new MaxIndexNode) extends Layer(net, opts) {
19 |
20 | override def forward = {
21 | val start = toc;
22 | output = maxi2(inputData, 1)._2;
23 | forwardtime += toc - start;
24 | }
25 |
26 | override def backward = {
27 | val start = toc;
28 | backwardtime += toc - start;
29 | }
30 |
31 | override def toString = {
32 | "copy@"+Integer.toHexString(hashCode % 0x10000).toString
33 | }
34 | }
35 |
36 | trait MaxIndexNodeOpts extends NodeOpts {
37 | }
38 |
39 | @SerialVersionUID(100L)
40 | class MaxIndexNode extends Node with MaxIndexNodeOpts {
41 |
42 | override def clone:MaxIndexNode = {copyTo(new MaxIndexNode).asInstanceOf[MaxIndexNode];}
43 |
44 | override def create(net:Net):MaxIndexLayer = {MaxIndexLayer(net, this);}
45 |
46 | override def toString = {
47 | "maxidx@"+Integer.toHexString(hashCode % 0x10000).toString
48 | }
49 | }
50 |
51 | @SerialVersionUID(100L)
52 | object MaxIndexLayer {
53 |
54 | def apply(net:Net) = new MaxIndexLayer(net, new MaxIndexNode);
55 |
56 | def apply(net:Net, opts:MaxIndexNode) = new MaxIndexLayer(net, opts);
57 | }
58 |
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/networks/layers/NodeSet.scala:
--------------------------------------------------------------------------------
1 | package BIDMach.networks.layers
2 |
3 | @SerialVersionUID(100L)
4 | class NodeSet(val nnodes:Int, val nodes:Array[Node]) extends Serializable {
5 |
6 | def this(nnodes:Int) = this(nnodes, new Array[Node](nnodes));
7 |
8 | def this(nodes:Array[Node]) = this(nodes.length, nodes);
9 |
10 | def apply(i:Int):Node = nodes(i);
11 |
12 | def update(i:Int, lopts:Node) = {nodes(i) = lopts; this}
13 |
14 | def size = nnodes;
15 |
16 | def length = nnodes;
17 |
18 | override def clone = copyTo(new NodeSet(nnodes));
19 |
20 | def copyTo(lopts:NodeSet):NodeSet = {
21 | for (i <- 0 until nnodes) {
22 | lopts.nodes(i) = nodes(i).clone;
23 | nodes(i).myGhost = lopts.nodes(i);
24 | }
25 | for (i <- 0 until nnodes) {
26 | for (j <- 0 until nodes(i).inputs.length) {
27 | if (nodes(i).inputs(j) != null) lopts.nodes(i).inputs(j) = nodes(i).inputs(j).node.myGhost;
28 | }
29 | }
30 | lopts;
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/networks/layers/SignLayer.scala:
--------------------------------------------------------------------------------
1 | package BIDMach.networks.layers
2 |
3 | import BIDMat.{Mat,SBMat,CMat,DMat,FMat,IMat,LMat,HMat,GMat,GDMat,GIMat,GLMat,GSMat,GSDMat,SMat,SDMat}
4 | import BIDMat.MatFunctions._
5 | import BIDMat.SciFunctions._
6 | import BIDMach.datasources._
7 | import BIDMach.updaters._
8 | import BIDMach.mixins._
9 | import BIDMach.models._
10 | import BIDMach._
11 | import edu.berkeley.bid.CPUMACH
12 | import edu.berkeley.bid.CUMACH
13 | import scala.util.hashing.MurmurHash3;
14 | import java.util.HashMap;
15 | import BIDMach.networks._
16 |
17 |
18 | /**
19 | * Sign layer.
20 | */
21 |
22 | @SerialVersionUID(100L)
23 | class SignLayer(override val net:Net, override val opts:SignNodeOpts = new SignNode) extends Layer(net, opts) {
24 |
25 | override def forward = {
26 | val start = toc;
27 | inplaceNoConnectGetOutput();
28 |
29 | sign(inputData, output);
30 |
31 | forwardtime += toc - start;
32 | }
33 |
34 | override def backward = {
35 | val start = toc;
36 |
37 | backwardtime += toc - start;
38 | }
39 |
40 | override def toString = {
41 | "exp@"+Integer.toHexString(hashCode % 0x10000).toString
42 | }
43 | }
44 |
45 |
46 | trait SignNodeOpts extends NodeOpts {
47 | }
48 |
49 | @SerialVersionUID(100L)
50 | class SignNode extends Node with SignNodeOpts {
51 |
52 | override def clone:SignNode = {copyTo(new SignNode).asInstanceOf[SignNode];}
53 |
54 | override def create(net:Net):SignLayer = {SignLayer(net, this);}
55 |
56 | override def toString = {
57 | "exp@"+Integer.toHexString(hashCode % 0x10000).toString
58 | }
59 | }
60 |
61 | @SerialVersionUID(100L)
62 | object SignLayer {
63 |
64 | def apply(net:Net) = new SignLayer(net, new SignNode);
65 |
66 | def apply(net:Net, opts:SignNode) = new SignLayer(net, opts);
67 | }
68 |
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/updaters/Batch.scala:
--------------------------------------------------------------------------------
1 | package BIDMach.updaters
2 |
3 | import BIDMat.{Mat,SBMat,CMat,DMat,FMat,IMat,HMat,GMat,GIMat,GSMat,SMat,SDMat}
4 | import BIDMat.MatFunctions._
5 | import BIDMat.SciFunctions._
6 | import BIDMach.models._
7 |
8 | @SerialVersionUID(100L)
9 | class Batch(override val opts:Batch.Opts = new Batch.Options) extends Updater {
10 |
11 | override def init(model0:Model) = {
12 | super.init(model0)
13 | }
14 |
15 | override def update(ipass:Int, step:Long) = {}
16 | }
17 |
18 | @SerialVersionUID(100L)
19 | object Batch {
20 | trait Opts extends Updater.Opts {
21 | var beps = 1e-5f
22 | }
23 |
24 | class Options extends Opts {}
25 | }
26 |
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/updaters/BatchNorm.scala:
--------------------------------------------------------------------------------
1 | package BIDMach.updaters
2 |
3 | import BIDMat.{Mat,SBMat,CMat,DMat,FMat,IMat,HMat,GMat,GIMat,GSMat,SMat,SDMat}
4 | import BIDMat.MatFunctions._
5 | import BIDMat.SciFunctions._
6 | import BIDMach.models._
7 |
8 | @SerialVersionUID(100L)
9 | class BatchNorm(override val opts:BatchNorm.Opts = new BatchNorm.Options) extends Updater {
10 | var accumulators:Array[Mat] = null
11 |
12 | override def init(model0:Model) = {
13 | super.init(model0)
14 | val modelmats = model.modelmats
15 | val updatemats = model.updatemats
16 | accumulators = new Array[Mat](updatemats.length)
17 | for (i <- 0 until accumulators.length) {
18 | accumulators(i) = updatemats(i).zeros(updatemats(i).nrows, updatemats(i).ncols)
19 | }
20 | }
21 |
22 | override def update(ipass:Int, step:Long) = {
23 | val updatemats = model.updatemats
24 | for (i <- 0 until accumulators.length) {
25 | accumulators(i) ~ accumulators(i) + updatemats(i)
26 | }
27 | }
28 |
29 | override def clear() = {
30 | for (i <- 0 until accumulators.length) {
31 | accumulators(i).clear
32 | }
33 | }
34 |
35 | override def updateM(ipass:Int):Unit = {
36 | val mm = model.modelmats(0)
37 | mm ~ accumulators(0) / accumulators(1)
38 | mm ~ mm / sum(mm,2)
39 | clear
40 | }
41 | }
42 |
43 | @SerialVersionUID(100L)
44 | object BatchNorm {
45 | trait Opts extends Updater.Opts {
46 | }
47 |
48 | class Options extends Opts {}
49 | }
50 |
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/updaters/IncMult.scala:
--------------------------------------------------------------------------------
1 | package BIDMach.updaters
2 |
3 | import BIDMat.{Mat,SBMat,CMat,DMat,FMat,IMat,HMat,GMat,GIMat,GSMat,SMat,SDMat}
4 | import BIDMat.MatFunctions._
5 | import BIDMat.SciFunctions._
6 | import BIDMach.models._
7 |
8 | @SerialVersionUID(100L)
9 | class IncMult(override val opts:IncMult.Opts = new IncMult.Options) extends Updater {
10 |
11 | var firstStep = 0f
12 | var rm:Mat = null
13 |
14 | override def init(model0:Model) = {
15 | super.init(model0)
16 | rm = model0.modelmats(0).zeros(1,1)
17 | }
18 |
19 | override def update(ipass:Int, step:Long) = {
20 | val modelmats = model.modelmats
21 | val updatemats = model.updatemats
22 | val mm = modelmats(0)
23 | val ms = modelmats(1)
24 | val um = updatemats(0)
25 | val ums = updatemats(1)
26 | val rr = if (step == 0) 1f else {
27 | if (firstStep == 0f) {
28 | firstStep = step
29 | 1f
30 | } else {
31 | (math.pow(firstStep / step, opts.power)).toFloat
32 | }
33 | }
34 |
35 | um ~ um *@ rm.set(rr)
36 | ln(mm, mm)
37 | mm ~ mm *@ rm.set(1-rr)
38 | mm ~ mm + um
39 | exp(mm, mm)
40 | if (opts.isprob) mm ~ mm / sum(mm,2)
41 | }
42 |
43 | override def clear() = {
44 | firstStep = 0f
45 | }
46 | }
47 |
48 | @SerialVersionUID(100L)
49 | object IncMult {
50 | trait Opts extends Updater.Opts {
51 | var warmup = 0L
52 | var power = 0.3f
53 | var isprob = true
54 | }
55 |
56 | class Options extends Opts {}
57 | }
58 |
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/updaters/Telescoping.scala:
--------------------------------------------------------------------------------
1 | package BIDMach.updaters
2 |
3 | import BIDMat.{Mat,SBMat,CMat,DMat,FMat,IMat,HMat,GMat,GIMat,GSMat,SMat,SDMat}
4 | import BIDMat.MatFunctions._
5 | import BIDMat.SciFunctions._
6 | import BIDMach.models._
7 |
8 | @SerialVersionUID(100L)
9 | class Telescoping(override val opts:Telescoping.Opts = new Telescoping.Options) extends Updater {
10 | var accumulators:Array[Mat] = null
11 | var firstStep = 0L
12 | var nextStep = 10L
13 | var nextCount = 0L
14 | var rm:Mat = null
15 |
16 | override def init(model0:Model) = {
17 | super.init(model0)
18 | val modelmats = model0.modelmats
19 | val updatemats = model0.updatemats
20 | rm = model0.modelmats(0).zeros(1,1)
21 | accumulators = new Array[Mat](updatemats.length)
22 | for (i <- 0 until updatemats.length) yield {
23 | accumulators(i) = updatemats(i).zeros(updatemats(i).nrows, updatemats(i).ncols)
24 | }
25 | firstStep = 0L
26 | nextStep = 10L
27 | nextCount = 0L
28 | }
29 |
30 | override def update(ipass:Int, step:Long) = {
31 | if (firstStep == 0 && step > 0) {
32 | firstStep = step
33 | }
34 | val updatemats = model.updatemats
35 | for (i <- 0 until updatemats.length) {
36 | accumulators(i) ~ accumulators(i) + updatemats(i)
37 | }
38 | if (step >= nextCount) {
39 | model.modelmats(0) ~ accumulators(0) / accumulators(1)
40 | nextStep = (nextStep * opts.factor).toLong
41 | nextCount = step + nextStep
42 | }
43 | }
44 |
45 | override def clear() = {
46 | for (i <- 0 until accumulators.length) {
47 | accumulators(i).clear
48 | }
49 | }
50 | }
51 |
52 | @SerialVersionUID(100L)
53 | object Telescoping {
54 | trait Opts extends Updater.Opts {
55 | val factor = 1.5f
56 | }
57 |
58 | class Options extends Opts {}
59 | }
60 |
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/updaters/Updater.scala:
--------------------------------------------------------------------------------
1 | package BIDMach.updaters
2 |
3 | import BIDMat.{Mat,SBMat,CMat,DMat,FMat,IMat,HMat,GMat,GIMat,GSMat,SMat,SDMat}
4 | import BIDMat.MatFunctions._
5 | import BIDMat.SciFunctions._
6 | import BIDMach.models._
7 |
8 |
9 | abstract class Updater(val opts:Updater.Opts = new Updater.Options) extends Serializable {
10 | var model:Model = null;
11 | var runningtime = 0.0;
12 |
13 | def init(model0:Model) = {
14 | model = model0
15 | }
16 |
17 | def clear():Unit = {}
18 |
19 | def update(ipass:Int, step:Long):Unit = {}
20 |
21 | def update(ipass:Int, step:Long, gprogress:Float):Unit = update(ipass, step)
22 |
23 | def updateM(ipass:Int):Unit = {
24 | model.updatePass(ipass)
25 | }
26 |
27 | def preupdate(ipass:Int, step:Long, gprogress:Float):Unit = {}
28 | }
29 |
30 | @SerialVersionUID(100L)
31 | object Updater {
32 | trait Opts extends BIDMat.Opts {
33 | }
34 |
35 | class Options extends Opts {}
36 | }
37 |
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/viz/LogViz.scala:
--------------------------------------------------------------------------------
1 | package BIDMach.viz;
2 | import BIDMat.{BMat,Mat,SBMat,CMat,DMat,FMat,FFilter,IMat,HMat,GDMat,GFilter,GLMat,GMat,GIMat,GSDMat,GSMat,LMat,SMat,SDMat,TMat}
3 | import BIDMat.MatFunctions._
4 | import BIDMat.SciFunctions._
5 | import BIDMach.models.Model;
6 | import BIDMach.networks.Net;
7 | import BIDMach.networks.layers._;
8 | import BIDMach.Learner;
9 | import scala.collection.mutable.ListBuffer;
10 |
11 | /***
12 | Collect and Visualize some logged values
13 | **/
14 |
15 | class LogViz(val name: String = "varName") extends Visualization{
16 | val data:ListBuffer[FMat] = new ListBuffer[FMat];
17 | interval = 1;
18 |
19 | // Override one of these to collect some log data
20 | def collect(model:Model, mats:Array[Mat], ipass:Int, pos:Long):FMat = {
21 | collect(model);
22 | }
23 |
24 | def collect(model:Model):FMat = {
25 | collect();
26 | }
27 |
28 | def collect():FMat = {
29 | row(0);
30 | }
31 |
32 | override def doUpdate(model:Model, mats:Array[Mat], ipass:Int, pos:Long) = {
33 | data.synchronized {
34 | data += FMat(collect(model, mats, ipass, pos));
35 | }
36 | }
37 |
38 | def snapshot = {
39 | Learner.scores2FMat(data);
40 | }
41 |
42 | def fromto(n0:Int, n1:Int) = {
43 | data.synchronized {
44 | val len = data.length;
45 | val na = math.min(n0, len);
46 | val nb = math.min(n1, len);
47 | val out = zeros(data(0).nrows, nb - na);
48 | var i = 0;
49 | data.foreach(f => {
50 | if (i >= na && i < nb) out(?, i - na) = f;
51 | i += 1;
52 | })
53 | out
54 | }
55 | }
56 |
57 | def lastn(n0:Int) = {
58 | val len = data.synchronized {data.length};
59 | fromto(math.max(0, len - n0), len);
60 | }
61 | }
62 |
--------------------------------------------------------------------------------
/src/main/scala/BIDMach/viz/Visualization.scala:
--------------------------------------------------------------------------------
1 | package BIDMach.viz
2 | import BIDMach.models.Model;
3 | import BIDMat.Mat
4 |
5 |
6 | /**
7 | Abstract class for visualizations. Extend this class to get correct behavior
8 | */
9 |
10 | abstract class Visualization {
11 | var interval = 10;
12 | var cnt = 0
13 | var checkStatus = -1
14 |
15 | def doUpdate(model:Model,mats:Array[Mat],ipass:Int, pos:Long)
16 |
17 | //Perform some initial check to make sure data type is correct
18 | def check(model:Model,mats:Array[Mat]):Int = 0
19 |
20 | //Initialize variables and states during the first update.
21 | def init(model:Model,mats:Array[Mat]) {}
22 |
23 | //Update the visualization per cnt batches
24 | def update(model:Model,mats:Array[Mat],ipass:Int, pos:Long){
25 | if (checkStatus == -1){
26 | checkStatus = check(model, mats)
27 | if (checkStatus == 0) init(model, mats)
28 | }
29 | if (checkStatus == 0) {
30 | if (cnt == 0) {
31 | //doUpdate(model, mats, ipass, pos)
32 | try {
33 | doUpdate(model, mats, ipass, pos)
34 | }
35 | catch {
36 | case e:Exception=> {
37 | checkStatus = 2
38 | println(e.toString)
39 | println(e.getStackTrace.mkString("\n"))
40 | }
41 | }
42 | }
43 | cnt = (cnt + 1) % interval
44 | }
45 | }
46 | }
47 |
--------------------------------------------------------------------------------
/src/test/scala/BIDMach/BIDMachSpec.scala:
--------------------------------------------------------------------------------
1 | package BIDMach
2 |
3 | import org.scalatest._
4 |
5 | abstract class BIDMachSpec extends FlatSpec
6 | with Matchers
7 | with BeforeAndAfterAll {
8 |
9 | override def beforeAll {
10 | BIDMat.Mat.checkMKL(false);
11 | }
12 |
13 | def assert_approx_eq(a: Array[Float], b: Array[Float], eps: Float = 1e-4f) = {
14 | (a, b).zipped foreach {
15 | case (x, y) => {
16 | val scale = (math.abs(x) + math.abs(y) + eps).toFloat;
17 | x / scale should equal ((y / scale) +- eps)
18 | }
19 | }
20 | }
21 |
22 | def assert_approx_eq_double(a: Array[Double], b: Array[Double], eps: Double = 1e-6f) = {
23 | (a, b).zipped foreach {
24 | case (x, y) => {
25 | val scale = (math.abs(x) + math.abs(y) + eps);
26 | x / scale should equal ((y / scale) +- eps)
27 | }
28 | }
29 | }
30 |
31 | }
32 |
--------------------------------------------------------------------------------