├── LICENSE ├── README.md ├── docs ├── 1.DeepDriver全解密-Introduction_of_ANN_in_DeepDriver.pdf ├── 2.DeepDriver全解密-Introduction_of_DNN_in_DeepDriver.pdf └── 3.DeepDriver全解密-Introduction_of_LSTM_in_DeepDriver.pdf ├── lib ├── JCublas-windows-x86_64.dll ├── JCublas2-windows-x86_64.dll ├── JCudaDriver-windows-x86_64.dll ├── JCudaRuntime-windows-x86_64.dll ├── JCufft-windows-x86_64.dll ├── JCurand-windows-x86_64.dll ├── JCusolver-windows-x86_64.dll ├── JCusparse-windows-x86_64.dll ├── jblas-1.2.4.jar ├── jcublas-0.7.5b.jar └── jcuda-0.7.5b.jar └── src └── deepDriver └── dl └── aml ├── ann ├── ANN.java ├── ANNCfg.java ├── ArtifactNeuroNetwork.java ├── ArtifactNeuroNetworkV2.java ├── BlasANN.java ├── IActivationFunction.java ├── ILayer.java ├── INeuroUnit.java ├── ISparseAutoEncoderCfg.java ├── InputParameters.java ├── Normalizer.java ├── SparseAutoEncoder.java ├── SparseAutoEncoderCfgFromANNCfg.java ├── imp │ ├── BlasLayerImp.java │ ├── LayerImp.java │ ├── LayerImpV2.java │ ├── LogicsticsActivationFunction.java │ ├── NeuroUnitImp.java │ ├── NeuroUnitImpV2.java │ ├── NeuroUnitImpV3.java │ ├── SparseAutoEncoderLayer.java │ ├── SparseAutoEncoderNeuro.java │ └── VONeuro.java └── test │ ├── ANNAreaSplittingTest1.java │ ├── ANNTest1.java │ ├── Test1.java │ └── TimeSeriesTest.java ├── attention └── SoftAttention.java ├── bn └── IBN.java ├── cart ├── Cart.java ├── DataSet.java ├── DataSetV2.java ├── DecisionNode.java ├── DecisionTree.java ├── Gbdt.java ├── GbdtParameter.java ├── RandomForest.java ├── TestCart.java └── TestGbdt.java ├── cnn ├── ActivationFactory.java ├── BlasCNNBpVisitor.java ├── BlasCNNFdVisitor.java ├── BlasCNNWwUpdateVisitor.java ├── CNNArchitecture.java ├── CNNBP.java ├── CNNBP4MaxPooling.java ├── CNNConfigurator.java ├── CNNFV4MaxPooling.java ├── CNNForwardVisitor.java ├── CNNLayer.java ├── CNNLayer2ANNAdapter.java ├── CNNParaMerger.java ├── CNNReconstructionFeatureMap.java ├── CNNReconstructionLayer.java ├── CNNUtils.java ├── CNNWwUpdateVisitor.java ├── CacheAbleDataStream.java ├── ConvAeBP.java ├── ConvAeFV.java ├── ConvAeWwUpdater.java ├── ConvAutoEncoder.java ├── ConvDeepNet.java ├── ConvolutionKernal.java ├── ConvolutionNeuroNetwork.java ├── DataMatrix.java ├── FeatureMap.java ├── FeatureMapTag.java ├── FlatAcf.java ├── FractalBlock.java ├── ICNNBP.java ├── ICNNLayer.java ├── ICNNLayerVisitor.java ├── IConvolutionKernal.java ├── IDataMatrix.java ├── IDataStream.java ├── IDataStreamPiples.java ├── IFeatureMap.java ├── IFractalBlock.java ├── LayerConfigurator.java ├── LeakyReLU.java ├── MaxOut.java ├── ReLU.java ├── SamplingFeatureMap.java ├── SamplingLayer.java ├── SamplingReconstructionFeatureMap.java ├── SamplingReconstructionLayer.java ├── SubSamplingKernal.java ├── cae │ └── IConvAutoEncoderLayerVisitor.java ├── distribution │ ├── CNNMaster.java │ ├── CNNSlave.java │ ├── DataStreamDistUtil.java │ └── test │ │ └── TestCNNSlave.java ├── img │ ├── CsvImgLoader.java │ ├── Img2Matrix.java │ ├── ImgDataStream.java │ ├── W2VDataStream.java │ ├── W2VDataStreamV2.java │ ├── W2VDataStreamV24Test.java │ └── W2VDirectStream.java ├── nets │ ├── LeNet.java │ └── VggNet.java ├── test │ ├── DataMetrics.java │ ├── HelloVo.java │ ├── SingleResult.java │ ├── TestHello.java │ ├── TestHwrCNN.java │ ├── TestHwrVGG.java │ ├── TestTxtClassification.java │ ├── TestTxtClassificationV2.java │ └── TxtAnalyzer.java └── txt │ ├── DataPreparation.java │ └── Word2vecUtil.java ├── cnn2lstm ├── CNN2LSTMBPTT.java ├── CNN2LSTMTeacher.java └── test │ └── TestCNN2LSTM.java ├── common ├── CommonArch.java ├── ICommonLayer.java ├── ICommonLayerConfigurator.java ├── ICommonModel.java ├── distribution │ ├── CommonSlave.java │ ├── DistributionMaster.java │ ├── Job.java │ ├── Linkable.java │ └── LinkableDataStream.java └── test │ └── TestCommonSlave.java ├── contrib └── MNIST │ ├── MnistCNN.java │ ├── MnistDataStream.java │ └── MnistLoader.java ├── costFunction ├── CostFunctionFactory.java ├── DummyCostFunction.java ├── ICostFunction.java ├── MTLCostFunction.java ├── PositiveTask.java ├── SoftMax4ANN.java └── Task.java ├── distribution ├── AsycMaster.java ├── AsycSlave.java ├── AsycSlaveServeThread.java ├── ClientVo.java ├── CommandFilter.java ├── CommandFilterManager.java ├── DistributionEnvCfg.java ├── Error.java ├── Fs.java ├── ITask.java ├── Master.java ├── P2PBase.java ├── P2PClient.java ├── P2PServer.java ├── ResourceMaster.java ├── Slave.java ├── modelParallel │ ├── PartialCallback.java │ └── ThreadParallel.java └── test │ ├── HelloCnt.java │ ├── HelloWrapper.java │ ├── TestMaster.java │ ├── TestSlave.java │ └── cl │ ├── Client.java │ ├── Employee.java │ └── Server.java ├── dnc ├── DNC.java ├── DNCBPTT.java ├── DNCChecker.java ├── DNCConfigurator.java ├── DNCController.java ├── DNCMemory.java ├── DNCReadHead.java ├── DNCWriteHead.java ├── ITxtStream.java └── test │ └── babi │ ├── BabiStream.java │ ├── FullPBabiStream.java │ ├── Paragraph.java │ └── TestBabi.java ├── dnn ├── DNN.java ├── DNN4Stream.java ├── distribute │ ├── ANNMaster.java │ ├── ANNSlave.java │ ├── DNNDistUtils.java │ ├── DNNMaster.java │ ├── DNNSlave.java │ └── test │ │ └── DNNSlaveTest.java └── test │ ├── StreamAdapter.java │ └── TestDNN4Hwr.java ├── fnn ├── FractalNet.java └── test │ └── TestHwrFractalNet.java ├── linearReg ├── GradientDecentOptimizer.java ├── ISubject2Optimized.java ├── LinearExpression.java ├── LinearFunctionSubject.java ├── LinearRegression.java ├── ParameterScaler.java ├── TestLinearReg.java └── test │ └── TestSelfProduced.java ├── lrate ├── BoldDriverLearningRateManager.java ├── LearningRateManager.java └── StepReductionLR.java ├── lstm ├── BPTT.java ├── BPTT4MultThreads.java ├── BiBPTT.java ├── BiCell.java ├── BiLstmLayer.java ├── BiProjectionLayer2.java ├── BiRNNLayer.java ├── BiRNNNeuroVo.java ├── Context.java ├── ContextLayer.java ├── CxtLeverager.java ├── CxtLeverager4S2sTraining.java ├── GradientNormalizer.java ├── IBPTT.java ├── IBlock.java ├── ICell.java ├── ICxtConsumer.java ├── IForgetGate.java ├── IInputGate.java ├── ILSTMNeuro.java ├── IOutputGate.java ├── IPreCxtProvider.java ├── IRNNLayer.java ├── IRNNLayerVisitor.java ├── IRNNNeuroVo.java ├── IStream.java ├── ITest.java ├── ITimePeriod.java ├── LSTM.java ├── LSTMCfgCleaner.java ├── LSTMConfigurator.java ├── LSTMDataSet.java ├── LSTMDeltaWwFromWwUpdater.java ├── LSTMDeltaWwUpdater.java ├── LSTMLayer.java ├── LSTMLayerV2.java ├── LSTMWwArrayTranslator.java ├── LSTMWwFresher.java ├── LSTMWwUpdater.java ├── LSTMXwWUpdater.java ├── LayerCfg.java ├── LstmAttention.java ├── NeuroNetworkArchitecture.java ├── PosValue.java ├── PreCxtProvider.java ├── ProjectionLayer.java ├── RNNLayer.java ├── RNNNeuroVo.java ├── RecurrentNeuroNetwork.java ├── Seq2SeqLSTM.java ├── Seq2SeqLSTMConfigurator.java ├── SimpleNeuroVo.java ├── apps │ ├── ner │ │ ├── NerDataLoader.java │ │ ├── NerStream.java │ │ ├── NerTagger.java │ │ └── test │ │ │ └── NerTaggerVerify.java │ ├── pos │ │ ├── PosDataLoader.java │ │ ├── PosStream.java │ │ ├── PosTagger.java │ │ └── test │ │ │ ├── PosTaggerTest.java │ │ │ └── PosTaggerVerify.java │ ├── util │ │ ├── Embedding.java │ │ ├── FeatureFactory.java │ │ ├── GZIPUtil.java │ │ ├── StringUtils.java │ │ └── TaggedWord.java │ └── wordSegmentation │ │ ├── WordSegSet.java │ │ ├── WordSegSetV2.java │ │ ├── WordSegment.java │ │ ├── WordSegmentationStream.java │ │ └── test │ │ ├── EvaluateWsOnFlatDs.java │ │ ├── TestWordSegment.java │ │ └── VerifyWordSegment.java ├── attentionEnDecoder │ ├── AttentionCfg.java │ ├── AttentionEnDecoderBPTT.java │ ├── AttentionEnDecoderLSTM.java │ └── test │ │ ├── AttEn2DeSetup.java │ │ └── TestAttentionEnDecoderQABbSystem.java ├── beamSearch │ ├── BeamLayer.java │ ├── BeamNode.java │ └── BeamSearch.java ├── bidirection │ └── test │ │ └── TestBiLstmWS.java ├── conversation │ ├── Encoder2DecoderConversation.java │ └── Seq2SeqConversation.java ├── data │ ├── CfgDataCleaner.java │ ├── CfgDataTransfer.java │ └── LSTMCfgData.java ├── distribution │ ├── LSTMMaster.java │ ├── LSTMSlave.java │ ├── Seq2SeqAsycMasterV6.java │ ├── Seq2SeqAsycSlaveV6.java │ ├── Seq2SeqAsycSlaveV6Thread.java │ ├── Seq2SeqLSTMBoostrapper.java │ ├── Seq2SeqLSTMSetup.java │ ├── Seq2SeqMaster.java │ ├── Seq2SeqMasterV2.java │ ├── Seq2SeqMasterV3.java │ ├── Seq2SeqMasterV5.java │ ├── Seq2SeqSlave.java │ ├── Seq2SeqSlaveV2.java │ ├── Seq2SeqSlaveV3.java │ ├── Seq2SeqSlaveV5.java │ ├── SimpleTask.java │ ├── Verifier.java │ └── test │ │ ├── Heoo.java │ │ ├── Seq2SeqLSTMSetup.java │ │ ├── Test.java │ │ ├── TestLSTMSlave.java │ │ ├── TestMultipleSeq2SeqAsycSlave.java │ │ ├── TestPath.java │ │ ├── TestQAMaster.java │ │ ├── TestQAWorker.java │ │ ├── TestSeq2SeqAsycMaster.java │ │ ├── TestSeq2SeqAsycSlave.java │ │ ├── TestSeq2SeqMaster.java │ │ ├── TestSeq2SeqMaster4Srv48.java │ │ ├── TestSeq2SeqSlave.java │ │ ├── TestSeq2SeqSlave4Srv48.java │ │ └── TestSerializedS2S.java ├── enDecoder │ ├── EncoderDecoderBPTT.java │ ├── EncoderDecoderLSTM.java │ └── test │ │ ├── Encoder2DecoderSetup.java │ │ ├── TestEnDeQA.java │ │ └── TestQAEnDecoderBabySystem.java ├── hred │ └── HredBPTT.java ├── imp │ ├── Block.java │ ├── Cell.java │ ├── ForgetGate.java │ ├── InputGate.java │ ├── LSTMNeuro.java │ ├── OutputGate.java │ ├── TanhAf.java │ └── TimePreriod.java ├── lstm2Ann │ ├── Lstm2AnnBPTT.java │ ├── Lstm2AnnTeacher.java │ └── test │ │ ├── TestLstm2Ann.java │ │ ├── VerifyLstm2Ann.java │ │ └── WordSegWindowStream.java └── test │ ├── Seq2SeqBabySysSetup.java │ ├── Test1.java │ ├── TestQASeq2SeqBabySystem.java │ ├── TestQASeq2SeqBabySystemWithRNN.java │ ├── TestQAe2eSystem.java │ ├── TestQa.java │ ├── TestS2S.java │ ├── TestSongLSTM.java │ ├── TestSongSeq2Seq.java │ ├── TestSongSeq2SeqFull.java │ └── TestTsLSTM.java ├── math ├── BlasMathFunction.java ├── ContentBasedWeighting.java ├── IExp.java ├── IExp4Function.java ├── IMathFunction.java ├── IMatrixExp.java ├── JCudaBlasMathFunction.java ├── LinearExp.java ├── LinearMatrixExp.java ├── LinearRegression.java ├── MathUtil.java ├── MathUtil4MThreads.java ├── MathUtilBase.java ├── OnePlusExp.java ├── SigmodExp.java ├── SigmodMatrixExp.java ├── SoftMaxExp.java └── test │ ├── Test.java │ ├── TestCBW.java │ ├── TestCos.java │ ├── TestData.java │ ├── TestIExp4Function.java │ ├── TestJcuBlas.java │ ├── TestMathUtils.java │ ├── TestMatrix.java │ ├── TestNormalDistribution.java │ ├── TestSoftMaxExp.java │ └── TestSummary.java ├── random └── RandomFactory.java ├── resNet ├── ResNet.java └── test │ └── TestResNet.java ├── rn ├── RN4DNN.java ├── RelationConnCostFunction.java ├── RelationObject.java ├── RelationObjectSet.java └── test │ └── TestDrama4RNDNNMTL.java ├── sa └── SA.java ├── stream └── IWordStream.java ├── string ├── ANFixedStreamImpV2.java ├── Dictionary.java ├── NFixedStreamImp.java ├── NFixedStreamImpV2.java ├── RandomQNFixedStreamImpV2.java ├── StreamImp.java ├── ThinRandomANFixedStreamImpV2.java └── ThinRandomQNFixedStreamImpV2.java ├── utils └── AccuracyCaculator.java └── w2v ├── KeyCntPair.java ├── NegtiveSampling.java ├── W2V.java ├── Window4WordSegStream.java └── test ├── TestNegativeSampling.java └── VerifyW2v.java /README.md: -------------------------------------------------------------------------------- 1 | # DeepDriver 2 | This is a Deep Learning framework projects written with JAVA, including: 3 | - ANN, the forward feed neural network 4 | - DNN, a deep ANN 5 | - CNN, the convNet, including LetNet, VGG, ResNet, FNN, and so on. 6 | - RNN/LSTM, used for sequential strings handeling 7 | - Bi-LSTM 8 | - W2V supports also is added already 9 | - Encoder-Decoder framework and so on. 10 | - add supporting for DNC(Differential Neural Computer) DeepDriver/src/deepDriver/dl/aml/dnc/ 11 | - add MTL(Multiple Task Learning) supporting for ANN/DNN, so it is easy to extend for CNN and other neural network also. 12 | - add RN(Relation Net) supporting. 13 | - add GRL supporting. 14 | 15 | This framework provides some examples: 16 | - NLP: Chinese word segmentation 17 | - CNN txt classification 18 | - QA samples 19 | - Babi Testing 20 | -------------------------------------------------------------------------------- /docs/1.DeepDriver全解密-Introduction_of_ANN_in_DeepDriver.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LongJunCai/DeepDriver/9007e10aae183f960f009fa1dc2ab6fb716bd438/docs/1.DeepDriver全解密-Introduction_of_ANN_in_DeepDriver.pdf -------------------------------------------------------------------------------- /docs/2.DeepDriver全解密-Introduction_of_DNN_in_DeepDriver.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LongJunCai/DeepDriver/9007e10aae183f960f009fa1dc2ab6fb716bd438/docs/2.DeepDriver全解密-Introduction_of_DNN_in_DeepDriver.pdf -------------------------------------------------------------------------------- /docs/3.DeepDriver全解密-Introduction_of_LSTM_in_DeepDriver.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LongJunCai/DeepDriver/9007e10aae183f960f009fa1dc2ab6fb716bd438/docs/3.DeepDriver全解密-Introduction_of_LSTM_in_DeepDriver.pdf -------------------------------------------------------------------------------- /lib/JCublas-windows-x86_64.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LongJunCai/DeepDriver/9007e10aae183f960f009fa1dc2ab6fb716bd438/lib/JCublas-windows-x86_64.dll -------------------------------------------------------------------------------- /lib/JCublas2-windows-x86_64.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LongJunCai/DeepDriver/9007e10aae183f960f009fa1dc2ab6fb716bd438/lib/JCublas2-windows-x86_64.dll -------------------------------------------------------------------------------- /lib/JCudaDriver-windows-x86_64.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LongJunCai/DeepDriver/9007e10aae183f960f009fa1dc2ab6fb716bd438/lib/JCudaDriver-windows-x86_64.dll -------------------------------------------------------------------------------- /lib/JCudaRuntime-windows-x86_64.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LongJunCai/DeepDriver/9007e10aae183f960f009fa1dc2ab6fb716bd438/lib/JCudaRuntime-windows-x86_64.dll -------------------------------------------------------------------------------- /lib/JCufft-windows-x86_64.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LongJunCai/DeepDriver/9007e10aae183f960f009fa1dc2ab6fb716bd438/lib/JCufft-windows-x86_64.dll -------------------------------------------------------------------------------- /lib/JCurand-windows-x86_64.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LongJunCai/DeepDriver/9007e10aae183f960f009fa1dc2ab6fb716bd438/lib/JCurand-windows-x86_64.dll -------------------------------------------------------------------------------- /lib/JCusolver-windows-x86_64.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LongJunCai/DeepDriver/9007e10aae183f960f009fa1dc2ab6fb716bd438/lib/JCusolver-windows-x86_64.dll -------------------------------------------------------------------------------- /lib/JCusparse-windows-x86_64.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LongJunCai/DeepDriver/9007e10aae183f960f009fa1dc2ab6fb716bd438/lib/JCusparse-windows-x86_64.dll -------------------------------------------------------------------------------- /lib/jblas-1.2.4.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LongJunCai/DeepDriver/9007e10aae183f960f009fa1dc2ab6fb716bd438/lib/jblas-1.2.4.jar -------------------------------------------------------------------------------- /lib/jcublas-0.7.5b.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LongJunCai/DeepDriver/9007e10aae183f960f009fa1dc2ab6fb716bd438/lib/jcublas-0.7.5b.jar -------------------------------------------------------------------------------- /lib/jcuda-0.7.5b.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LongJunCai/DeepDriver/9007e10aae183f960f009fa1dc2ab6fb716bd438/lib/jcuda-0.7.5b.jar -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/ann/ANNCfg.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.ann; 2 | 3 | import java.io.Serializable; 4 | 5 | public class ANNCfg implements Serializable { 6 | 7 | /** 8 | * 9 | */ 10 | private static final long serialVersionUID = 1L; 11 | 12 | double dropOut = 0; 13 | 14 | boolean isTesting; 15 | 16 | int threadsNum = 1; 17 | 18 | public int getThreadsNum() { 19 | return threadsNum; 20 | } 21 | 22 | public void setThreadsNum(int threadsNum) { 23 | this.threadsNum = threadsNum; 24 | } 25 | 26 | public boolean isTesting() { 27 | return isTesting; 28 | } 29 | 30 | public void setTesting(boolean isTesting) { 31 | this.isTesting = isTesting; 32 | } 33 | 34 | public double getDropOut() { 35 | return dropOut; 36 | } 37 | 38 | public void setDropOut(double dropOut) { 39 | this.dropOut = dropOut; 40 | } 41 | 42 | 43 | } 44 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/ann/BlasANN.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.ann; 2 | 3 | public class BlasANN { 4 | ArtifactNeuroNetwork ann; 5 | double [][] aAs = null; 6 | double [][][] wWs = null; 7 | double [][][] dwWs = null; 8 | public void trainModel(InputParameters parameters) { 9 | int ln = getLayerNum(); 10 | if (aAs == null) { 11 | aAs = new double[ln][]; 12 | wWs = new double[ln][][]; 13 | dwWs = new double[ln][][]; 14 | } 15 | 16 | 17 | } 18 | 19 | public int getLayerNum() { 20 | ILayer layer = ann.getFirstLayer(); 21 | return 0; 22 | } 23 | 24 | } 25 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/ann/IActivationFunction.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.ann; 2 | 3 | public interface IActivationFunction { 4 | 5 | public double activate(double x); 6 | 7 | public double deActivate(double x); 8 | 9 | } 10 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/ann/ILayer.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.ann; 2 | 3 | import java.util.List; 4 | 5 | public interface ILayer { 6 | 7 | public int getPos(); 8 | 9 | public void setPos(int pos); 10 | 11 | public void setNextLayer(ILayer iLayer); 12 | 13 | public void setPreviousLayer(ILayer iLayer); 14 | 15 | public ILayer getNextLayer(); 16 | 17 | public ILayer getPreviousLayer(); 18 | 19 | public void addNeuro(INeuroUnit neuro); 20 | 21 | public List getNeuros(); 22 | 23 | public void buildup(ILayer previousLayer, double [][] input, IActivationFunction acf 24 | , boolean isLastLayer, int neuroCount); 25 | 26 | public void forwardPropagation(double [][] input); 27 | 28 | public void backPropagation(double [][] finalResult, InputParameters parameters); 29 | 30 | public void updateNeuros(); 31 | 32 | public double getStdError(double [][] result); 33 | } 34 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/ann/INeuroUnit.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.ann; 2 | 3 | import java.util.List; 4 | 5 | public interface INeuroUnit { 6 | //theta[], result [] 7 | public double getAaz(int dataIndex); 8 | 9 | // public int getInputSize(); 10 | // public double[] getDeltaZ(); 11 | 12 | public double[] getThetas(); 13 | 14 | public double get4PropagationPreviousDelta(int dataIndex, int previouNeuroIndex); 15 | 16 | public void setActivationFunction(IActivationFunction activationFunction); 17 | 18 | // public void input(List neuros); 19 | 20 | public void forwardPropagation(List previousNeuros, double [][] inputs); 21 | 22 | public void backPropagation(List previousNeuros, List nextNeuros, double [][] finalResult, InputParameters parameters); 23 | 24 | public void buildup(List previousNeuros, double [][] input, int position); 25 | 26 | public void updateSelf(); 27 | } 28 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/ann/ISparseAutoEncoderCfg.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.ann; 2 | 3 | public interface ISparseAutoEncoderCfg { 4 | 5 | public double getP(); 6 | 7 | public void setP(double p); 8 | 9 | public double getBeta(); 10 | 11 | public void setBeta(double beta); 12 | 13 | } 14 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/ann/SparseAutoEncoder.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.ann; 2 | 3 | import java.io.Serializable; 4 | 5 | import deepDriver.dl.aml.ann.imp.SparseAutoEncoderLayer; 6 | 7 | public class SparseAutoEncoder extends ArtifactNeuroNetworkV2 implements Serializable { 8 | 9 | /** 10 | * 11 | */ 12 | private static final long serialVersionUID = 1L; 13 | public SparseAutoEncoder() { 14 | this.aNNCfg = new SparseAutoEncoderCfgFromANNCfg(); 15 | } 16 | 17 | public ILayer createLayer() { 18 | SparseAutoEncoderLayer layer = new SparseAutoEncoderLayer(); 19 | layer.setaNNCfg(aNNCfg); 20 | if (cf != null) { 21 | layer.setCostFunction(cf); 22 | } 23 | return layer; 24 | } 25 | 26 | } 27 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/ann/SparseAutoEncoderCfgFromANNCfg.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.ann; 2 | 3 | import java.io.Serializable; 4 | 5 | public class SparseAutoEncoderCfgFromANNCfg extends ANNCfg implements ISparseAutoEncoderCfg, Serializable { 6 | 7 | /** 8 | * 9 | */ 10 | private static final long serialVersionUID = 1L; 11 | 12 | double p = 0.05; 13 | 14 | double beta = 0.001; 15 | 16 | public double getP() { 17 | return p; 18 | } 19 | 20 | public void setP(double p) { 21 | this.p = p; 22 | } 23 | 24 | public double getBeta() { 25 | return beta; 26 | } 27 | 28 | public void setBeta(double beta) { 29 | this.beta = beta; 30 | } 31 | 32 | } 33 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/ann/imp/LogicsticsActivationFunction.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.ann.imp; 2 | 3 | import java.io.Serializable; 4 | 5 | import deepDriver.dl.aml.ann.IActivationFunction; 6 | 7 | public class LogicsticsActivationFunction implements IActivationFunction , Serializable { 8 | 9 | private static final long serialVersionUID = -331809861294942272L; 10 | 11 | //y = 1/(1+exp(-z)) 12 | @Override 13 | public double activate(double x) { 14 | return 1.0/(1.0+Math.exp(-x)); 15 | } 16 | 17 | @Override 18 | public double deActivate(double x) { 19 | return activate(x) * (1.0 - activate(x)) ; 20 | } 21 | 22 | public static void main(String[] args) { 23 | LogicsticsActivationFunction tanhAf = new LogicsticsActivationFunction(); 24 | System.out.println(tanhAf.activate(-1199999990)); 25 | System.out.println(tanhAf.deActivate(1199999990)); 26 | System.out.println(tanhAf.activate(1199999990)); 27 | System.out.println(tanhAf.deActivate(-1190000)); 28 | System.out.println(tanhAf.deActivate(0)); 29 | System.out.println(tanhAf.activate(0)); 30 | } 31 | 32 | } 33 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/ann/imp/NeuroUnitImpV2.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.ann.imp; 2 | 3 | public class NeuroUnitImpV2 extends NeuroUnitImp { 4 | 5 | /** 6 | * 7 | */ 8 | private static final long serialVersionUID = 1L; 9 | LayerImp layer; 10 | public NeuroUnitImpV2(LayerImp layer) { 11 | super(); 12 | this.layer = layer; 13 | } 14 | 15 | protected void initTheta() { 16 | double b = Math.pow(6.0/(double)(layer.getNeuros().size() + 17 | layer.getPreviousLayer().getNeuros().size()), 0.5); 18 | length = 2*b; 19 | min = -b; 20 | max = b; 21 | if (randomize) { 22 | for (int i = 0; i < thetas.length; i++) { 23 | thetas[i] = length * random.nextDouble() 24 | + min; 25 | } 26 | } 27 | } 28 | 29 | 30 | } 31 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/ann/imp/SparseAutoEncoderLayer.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.ann.imp; 2 | 3 | import java.io.Serializable; 4 | import java.util.List; 5 | 6 | import deepDriver.dl.aml.ann.INeuroUnit; 7 | import deepDriver.dl.aml.ann.ISparseAutoEncoderCfg; 8 | 9 | public class SparseAutoEncoderLayer extends LayerImpV2 implements Serializable { 10 | 11 | /** 12 | * 13 | */ 14 | private static final long serialVersionUID = 1L; 15 | 16 | public ISparseAutoEncoderCfg getSparseAutoEncoderCfg() { 17 | if (getaNNCfg() instanceof ISparseAutoEncoderCfg) { 18 | return (ISparseAutoEncoderCfg) getaNNCfg(); 19 | } 20 | return null; 21 | } 22 | 23 | public NeuroUnitImp createNeuroUnitImp() { 24 | return new SparseAutoEncoderNeuro(this); 25 | } 26 | 27 | int zZIndex = 0; 28 | @Override 29 | public double getStdError(double[][] result) { 30 | //since this is a sparse auto encoder, so assume it is a 3-layer one 31 | ISparseAutoEncoderCfg cfg = getSparseAutoEncoderCfg(); 32 | if (cfg == null || getPreviousLayer().getPreviousLayer() == null) { 33 | return super.getStdError(result); 34 | } 35 | List list = getPreviousLayer().getNeuros(); 36 | double kl = 0; 37 | double p = cfg.getP(); 38 | for (int i = 0; i < list.size(); i++) { 39 | INeuroUnit nu = list.get(i); 40 | 41 | // double [] aAs = nu.getAas(); 42 | // for (int j = 0; j < aAs.length; j++) { 43 | // 44 | // } 45 | double p1 = nu.getAaz(zZIndex); 46 | kl = kl + p * Math.log(p/p1) + (1 - p) * Math.log((1-p)/(1- p1)); 47 | } 48 | return super.getStdError(result) + cfg.getBeta() * kl; 49 | } 50 | 51 | } 52 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/ann/imp/SparseAutoEncoderNeuro.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.ann.imp; 2 | 3 | import java.util.List; 4 | 5 | import deepDriver.dl.aml.ann.INeuroUnit; 6 | import deepDriver.dl.aml.ann.ISparseAutoEncoderCfg; 7 | import deepDriver.dl.aml.ann.InputParameters; 8 | 9 | public class SparseAutoEncoderNeuro extends NeuroUnitImpV3 { 10 | 11 | public SparseAutoEncoderNeuro(LayerImp layer) { 12 | super(layer); 13 | } 14 | 15 | /** 16 | * 17 | */ 18 | private static final long serialVersionUID = 1L; 19 | 20 | public SparseAutoEncoderLayer getSparseAutoEncoderLayer() { 21 | return (SparseAutoEncoderLayer) layer; 22 | } 23 | 24 | public ISparseAutoEncoderCfg getSparseAutoEncoderCfg() { 25 | return getSparseAutoEncoderLayer().getSparseAutoEncoderCfg(); 26 | } 27 | 28 | @Override 29 | public void backPropagation(List previousNeuros, List nextNeuros, double [][] result, InputParameters parameters) { 30 | super.backPropagation(previousNeuros, nextNeuros, result, parameters); 31 | ISparseAutoEncoderCfg cfg = getSparseAutoEncoderCfg(); 32 | if (cfg == null) { 33 | return ; 34 | } 35 | if (nextNeuros == null) { 36 | } else { 37 | /*****/ 38 | if (layer.getPreviousLayer() == null) { 39 | return ; 40 | } 41 | for (int i = 0; i < deltaZ.length; i++) { 42 | double sumDelta = 0; 43 | double p = cfg.getP(); 44 | double p1 = aas[i]; 45 | sumDelta = - p/p1 + (1- p)/(1 - p1); 46 | deltaZ[i] = deltaZ[i] +cfg.getBeta() * (sumDelta) * activationFunction.deActivate(zzs[i]); 47 | } 48 | } 49 | 50 | } 51 | 52 | } 53 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/ann/imp/VONeuro.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.ann.imp; 2 | 3 | import java.io.Serializable; 4 | import java.util.List; 5 | 6 | import deepDriver.dl.aml.ann.IActivationFunction; 7 | import deepDriver.dl.aml.ann.INeuroUnit; 8 | import deepDriver.dl.aml.ann.InputParameters; 9 | 10 | public class VONeuro implements INeuroUnit ,Serializable { 11 | 12 | private static final long serialVersionUID = 2223264281984161799L; 13 | 14 | double [] results; 15 | public VONeuro(int inputSize) { 16 | results = new double[inputSize]; 17 | } 18 | public void setResult(int index, double result) { 19 | results[index] = result; 20 | } 21 | @Override 22 | public double getAaz(int dataIndex) { 23 | return results[dataIndex]; 24 | } 25 | 26 | @Override 27 | public double get4PropagationPreviousDelta(int dataIndex, 28 | int previouNeuroIndex) { 29 | return 0; 30 | } 31 | 32 | @Override 33 | public void setActivationFunction(IActivationFunction activationFunction) { 34 | 35 | } 36 | 37 | @Override 38 | public void forwardPropagation(List previousNeuros, 39 | double[][] inputs) { 40 | 41 | } 42 | 43 | @Override 44 | public void backPropagation(List previousNeuros,List nextNeuros, 45 | double[][] finalResult, InputParameters parameters) { 46 | } 47 | @Override 48 | public void buildup(List previousNeuros, double[][] input, int position) { 49 | 50 | } 51 | @Override 52 | public void updateSelf() { 53 | } 54 | @Override 55 | public double[] getThetas() { 56 | return null; 57 | } 58 | 59 | } 60 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/ann/test/Test1.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.ann.test; 2 | 3 | 4 | public class Test1 { 5 | 6 | /** 7 | * @param args 8 | */ 9 | public static void main(String[] args) { 10 | // TransformDataTypeEnum dataTypeEnum = TransformDataTypeEnum.valueOf("VV"); 11 | // System.out.println(dataTypeEnum.compareTo(TransformDataTypeEnum.BOXOFFICE)); 12 | // System.out.println(dataTypeEnum.keyValue()); 13 | } 14 | 15 | } 16 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/bn/IBN.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.bn; 2 | 3 | public interface IBN { 4 | 5 | 6 | 7 | } 8 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cart/DataSet.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cart; 2 | 3 | import java.io.Serializable; 4 | 5 | public class DataSet implements Serializable { 6 | double [][] dependentVars; 7 | String [] labels; 8 | double [] independentVars; 9 | 10 | public double[][] getDependentVars() { 11 | return dependentVars; 12 | } 13 | public void setDependentVars(double[][] dependentVars) { 14 | this.dependentVars = dependentVars; 15 | } 16 | public String[] getLabels() { 17 | return labels; 18 | } 19 | public void setLabels(String[] labels) { 20 | this.labels = labels; 21 | } 22 | public double[] getIndependentVars() { 23 | return independentVars; 24 | } 25 | public void setIndependentVars(double[] independentVars) { 26 | this.independentVars = independentVars; 27 | } 28 | 29 | } 30 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cart/DataSetV2.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cart; 2 | 3 | public class DataSetV2 extends DataSet { 4 | 5 | double [] results; 6 | 7 | public double[] getResults() { 8 | return results; 9 | } 10 | public void setResults(double[] results) { 11 | this.results = results; 12 | } 13 | 14 | } 15 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cart/DecisionNode.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cart; 2 | 3 | import java.io.Serializable; 4 | 5 | public class DecisionNode implements Serializable { 6 | int indexOfVars; 7 | double decisionCondition; 8 | double currentY; 9 | // double leftY; 10 | // double rightY; 11 | 12 | DecisionNode parent; 13 | 14 | DecisionNode leftNode; 15 | DecisionNode rightNode; 16 | 17 | DataSet trainingDataSet; 18 | 19 | double alpha = -1; 20 | 21 | public double getAlpha() { 22 | return alpha; 23 | } 24 | public void setAlpha(double alpha) { 25 | this.alpha = alpha; 26 | } 27 | public DataSet getTrainingDataSet() { 28 | return trainingDataSet; 29 | } 30 | public void setTrainingDataSet(DataSet trainingDataSet) { 31 | this.trainingDataSet = trainingDataSet; 32 | } 33 | public double getCurrentY() { 34 | return currentY; 35 | } 36 | public void setCurrentY(double currentY) { 37 | this.currentY = currentY; 38 | } 39 | public DecisionNode getParent() { 40 | return parent; 41 | } 42 | public void setParent(DecisionNode parent) { 43 | this.parent = parent; 44 | } 45 | // public double getLeftY() { 46 | // return leftY; 47 | // } 48 | // public void setLeftY(double leftY) { 49 | // this.leftY = leftY; 50 | // } 51 | // public double getRightY() { 52 | // return rightY; 53 | // } 54 | // public void setRightY(double rightY) { 55 | // this.rightY = rightY; 56 | // } 57 | public int getIndexOfVars() { 58 | return indexOfVars; 59 | } 60 | public void setIndexOfVars(int indexOfVars) { 61 | this.indexOfVars = indexOfVars; 62 | } 63 | public double getDecisionCondition() { 64 | return decisionCondition; 65 | } 66 | public void setDecisionCondition(double decisionCondition) { 67 | this.decisionCondition = decisionCondition; 68 | } 69 | public DecisionNode getLeftNode() { 70 | return leftNode; 71 | } 72 | public void setLeftNode(DecisionNode leftNode) { 73 | this.leftNode = leftNode; 74 | } 75 | public DecisionNode getRightNode() { 76 | return rightNode; 77 | } 78 | public void setRightNode(DecisionNode rightNode) { 79 | this.rightNode = rightNode; 80 | } 81 | 82 | } 83 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cart/DecisionTree.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cart; 2 | 3 | import java.io.Serializable; 4 | 5 | public class DecisionTree implements Serializable { 6 | DecisionNode root; 7 | 8 | public DecisionNode getRoot() { 9 | return root; 10 | } 11 | 12 | public void setRoot(DecisionNode root) { 13 | this.root = root; 14 | } 15 | 16 | 17 | } 18 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cart/GbdtParameter.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cart; 2 | 3 | import java.io.Serializable; 4 | 5 | public class GbdtParameter implements Serializable { 6 | Cart cart; 7 | double r; 8 | double [] currentTrainingInVars; 9 | double [] currentTestInVars; 10 | public Cart getCart() { 11 | return cart; 12 | } 13 | public void setCart(Cart cart) { 14 | this.cart = cart; 15 | } 16 | public double getR() { 17 | return r; 18 | } 19 | public void setR(double r) { 20 | this.r = r; 21 | } 22 | public double[] getCurrentTrainingInVars() { 23 | return currentTrainingInVars; 24 | } 25 | public void setCurrentTrainingInVars(double[] currentTrainingInVars) { 26 | this.currentTrainingInVars = currentTrainingInVars; 27 | } 28 | public double[] getCurrentTestInVars() { 29 | return currentTestInVars; 30 | } 31 | public void setCurrentTestInVars(double[] currentTestInVars) { 32 | this.currentTestInVars = currentTestInVars; 33 | } 34 | 35 | } 36 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cart/RandomForest.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cart; 2 | 3 | import java.io.Serializable; 4 | import java.util.ArrayList; 5 | import java.util.List; 6 | 7 | public class RandomForest implements Serializable { 8 | List cartList = new ArrayList(); 9 | public void train(DataSet trainingDs, DataSet testDs) { 10 | 11 | } 12 | 13 | // public double [] test(DataSet dataSet) { 14 | // 15 | // } 16 | 17 | } 18 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cart/TestCart.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cart; 2 | 3 | public class TestCart { 4 | 5 | public static void main(String[] args) { 6 | DataSet ds = creatDs(1); 7 | Cart cart = new Cart(); 8 | cart.trainTree(ds); 9 | double [] ys = cart.predict(ds); 10 | for (int i = 0; i < ys.length; i++) { 11 | System.out.println(ds.getLabels()[i]+","+ds.getIndependentVars()[i]+","+ys[i]); 12 | } 13 | DataSet ds1 = creatDs(0.75); 14 | cart.lookupBestTree(ds1); 15 | } 16 | 17 | public static DataSet creatDs(double coef) { 18 | int cnt = 5; 19 | int columns = 6; 20 | double [] [] vars = new double[cnt][columns]; 21 | double [] inVars = new double[cnt]; 22 | String [] lables = new String[cnt]; 23 | for (int i = 0; i < cnt; i++) { 24 | vars[i] = new double[columns]; 25 | for (int j = 0; j < columns; j++) { 26 | if (j == 0) { 27 | vars[i][j] = i ; 28 | } else { 29 | vars[i][j] = i + j; 30 | } 31 | 32 | } 33 | inVars[i] =coef * i+1 ; 34 | lables[i] = "la"+i; 35 | } 36 | DataSet ds = new DataSet(); 37 | ds.setDependentVars(vars); 38 | ds.setIndependentVars(inVars); 39 | ds.setLabels(lables); 40 | return ds; 41 | } 42 | 43 | } 44 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cart/TestGbdt.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cart; 2 | 3 | public class TestGbdt { 4 | 5 | public static void main(String[] args) { 6 | DataSet ds = creatDs(0.9); 7 | DataSet ds1 = creatDs(0.65); 8 | DataSet ds2 = creatDs(0.8); 9 | Gbdt gbdt = new Gbdt(); 10 | gbdt.train(ds, ds1); 11 | double [] ys = gbdt.test(ds2); 12 | for (int i = 0; i < ys.length; i++) { 13 | System.out.println(ds2.getLabels()[i]+","+ds2.getIndependentVars()[i]+","+ys[i]); 14 | } 15 | 16 | // cart.lookupBestTree(ds1); 17 | } 18 | 19 | public static DataSet creatDs(double coef) { 20 | int cnt = 5; 21 | int columns = 6; 22 | double [] [] vars = new double[cnt][columns]; 23 | double [] inVars = new double[cnt]; 24 | String [] lables = new String[cnt]; 25 | for (int i = 0; i < cnt; i++) { 26 | vars[i] = new double[columns]; 27 | for (int j = 0; j < columns; j++) { 28 | if (j == 0) { 29 | vars[i][j] = i ; 30 | } else { 31 | vars[i][j] = i + j; 32 | } 33 | 34 | } 35 | inVars[i] =coef * i+1 ; 36 | lables[i] = "la"+i; 37 | } 38 | DataSet ds = new DataSet(); 39 | ds.setDependentVars(vars); 40 | ds.setIndependentVars(inVars); 41 | ds.setLabels(lables); 42 | return ds; 43 | } 44 | 45 | } 46 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/ActivationFactory.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | 4 | import deepDriver.dl.aml.ann.IActivationFunction; 5 | import deepDriver.dl.aml.ann.imp.LogicsticsActivationFunction; 6 | import deepDriver.dl.aml.lstm.imp.TanhAf; 7 | 8 | public class ActivationFactory { 9 | 10 | IActivationFunction acf = new LogicsticsActivationFunction(); 11 | 12 | IActivationFunction flatAcf = new FlatAcf(); 13 | IActivationFunction reLU = new ReLU(); 14 | IActivationFunction tanh = new TanhAf(); 15 | 16 | static ActivationFactory af = new ActivationFactory(); 17 | public static ActivationFactory getAf() { 18 | return af; 19 | } 20 | public IActivationFunction getAcf() { 21 | return acf; 22 | } 23 | public void setAcf(IActivationFunction acf) { 24 | this.acf = acf; 25 | } 26 | public IActivationFunction getFlatAcf() { 27 | return flatAcf; 28 | } 29 | public void setFlatAcf(IActivationFunction flatAcf) { 30 | this.flatAcf = flatAcf; 31 | } 32 | public IActivationFunction getReLU() { 33 | return reLU; 34 | } 35 | public void setReLU(IActivationFunction reLU) { 36 | this.reLU = reLU; 37 | } 38 | public IActivationFunction getTanh() { 39 | return tanh; 40 | } 41 | public void setTanh(IActivationFunction tanh) { 42 | this.tanh = tanh; 43 | } 44 | 45 | } 46 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/BlasCNNWwUpdateVisitor.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | import deepDriver.dl.aml.math.MathUtil; 4 | 5 | public class BlasCNNWwUpdateVisitor implements ICNNLayerVisitor { 6 | 7 | protected CNNBP bp; 8 | 9 | public BlasCNNWwUpdateVisitor(CNNBP bp) { 10 | super(); 11 | this.bp = bp; 12 | } 13 | 14 | @Override 15 | public void visitCNNLayer(CNNLayer layer) { 16 | IFeatureMap [] fms = layer.getFeatureMaps(); 17 | for (int i = 0; i < fms.length; i++) { 18 | updateGlobalWws(fms[i]); 19 | } 20 | if (layer.getCkM() != null) { 21 | MathUtil.plus(layer.getCkM(), layer.getDckM(), layer.getCkM()); 22 | 23 | float [][] ckm = layer.getCkM(); 24 | int [][][] ckid = layer.getCkIds(); 25 | for (int i = 0; i < ckm.length; i++) { 26 | for (int j = 0; j < ckm[i].length; j++) { 27 | int[] pt = ckid[i][j]; 28 | IConvolutionKernal [] cks1 = fms[j].getKernals(); 29 | ConvolutionKernal ck = (ConvolutionKernal)cks1[pt[0]]; 30 | ck.wWs[pt[1]][pt[2]] = ckm[i][j]; 31 | } 32 | } 33 | } 34 | } 35 | 36 | private void updateGlobalWws(IFeatureMap fms) { 37 | fms.setGema(fms.getGema() + fms.getDgamma()); 38 | fms.setBeta(fms.getBeta() + fms.getDbeta()); 39 | if (!bp.useGlobalWeight) { 40 | return; 41 | } 42 | fms.setbB(fms.getbB() + fms.getDeltaBb()); 43 | } 44 | 45 | @Override 46 | public void visitPoolingLayer(SamplingLayer layer) { 47 | 48 | } 49 | 50 | @Override 51 | public void visitANNLayer(CNNLayer2ANNAdapter layer) { 52 | 53 | } 54 | 55 | } 56 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/CNNArchitecture.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | public class CNNArchitecture { 7 | 8 | List layerCfgs = new ArrayList(); 9 | 10 | public List getLayerCfgs() { 11 | return layerCfgs; 12 | } 13 | 14 | public void setLayerCfgs(List layerCfgs) { 15 | this.layerCfgs = layerCfgs; 16 | } 17 | 18 | public void addLayerCfg(LayerConfigurator lc) { 19 | layerCfgs.add(lc); 20 | } 21 | 22 | } 23 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/CNNFV4MaxPooling.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | import deepDriver.dl.aml.ann.IActivationFunction; 4 | 5 | public class CNNFV4MaxPooling extends CNNForwardVisitor { 6 | 7 | public void sampling(SubSamplingKernal ck, double [][] ffms, IFeatureMap t2fm, boolean begin, IActivationFunction acf) { 8 | SamplingFeatureMap sfm = (SamplingFeatureMap) t2fm; 9 | FeatureMapTag [][] fmts = sfm.getFmts(); 10 | for (int i = 0; i < t2fm.getFeatures().length; i++) { 11 | for (int j = 0; j < t2fm.getFeatures()[i].length; j++) { 12 | boolean init = false; 13 | double cs = 0; 14 | for (int j2 = 0; j2 < ck.ckRows; j2++) { 15 | for (int k = 0; k < ck.ckColumns; k++) { 16 | int fr = i * ck.ckRows + j2; 17 | int fc = j * ck.ckColumns + k; 18 | /*auto padding 19 | * **/ 20 | if (fr >= ffms.length || fc >= ffms[0].length) { 21 | continue; 22 | }/*auto padding 23 | * **/ 24 | double t = ffms[fr][fc]; 25 | if (!init) { 26 | cs = t; 27 | init = true; 28 | fmts[i][j].r = j2; 29 | fmts[i][j].c = k; 30 | } else { 31 | if (cs < t) { 32 | cs = t; 33 | fmts[i][j].r = j2; 34 | fmts[i][j].c = k; 35 | } 36 | } 37 | } 38 | } 39 | cs = cs * ck.wW; 40 | if (!bp.useGlobalWeight) { 41 | cs = cs + ck.b; 42 | } 43 | // double acs = acf.activate(cs); 44 | if (begin) { 45 | // t2fm.getFeatures()[i][j] = acs; 46 | t2fm.getzZs()[i][j] = cs; 47 | } else { 48 | // t2fm.getFeatures()[i][j] = t2fm.getFeatures()[i][j] + acs; 49 | t2fm.getzZs()[i][j] = t2fm.getzZs()[i][j] + cs; 50 | } 51 | } 52 | } 53 | } 54 | 55 | } 56 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/CNNReconstructionFeatureMap.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | import deepDriver.dl.aml.ann.IActivationFunction; 4 | 5 | public class CNNReconstructionFeatureMap extends FeatureMap { 6 | 7 | /** 8 | * 9 | */ 10 | private static final long serialVersionUID = 1L; 11 | 12 | public CNNReconstructionFeatureMap(ICNNLayer currentLayer, IActivationFunction acf, 13 | ICNNLayer previousLayer, int ckRows, int ckColumns, 14 | boolean isFullConnection, int[] previouFeatureMapSeq, int fmIndex) { 15 | super(currentLayer, acf, previousLayer, ckRows, ckColumns, isFullConnection, 16 | previouFeatureMapSeq, fmIndex); 17 | } 18 | 19 | public void resizeFeatures() { 20 | // pfm = cfm - 1 - padding + ck 21 | double [][] featureOfPrevious = previousLayer.getFeatureMaps()[0].getFeatures(); 22 | int padding = 2 * previousLayer.getLc().getPadding(); 23 | //asume step = 1, and no need padding. 24 | int r = - padding + featureOfPrevious.length + ckRows - 1; 25 | int c = - padding + featureOfPrevious[0].length + ckColumns - 1; 26 | initFeatures(r, c); 27 | } 28 | 29 | } 30 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/CNNReconstructionLayer.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | import deepDriver.dl.aml.cnn.cae.IConvAutoEncoderLayerVisitor; 4 | 5 | 6 | public class CNNReconstructionLayer extends CNNLayer { 7 | 8 | public CNNReconstructionLayer(LayerConfigurator lc, ICNNLayer previousLayer) { 9 | super(lc, previousLayer); 10 | } 11 | 12 | /** 13 | * 14 | */ 15 | private static final long serialVersionUID = 1L; 16 | 17 | public IFeatureMap createIFeatureMap() { 18 | int r = lc.getCkRows(); 19 | int c = lc.getCkColumns(); 20 | if (lc.getCks() != null) { 21 | r = lc.getCks()[fmIndex][0]; 22 | c = lc.getCks()[fmIndex][1]; 23 | } 24 | return new CNNReconstructionFeatureMap(this, 25 | lc.getAcf() == null? ActivationFactory.getAf().getReLU():lc.getAcf(), previousLayer, 26 | r, c, lc.isFullConnection, 27 | lc.isFullConnection ? null : lc.getFeatureMapAllocationMatrix()[fmIndex], fmIndex); 28 | } 29 | 30 | public void accept(ICNNLayerVisitor visitor) { 31 | IConvAutoEncoderLayerVisitor vi = (IConvAutoEncoderLayerVisitor)visitor; 32 | vi.visitCNNReconstructionLayer(this); 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/CacheAbleDataStream.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | import deepDriver.dl.aml.common.distribution.Linkable; 4 | 5 | public class CacheAbleDataStream implements IDataStream, Linkable { 6 | 7 | /** 8 | * 9 | */ 10 | private static final long serialVersionUID = 1L; 11 | IDataMatrix[][] idms; 12 | int cnt = 0; 13 | 14 | int pos; 15 | Linkable next; 16 | 17 | public CacheAbleDataStream(int capacity) { 18 | super(); 19 | this.idms = new IDataMatrix[capacity][]; 20 | } 21 | 22 | public int getCnt() { 23 | return cnt; 24 | } 25 | 26 | public void setCnt(int cnt) { 27 | this.cnt = cnt; 28 | } 29 | 30 | public Linkable getNext() { 31 | return next; 32 | } 33 | 34 | public void setNext(Linkable next) { 35 | this.next = next; 36 | } 37 | 38 | public void add(IDataMatrix[] idm) { 39 | idms[cnt ++] = idm; 40 | } 41 | 42 | @Override 43 | public IDataMatrix[] next() { 44 | return idms[pos++]; 45 | } 46 | 47 | @Override 48 | public IDataMatrix[] next(Object pos) { 49 | int ip = ((Integer)pos).intValue(); 50 | return idms[ip]; 51 | } 52 | 53 | @Override 54 | public boolean hasNext() { 55 | return pos < cnt; 56 | } 57 | 58 | @Override 59 | public boolean reset() { 60 | pos = 0; 61 | return true; 62 | } 63 | 64 | @Override 65 | public IDataStream[] splitStream(int segments) { 66 | return null; 67 | } 68 | 69 | @Override 70 | public int splitCnt(int segments) { 71 | return 0; 72 | } 73 | 74 | 75 | } 76 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/ConvAeWwUpdater.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | import deepDriver.dl.aml.cnn.cae.IConvAutoEncoderLayerVisitor; 4 | 5 | public class ConvAeWwUpdater extends CNNWwUpdateVisitor implements 6 | IConvAutoEncoderLayerVisitor { 7 | 8 | @Override 9 | public void visitCNNReconstructionLayer(CNNReconstructionLayer layer) { 10 | visitCNNLayer(layer); 11 | } 12 | 13 | @Override 14 | public void visitPoolingReconstructionLayer( 15 | SamplingReconstructionLayer layer) { 16 | visitPoolingLayer(layer); 17 | } 18 | 19 | } 20 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/ConvAutoEncoder.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | 4 | public class ConvAutoEncoder { 5 | 6 | CNNConfigurator cfg; 7 | int iterations; 8 | 9 | public void construct(LayerConfigurator [] lcs) { 10 | for (int i = 0; i < lcs.length; i++) { 11 | LayerConfigurator lc = lcs[i]; 12 | lc.setLast(i == lcs.length - 1); 13 | if (i == 0) { 14 | cfg.getLayers()[i] = createCNNLayer(lc, null); 15 | } else { 16 | cfg.getLayers()[i] = createCNNLayer(lc, cfg.getLayers()[i - 1]); 17 | } 18 | } 19 | } 20 | 21 | public ICNNLayer createCNNLayer(LayerConfigurator lc, ICNNLayer previous) { 22 | ICNNLayer layer = null; 23 | if (LayerConfigurator.CONVOLUTION_LAYER == lc.getType()) { 24 | layer = new CNNLayer(lc, previous); 25 | } else if (LayerConfigurator.POOLING_LAYER == lc.getType()) { 26 | layer = new SamplingLayer(lc, previous); 27 | } else if (LayerConfigurator.ANN_LAYER == lc.getType()) { 28 | layer = new CNNLayer2ANNAdapter(lc, previous); 29 | } else if (LayerConfigurator.CONV_RECONSTRUCTION_LAYER == lc.getType()) { 30 | layer = new CNNReconstructionLayer(lc, previous); 31 | } else if (LayerConfigurator.SAMPLING_RECONSTRUCTION_LAYER == lc.getType()) { 32 | layer = new SamplingReconstructionLayer(lc, previous); 33 | } 34 | return layer; 35 | } 36 | 37 | public void train(IDataStreamPiples idp, int iterations) { 38 | ConvAeBP convAeBP = new ConvAeBP(this.cfg); 39 | this.iterations = iterations; 40 | double error = 0; 41 | int cnt = 0; 42 | for (int i = 0; i < iterations; i++) { 43 | while (idp.hasNext()) { 44 | IDataMatrix [] ms = idp.next(); 45 | error = error + convAeBP.runTrainEpich(ms, ms[0].getTarget()); 46 | cnt ++; 47 | if (cnt % 4000 == 0) { 48 | System.out.println(""+cnt+" samples, the avg error is "+(error/(double)cnt)); 49 | } 50 | } 51 | } 52 | } 53 | 54 | } 55 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/ConvDeepNet.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | import java.util.List; 4 | 5 | 6 | public class ConvDeepNet extends ConvolutionNeuroNetwork { 7 | 8 | CNNArchitecture architecture; 9 | public void construct(CNNArchitecture architecture, CNNConfigurator cfg) { 10 | this.architecture = architecture; 11 | this.cfg = cfg; 12 | } 13 | 14 | public void train(IDataStream is, IDataStream tis) throws Exception { 15 | learnSelf(is); 16 | tuneFine(is); 17 | } 18 | int iterations = 10; 19 | int index = 0; 20 | public void learnSelf(IDataStream is) { 21 | List cfgs = architecture.getLayerCfgs(); 22 | cfg.layers = new ICNNLayer[cfgs.size()]; 23 | 24 | for (int i = 0; i < cfgs.size(); i++) { 25 | index = i; 26 | LayerConfigurator lc = cfgs.get(i); 27 | lc.setLast(i == cfgs.size() - 1); 28 | if (i == 0) { 29 | cfg.layers[i] = createCNNLayer(lc, null); 30 | } else { 31 | // cfg.layers[i] = createCNNLayer(lc, cfg.layers[i - 1]); 32 | ConvAutoEncoder cae = new ConvAutoEncoder(); 33 | cae.construct(new LayerConfigurator[] { 34 | replicate(cfgs.get(i - 1), false), 35 | replicate(cfgs.get(i), false), 36 | replicate(cfgs.get(i - 1), true)}); 37 | cae.train(new DataStreamPiples(), iterations); 38 | } 39 | 40 | } 41 | } 42 | 43 | public void forwardCNN() { 44 | 45 | } 46 | 47 | class DataStreamPiples implements IDataStreamPiples { 48 | 49 | @Override 50 | public IDataMatrix[] next() { 51 | forwardCNN(); 52 | 53 | // IDataMatrix[] 54 | return null; 55 | } 56 | 57 | @Override 58 | public boolean hasNext() { 59 | return false; 60 | } 61 | 62 | @Override 63 | public boolean reset() { 64 | return false; 65 | } 66 | 67 | } 68 | 69 | public LayerConfigurator replicate(LayerConfigurator lc, boolean reconstructed) { 70 | return null; 71 | } 72 | 73 | 74 | public void tuneFine(IDataStream is) { 75 | 76 | } 77 | 78 | 79 | 80 | public void test(IDataStream tis) { 81 | 82 | } 83 | 84 | 85 | } 86 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/ConvolutionKernal.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | import java.io.Serializable; 4 | 5 | public class ConvolutionKernal implements IConvolutionKernal, Serializable { 6 | 7 | /** 8 | * 9 | */ 10 | private static final long serialVersionUID = 1L; 11 | int fmapOfPreviousLayer; 12 | double [][] wWs; 13 | double b; 14 | 15 | double [][] detalwWs; 16 | double deltab; 17 | boolean [][] initDeltaZzs; 18 | boolean initB; 19 | 20 | 21 | public int getFmapOfPreviousLayer() { 22 | return fmapOfPreviousLayer; 23 | } 24 | public void setFmapOfPreviousLayer(int fmapOfPreviousLayer) { 25 | this.fmapOfPreviousLayer = fmapOfPreviousLayer; 26 | } 27 | public boolean[][] getInitDeltaZzs() { 28 | return initDeltaZzs; 29 | } 30 | public void setInitDeltaZzs(boolean[][] initDeltaZzs) { 31 | this.initDeltaZzs = initDeltaZzs; 32 | } 33 | public boolean isInitB() { 34 | return initB; 35 | } 36 | public void setInitB(boolean initB) { 37 | this.initB = initB; 38 | } 39 | 40 | } 41 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/DataMatrix.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | 4 | public class DataMatrix implements IDataMatrix { 5 | 6 | /** 7 | * 8 | */ 9 | private static final long serialVersionUID = 1L; 10 | double [] target; 11 | double result; 12 | 13 | double [][] matrix; 14 | 15 | public double[][] getMatrix() { 16 | return matrix; 17 | } 18 | 19 | public void setMatrix(double[][] matrix) { 20 | this.matrix = matrix; 21 | } 22 | 23 | public double[] getTarget() { 24 | return target; 25 | } 26 | 27 | public void setTarget(double[] target) { 28 | this.target = target; 29 | } 30 | 31 | public double getResult() { 32 | return result; 33 | } 34 | 35 | public void setResult(double result) { 36 | this.result = result; 37 | } 38 | 39 | 40 | 41 | } 42 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/FeatureMapTag.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | import java.io.Serializable; 4 | 5 | public class FeatureMapTag implements Serializable { 6 | 7 | /** 8 | * 9 | */ 10 | private static final long serialVersionUID = 1L; 11 | int r; 12 | int c; 13 | public int getR() { 14 | return r; 15 | } 16 | public void setR(int r) { 17 | this.r = r; 18 | } 19 | public int getC() { 20 | return c; 21 | } 22 | public void setC(int c) { 23 | this.c = c; 24 | } 25 | 26 | } 27 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/FlatAcf.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | import java.io.Serializable; 4 | 5 | import deepDriver.dl.aml.ann.IActivationFunction; 6 | 7 | public class FlatAcf implements IActivationFunction, Serializable { 8 | 9 | /** 10 | * 11 | */ 12 | private static final long serialVersionUID = 1L; 13 | 14 | @Override 15 | public double activate(double x) { 16 | return x; 17 | } 18 | 19 | @Override 20 | public double deActivate(double x) { 21 | return 1; 22 | } 23 | 24 | } 25 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/FractalBlock.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | 4 | public class FractalBlock extends CNNLayer implements IFractalBlock { 5 | 6 | /** 7 | * 8 | */ 9 | private static final long serialVersionUID = 1L; 10 | public FractalBlock(LayerConfigurator lc, ICNNLayer previousLayer) { 11 | this(lc, previousLayer, lc.getFblockDepth()); 12 | } 13 | public FractalBlock(LayerConfigurator lc, ICNNLayer previousLayer, int cDepth) { 14 | super(lc, previousLayer); 15 | this.currentDepth = cDepth; 16 | if (cDepth > 1) { 17 | this.resNet = lc.isResNetLayer(); 18 | if (!resNet) { 19 | directLayer = new CNNLayer(lc, previousLayer); 20 | } 21 | 22 | fbs = new FractalBlock[lc.getFblockLayerNum()]; 23 | for (int i = 0; i < fbs.length; i++) { 24 | if (i == 0) { 25 | fbs[i] = new FractalBlock(lc, previousLayer, cDepth - 1); 26 | } else { 27 | fbs[i] = new FractalBlock(lc, fbs[i - 1], cDepth - 1); 28 | } 29 | } 30 | } 31 | } 32 | boolean resNet; 33 | 34 | int currentDepth; 35 | CNNLayer directLayer; 36 | FractalBlock [] fbs; 37 | 38 | public boolean isResNet() { 39 | return resNet; 40 | } 41 | public void setResNet(boolean resNet) { 42 | this.resNet = resNet; 43 | } 44 | public int getCurrentDepth() { 45 | return currentDepth; 46 | } 47 | public void setCurrentDepth(int currentDepth) { 48 | this.currentDepth = currentDepth; 49 | } 50 | public CNNLayer getDirectLayer() { 51 | return directLayer; 52 | } 53 | public void setDirectLayer(CNNLayer directLayer) { 54 | this.directLayer = directLayer; 55 | } 56 | public FractalBlock[] getFbs() { 57 | return fbs; 58 | } 59 | public void setFbs(FractalBlock[] fbs) { 60 | this.fbs = fbs; 61 | } 62 | 63 | } 64 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/ICNNBP.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | public interface ICNNBP extends ICNNLayerVisitor { 4 | 5 | } 6 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/ICNNLayer.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | import java.io.Serializable; 4 | 5 | import deepDriver.dl.aml.costFunction.ICostFunction; 6 | 7 | public interface ICNNLayer extends Serializable { 8 | 9 | public IFeatureMap[] getFeatureMaps(); 10 | 11 | public double [] featureMaps2Vector(); 12 | 13 | public void accept(ICNNLayerVisitor visitor); 14 | 15 | public ICNNLayer getPreviousLayer(); 16 | 17 | public ICostFunction getCostFunction(); 18 | 19 | public LayerConfigurator getLc(); 20 | 21 | public void setLc(LayerConfigurator lc); 22 | 23 | } 24 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/ICNNLayerVisitor.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | public interface ICNNLayerVisitor { 4 | 5 | public void visitCNNLayer(CNNLayer layer); 6 | 7 | public void visitPoolingLayer(SamplingLayer layer); 8 | 9 | public void visitANNLayer(CNNLayer2ANNAdapter layer); 10 | 11 | } 12 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/IConvolutionKernal.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | import java.io.Serializable; 4 | 5 | public interface IConvolutionKernal extends Serializable { 6 | 7 | public int getFmapOfPreviousLayer(); 8 | 9 | public void setFmapOfPreviousLayer(int fmapOfPreviousLayer); 10 | 11 | } 12 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/IDataMatrix.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | import java.io.Serializable; 4 | 5 | public interface IDataMatrix extends Serializable { 6 | 7 | public double[][] getMatrix(); 8 | 9 | public void setMatrix(double[][] matrix); 10 | 11 | public double[] getTarget(); 12 | 13 | public void setTarget(double[] target); 14 | 15 | public double getResult(); 16 | 17 | public void setResult(double target); 18 | 19 | } 20 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/IDataStream.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | import java.io.Serializable; 4 | 5 | public interface IDataStream extends Serializable { 6 | 7 | public IDataMatrix [] next(); 8 | 9 | public IDataMatrix [] next(Object pos); 10 | 11 | public boolean hasNext(); 12 | 13 | public boolean reset(); 14 | 15 | public IDataStream [] splitStream(int segments); 16 | 17 | public int splitCnt(int segments); 18 | } 19 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/IDataStreamPiples.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | public interface IDataStreamPiples { 4 | 5 | public IDataMatrix [] next(); 6 | 7 | public boolean hasNext(); 8 | 9 | public boolean reset(); 10 | 11 | } 12 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/IFeatureMap.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | import java.io.Serializable; 4 | 5 | import deepDriver.dl.aml.ann.IActivationFunction; 6 | 7 | public interface IFeatureMap extends Serializable { 8 | 9 | public double[][] getFeatures(); 10 | 11 | public double[][] getDeltaZzs(); 12 | 13 | public boolean[][] getInitDeltaZzs(); 14 | 15 | public void setInitDeltaZzs(boolean[][] initDeltaZzs); 16 | 17 | public double[][] getzZs(); 18 | 19 | public IActivationFunction getAcf(); 20 | 21 | public IConvolutionKernal[] getKernals(); 22 | 23 | public double getbB(); 24 | 25 | public void setbB(double bB); 26 | 27 | public double getDeltaBb(); 28 | 29 | public void setDeltaBb(double deltaBb); 30 | 31 | public boolean isInitBb(); 32 | 33 | public void setInitBb(boolean initBb); 34 | 35 | public void initData(IDataMatrix dm); 36 | 37 | public void resizeFeatures(); 38 | 39 | public void reset(); 40 | 41 | public double getU(); 42 | 43 | public void setU(double u); 44 | 45 | public double getVar2(); 46 | 47 | public void setVar2(double var2); 48 | 49 | public double getGema(); 50 | 51 | public void setGema(double gema); 52 | 53 | public double getBeta(); 54 | 55 | public void setBeta(double beta); 56 | 57 | public double getE(); 58 | 59 | public void setE(double e); 60 | 61 | public double[][] getoZzs(); 62 | 63 | public void setoZzs(double[][] oZzs); 64 | 65 | public double getDgamma(); 66 | 67 | public void setDgamma(double dgamma); 68 | 69 | public double getDbeta(); 70 | 71 | public void setDbeta(double dbeta); 72 | 73 | public double getSumU(); 74 | 75 | public void setSumU(double sumU); 76 | 77 | public double getSumVar2(); 78 | 79 | public void setSumVar2(double sumVar2); 80 | 81 | public int getSamplesCnt(); 82 | 83 | public void setSamplesCnt(int samplesCnt); 84 | 85 | public Object[][] getLockObjs(); 86 | 87 | public void setLockObjs(Object[][] lockObjs); 88 | 89 | public int[] getfMckIdMap(); 90 | 91 | public void setfMckIdMap(int[] fMckIdMap); 92 | 93 | } 94 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/IFractalBlock.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | public interface IFractalBlock extends ICNNLayer { 4 | 5 | public int getCurrentDepth() ; 6 | 7 | public void setCurrentDepth(int currentDepth); 8 | 9 | public CNNLayer getDirectLayer() ; 10 | 11 | public void setDirectLayer(CNNLayer directLayer); 12 | 13 | public FractalBlock[] getFbs(); 14 | 15 | public void setFbs(FractalBlock[] fbs); 16 | 17 | 18 | } 19 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/LeakyReLU.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | import java.io.Serializable; 4 | 5 | import deepDriver.dl.aml.ann.IActivationFunction; 6 | 7 | public class LeakyReLU implements IActivationFunction, Serializable { 8 | 9 | /** 10 | * 11 | */ 12 | private static final long serialVersionUID = 1L; 13 | double a = 0.001; 14 | 15 | 16 | public double getA() { 17 | return a; 18 | } 19 | 20 | public void setA(double a) { 21 | this.a = a; 22 | } 23 | 24 | @Override 25 | public double activate(double x) { 26 | if (x < 0) { 27 | return x * a; 28 | } 29 | return x; 30 | } 31 | 32 | @Override 33 | public double deActivate(double x) { 34 | if (x < 0) { 35 | return a; 36 | } 37 | return 1; 38 | } 39 | 40 | } -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/MaxOut.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | import deepDriver.dl.aml.ann.IActivationFunction; 4 | 5 | public class MaxOut implements IActivationFunction { 6 | 7 | @Override 8 | public double activate(double x) { 9 | return 0; 10 | } 11 | 12 | @Override 13 | public double deActivate(double x) { 14 | return 0; 15 | } 16 | 17 | } 18 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/ReLU.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | import java.io.Serializable; 4 | 5 | import deepDriver.dl.aml.ann.IActivationFunction; 6 | 7 | public class ReLU implements IActivationFunction, Serializable { 8 | 9 | /** 10 | * 11 | */ 12 | private static final long serialVersionUID = 1L; 13 | 14 | @Override 15 | public double activate(double x) { 16 | if (x < 0) { 17 | return 0; 18 | } 19 | return x; 20 | } 21 | 22 | @Override 23 | public double deActivate(double x) { 24 | if (x < 0) { 25 | return 0; 26 | } 27 | return 1; 28 | } 29 | 30 | } 31 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/SamplingLayer.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | import java.io.Serializable; 4 | 5 | public class SamplingLayer extends CNNLayer implements Serializable { 6 | 7 | /** 8 | * 9 | */ 10 | private static final long serialVersionUID = 1L; 11 | 12 | public SamplingLayer(LayerConfigurator lc, ICNNLayer previousLayer) { 13 | super(lc, previousLayer); 14 | } 15 | 16 | public IFeatureMap createIFeatureMap() { 17 | int r = lc.getCkRows(); 18 | int c = lc.getCkColumns(); 19 | if (lc.getCks() != null) { 20 | r = lc.getCks()[fmIndex][0]; 21 | c = lc.getCks()[fmIndex][1]; 22 | } 23 | return new SamplingFeatureMap(this, 24 | lc.getAcf() == null? ActivationFactory.getAf().getTanh(): lc.getAcf() 25 | // lc.getAcf() == null? ActivationFactory.getAf().getReLU(): lc.getAcf() 26 | 27 | , previousLayer, 28 | r, c, lc.isFullConnection, 29 | lc.isFullConnection ? null : lc.getFeatureMapAllocationMatrix()[fmIndex], fmIndex); 30 | } 31 | 32 | @Override 33 | public void accept(ICNNLayerVisitor visitor) { 34 | visitor.visitPoolingLayer(this); 35 | } 36 | 37 | } 38 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/SamplingReconstructionFeatureMap.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | import deepDriver.dl.aml.ann.IActivationFunction; 4 | 5 | public class SamplingReconstructionFeatureMap extends SamplingFeatureMap { 6 | 7 | /** 8 | * 9 | */ 10 | private static final long serialVersionUID = 1L; 11 | 12 | public SamplingReconstructionFeatureMap(ICNNLayer currentLayer, IActivationFunction acf, 13 | ICNNLayer previousLayer, int ckRows, int ckColumns, 14 | boolean isFullConnection, int[] previouFeatureMapSeq, int fmIndex) { 15 | super(currentLayer, acf, previousLayer, ckRows, ckColumns, isFullConnection, 16 | previouFeatureMapSeq, fmIndex); 17 | } 18 | 19 | public void resizeFeatures() { 20 | double [][] featureOfPrevious = previousLayer.getFeatureMaps()[0].getFeatures(); 21 | //asume step = 1, and no need pending. 22 | int r = featureOfPrevious.length * ckRows; 23 | int c = featureOfPrevious[0].length * ckColumns; 24 | 25 | fmts = new FeatureMapTag[r][c]; 26 | for (int i = 0; i < fmts.length; i++) { 27 | for (int j = 0; j < fmts[i].length; j++) { 28 | fmts[i][j] = new FeatureMapTag(); 29 | } 30 | } 31 | initFeatures(r, c); 32 | } 33 | 34 | } 35 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/SamplingReconstructionLayer.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | import deepDriver.dl.aml.cnn.cae.IConvAutoEncoderLayerVisitor; 4 | 5 | public class SamplingReconstructionLayer extends SamplingLayer { 6 | 7 | /** 8 | * 9 | */ 10 | private static final long serialVersionUID = 1L; 11 | 12 | public SamplingReconstructionLayer(LayerConfigurator lc, 13 | ICNNLayer previousLayer) { 14 | super(lc, previousLayer); 15 | } 16 | 17 | public IFeatureMap createIFeatureMap() { 18 | int r = lc.getCkRows(); 19 | int c = lc.getCkColumns(); 20 | if (lc.getCks() != null) { 21 | r = lc.getCks()[fmIndex][0]; 22 | c = lc.getCks()[fmIndex][1]; 23 | } 24 | return new SamplingReconstructionFeatureMap (this, 25 | lc.getAcf() == null? ActivationFactory.getAf().getTanh(): lc.getAcf() 26 | , previousLayer, 27 | r, c, lc.isFullConnection, 28 | lc.isFullConnection ? null : lc.getFeatureMapAllocationMatrix()[fmIndex], fmIndex); 29 | } 30 | 31 | public void accept(ICNNLayerVisitor visitor) { 32 | IConvAutoEncoderLayerVisitor vi = (IConvAutoEncoderLayerVisitor)visitor; 33 | vi.visitPoolingReconstructionLayer(this); 34 | } 35 | 36 | } 37 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/SubSamplingKernal.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn; 2 | 3 | import java.io.Serializable; 4 | 5 | public class SubSamplingKernal implements IConvolutionKernal, Serializable { 6 | 7 | /** 8 | * 9 | */ 10 | private static final long serialVersionUID = 1L; 11 | double wW; 12 | double b; 13 | 14 | double deltawW; 15 | double deltab; 16 | 17 | boolean initwW; 18 | boolean initB; 19 | 20 | int ckRows; 21 | int ckColumns; 22 | 23 | int fmapOfPreviousLayer; 24 | 25 | public int getFmapOfPreviousLayer() { 26 | return fmapOfPreviousLayer; 27 | } 28 | public void setFmapOfPreviousLayer(int fmapOfPreviousLayer) { 29 | this.fmapOfPreviousLayer = fmapOfPreviousLayer; 30 | } 31 | 32 | } 33 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/cae/IConvAutoEncoderLayerVisitor.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn.cae; 2 | 3 | import deepDriver.dl.aml.cnn.CNNReconstructionLayer; 4 | import deepDriver.dl.aml.cnn.ICNNLayerVisitor; 5 | import deepDriver.dl.aml.cnn.SamplingReconstructionLayer; 6 | 7 | public interface IConvAutoEncoderLayerVisitor extends ICNNLayerVisitor { 8 | 9 | public void visitCNNReconstructionLayer(CNNReconstructionLayer layer); 10 | 11 | public void visitPoolingReconstructionLayer(SamplingReconstructionLayer layer); 12 | } 13 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/distribution/DataStreamDistUtil.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn.distribution; 2 | 3 | import deepDriver.dl.aml.cnn.CacheAbleDataStream; 4 | import deepDriver.dl.aml.cnn.IDataMatrix; 5 | import deepDriver.dl.aml.cnn.IDataStream; 6 | import deepDriver.dl.aml.common.distribution.CommonSlave; 7 | import deepDriver.dl.aml.distribution.ResourceMaster; 8 | 9 | public class DataStreamDistUtil { 10 | 11 | int cnt; 12 | int cap = 4096; 13 | 14 | public int getCap() { 15 | return cap; 16 | } 17 | 18 | public void setCap(int cap) { 19 | this.cap = cap; 20 | } 21 | 22 | public int getCnt() { 23 | return cnt; 24 | } 25 | 26 | public void setCnt(int cnt) { 27 | this.cnt = cnt; 28 | } 29 | 30 | public void distributeDs(IDataStream is, int num) throws Exception { 31 | ResourceMaster rm = ResourceMaster.getInstance(); 32 | is.reset(); 33 | CacheAbleDataStream [] iss = new CacheAbleDataStream[num]; 34 | for (int i = 0; i < iss.length; i++) { 35 | iss[i] = new CacheAbleDataStream(cap); 36 | } 37 | int i = 0; 38 | while (is.hasNext()) { 39 | cnt ++; 40 | IDataMatrix [] idm = is.next(); 41 | CacheAbleDataStream ids = iss[i++]; 42 | ids.add(idm); 43 | if (i > iss.length - 1) { 44 | i = 0; 45 | if (iss[i].getCnt() >= cap) { 46 | rm.distributeCommand(CommonSlave.CTASKPIECE); 47 | rm.distributeObjects(iss); 48 | 49 | iss = new CacheAbleDataStream[num]; 50 | for (int j = 0; j < iss.length; j++) { 51 | iss[j] = new CacheAbleDataStream(cap); 52 | } 53 | } 54 | } 55 | 56 | } 57 | if (iss[0].getCnt() < cap) { 58 | rm.distributeCommand(CommonSlave.CTASKPIECE); 59 | rm.distributeObjects(iss); 60 | } 61 | } 62 | 63 | } 64 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/distribution/test/TestCNNSlave.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn.distribution.test; 2 | 3 | import deepDriver.dl.aml.cnn.distribution.CNNSlave; 4 | import deepDriver.dl.aml.distribution.DistributionEnvCfg; 5 | import deepDriver.dl.aml.distribution.P2PServer; 6 | 7 | public class TestCNNSlave { 8 | 9 | public static void main(String[] args) throws Exception { 10 | String host = "127.0.0.1"; 11 | if (args != null && args.length >= 1) { 12 | host = args[0]; 13 | } 14 | System.out.println("Connet to Server: "+host); 15 | DistributionEnvCfg.getCfg().set(P2PServer.KEY_SRV_HOST, host); 16 | DistributionEnvCfg.getCfg().set(P2PServer.KEY_SRV_PORT, 8034); 17 | 18 | CNNSlave cnnSlave = new CNNSlave(); 19 | cnnSlave.train(); 20 | } 21 | 22 | } 23 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/img/CsvImgLoader.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn.img; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.File; 5 | import java.io.FileInputStream; 6 | import java.io.InputStreamReader; 7 | import java.io.Serializable; 8 | import java.util.ArrayList; 9 | import java.util.List; 10 | 11 | public class CsvImgLoader implements Serializable { 12 | 13 | /** 14 | * 15 | */ 16 | private static final long serialVersionUID = 1L; 17 | private List imgs = new ArrayList(); 18 | String header; 19 | 20 | public boolean isHeader() { 21 | return isHeader; 22 | } 23 | public void setHeader(boolean isHeader) { 24 | this.isHeader = isHeader; 25 | } 26 | 27 | public int size() { 28 | return imgs.size(); 29 | } 30 | 31 | public String get(int id) { 32 | return imgs.get(id); 33 | } 34 | 35 | public List getImgs() { 36 | return imgs; 37 | } 38 | public void setImgs(List imgs) { 39 | this.imgs = imgs; 40 | } 41 | 42 | 43 | boolean isHeader = true; 44 | public void loadImg(String file) throws Exception { 45 | BufferedReader bi = new BufferedReader( new InputStreamReader(new 46 | FileInputStream(new File(file)), "utf-8")); 47 | String content = bi.readLine(); 48 | while (content != null) { 49 | content = content.trim(); 50 | if (content.length() == 0) { 51 | content = bi.readLine(); 52 | continue; 53 | } 54 | if (isHeader) { 55 | isHeader = false; 56 | header = content; 57 | content = bi.readLine(); 58 | continue; 59 | } 60 | imgs.add(content); 61 | content = bi.readLine(); 62 | } 63 | bi.close(); 64 | } 65 | 66 | public void loadSingle(String content) throws Exception 67 | { 68 | if(content != null) 69 | { 70 | content = content.trim(); 71 | imgs.add(content); 72 | } 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/img/W2VDataStreamV24Test.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn.img; 2 | 3 | import deepDriver.dl.aml.cnn.IDataMatrix; 4 | 5 | public class W2VDataStreamV24Test extends W2VDataStreamV2 { 6 | 7 | public W2VDataStreamV24Test(CsvImgLoader imgLoader, int tLength, int rLength) { 8 | super(imgLoader, tLength, rLength); 9 | } 10 | 11 | public IDataMatrix [] next() { 12 | return new IDataMatrix [] {getIDataMatrix(cnt++)}; 13 | } 14 | 15 | 16 | } 17 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/img/W2VDirectStream.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn.img; 2 | 3 | import deepDriver.dl.aml.cnn.DataMatrix; 4 | import deepDriver.dl.aml.cnn.IDataMatrix; 5 | import deepDriver.dl.aml.cnn.IDataStream; 6 | 7 | public class W2VDirectStream implements IDataStream { 8 | 9 | float [][] data; 10 | float scaler = 1; 11 | 12 | public W2VDirectStream(float [][] data, float scaler) { 13 | super(); 14 | this.data = data; 15 | this.scaler = scaler; 16 | } 17 | 18 | int cnt = 0; 19 | double omax = 39.6717; 20 | double omin = -33.2884; 21 | double min = -1; 22 | double max = 1; 23 | 24 | public IDataMatrix [] next() { 25 | cnt ++; 26 | double [][] ndata = new double[data.length][]; 27 | DataMatrix dataMatrix = new DataMatrix(); 28 | dataMatrix.setMatrix(ndata); 29 | dataMatrix.setTarget(null); 30 | for (int i = 0; i < data.length; i++) { 31 | ndata[i] = new double[data[i].length]; 32 | for (int j = 0; j < data[i].length; j++) { 33 | ndata[i][j] = (data[i][j] * scaler - omin)/(omax - omin) * 34 | (max - min) + min; 35 | } 36 | } 37 | return new IDataMatrix [] {dataMatrix}; 38 | } 39 | 40 | @Override 41 | public boolean hasNext() { 42 | return cnt < 1; 43 | } 44 | 45 | @Override 46 | public boolean reset() { 47 | cnt = 0; 48 | return true; 49 | } 50 | 51 | @Override 52 | public IDataMatrix [] next(Object pos) { 53 | return null; 54 | } 55 | 56 | @Override 57 | public IDataStream[] splitStream(int segments) { 58 | return null; 59 | } 60 | 61 | @Override 62 | public int splitCnt(int segments) { 63 | // TODO Auto-generated method stub 64 | return 0; 65 | } 66 | 67 | } 68 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/nets/VggNet.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn.nets; 2 | 3 | public class VggNet { 4 | 5 | } 6 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/test/DataMetrics.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn.test; 2 | 3 | public class DataMetrics { 4 | 5 | double max; 6 | double min; 7 | double avg; 8 | double stdErr; 9 | double sum; 10 | int cnt; 11 | 12 | 13 | } 14 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/test/HelloVo.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn.test; 2 | 3 | import java.io.Serializable; 4 | 5 | public class HelloVo implements Serializable { 6 | 7 | /** 8 | * 9 | */ 10 | private static final long serialVersionUID = 1L; 11 | 12 | /** 13 | * 14 | */ 15 | protected int [] type = {1, 2, 3}; 16 | 17 | protected String name = "hello"; 18 | 19 | String wd = "world"; 20 | 21 | public String getName() { 22 | return name; 23 | } 24 | 25 | public void setName(String name) { 26 | this.name = name; 27 | } 28 | 29 | public String getWd() { 30 | return wd; 31 | } 32 | 33 | public void setWd(String wd) { 34 | this.wd = wd; 35 | } 36 | 37 | 38 | } 39 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/test/SingleResult.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn.test; 2 | 3 | public class SingleResult { 4 | 5 | private int label; 6 | private double prob; 7 | public int getLabel() { 8 | return label; 9 | } 10 | public void setLabel(int label) { 11 | this.label = label; 12 | } 13 | public double getProb() { 14 | return prob; 15 | } 16 | public void setProb(double prob) { 17 | this.prob = prob; 18 | } 19 | 20 | 21 | } 22 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/cnn/test/TestHello.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.cnn.test; 2 | 3 | import java.io.File; 4 | 5 | import deepDriver.dl.aml.distribution.Fs; 6 | 7 | public class TestHello { 8 | 9 | public static void read(File dir) throws Exception { 10 | String mfile = dir.getAbsolutePath()+"\\helloVo.m"; 11 | HelloVo hv = (HelloVo) Fs.readObjFromFile(mfile); 12 | System.out.println("Read from "+mfile+", "+hv.name); 13 | } 14 | 15 | public static void save(File dir) throws Exception { 16 | HelloVo hv = new HelloVo(); 17 | String mfile = dir.getAbsolutePath()+"\\helloVo.m"; 18 | Fs.writeObject2File(mfile, hv); 19 | System.out.println("Save into "+mfile); 20 | } 21 | 22 | 23 | public static void main(String[] args) throws Exception { 24 | String sf = System.getProperty("user.dir"); 25 | File dir = new File(sf, "data"); 26 | dir.mkdirs(); 27 | // save(dir); 28 | read(dir); 29 | } 30 | 31 | } 32 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/common/CommonArch.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.common; 2 | 3 | public class CommonArch { 4 | 5 | 6 | } 7 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/common/ICommonLayer.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.common; 2 | 3 | import deepDriver.dl.aml.costFunction.ICostFunction; 4 | 5 | public interface ICommonLayer { 6 | 7 | public void setICostFunction(ICostFunction cf); 8 | 9 | public ICostFunction getICostFunction(); 10 | 11 | 12 | } 13 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/common/ICommonLayerConfigurator.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.common; 2 | 3 | public interface ICommonLayerConfigurator { 4 | 5 | } 6 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/common/ICommonModel.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.common; 2 | 3 | public interface ICommonModel { 4 | 5 | 6 | } 7 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/common/distribution/CommonSlave.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.common.distribution; 2 | 3 | import deepDriver.dl.aml.distribution.Error; 4 | import deepDriver.dl.aml.distribution.Slave; 5 | 6 | public class CommonSlave extends Slave { 7 | public static String CMODEL_SLAVE = "-CMODEL_SLAVE"; 8 | public static String CTASKPIECE = "-CTASKPIECE"; 9 | 10 | Slave ms = null; 11 | Linkable root; 12 | Linkable current; 13 | public void handleOthers(String command) throws Exception { 14 | if (command.startsWith(CMODEL_SLAVE)) { 15 | String clazz = command.substring(CMODEL_SLAVE.length() + 1); 16 | System.out.println("Prepare to run "+clazz); 17 | ms = (Slave) Class.forName(clazz).newInstance(); 18 | } else if (command.startsWith(CTASKPIECE)) { 19 | // String clazz = command.substring(CTASKPIECE.length() + 1); 20 | // System.out.println("Prepare to run "+clazz); 21 | Object obj = talkClient.receiveObj(); 22 | if (obj instanceof String) { 23 | System.out.println("WTf: "+obj); 24 | } 25 | Linkable linkable1 = (Linkable) obj; 26 | if (root == null) { 27 | root = linkable1; 28 | current = root; 29 | if (ms != null) { 30 | ms.setTask(root); 31 | } 32 | } else { 33 | current.setNext(linkable1); 34 | current = linkable1; 35 | } 36 | } else { 37 | if (ms != null) { 38 | ms.handleOthers(command); 39 | } 40 | } 41 | } 42 | 43 | @Override 44 | public void setTask(Object obj) throws Exception { 45 | ms.setTask(obj); 46 | } 47 | 48 | @Override 49 | public void trainLocal() throws Exception { 50 | ms.trainLocal(); 51 | } 52 | 53 | @Override 54 | public Error getError() { 55 | return ms.getError(); 56 | } 57 | 58 | @Override 59 | public void setSubject(Object obj) { 60 | ms.setSubject(obj); 61 | } 62 | 63 | @Override 64 | public Object getLocalSubject() { 65 | return ms.getLocalSubject(); 66 | } 67 | 68 | } 69 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/common/distribution/Job.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.common.distribution; 2 | 3 | import java.lang.reflect.Method; 4 | 5 | public class Job extends Thread { 6 | 7 | String mainClzz; 8 | 9 | public Job(String mainClzz) { 10 | super(); 11 | this.mainClzz = mainClzz; 12 | } 13 | 14 | @Override 15 | public void run() { 16 | super.run(); 17 | Class clzz; 18 | try { 19 | clzz = Class.forName(mainClzz); 20 | Object obj = clzz.newInstance(); 21 | Method m1 = clzz.getDeclaredMethod("main", String[].class); 22 | m1.invoke(obj, null); 23 | } catch (Exception e) { 24 | e.printStackTrace(); 25 | } 26 | 27 | } 28 | 29 | } 30 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/common/distribution/Linkable.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.common.distribution; 2 | 3 | public interface Linkable { 4 | 5 | // public Linkable nextLink(); 6 | public Linkable getNext(); 7 | 8 | public void setNext(Linkable next); 9 | 10 | } 11 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/common/distribution/LinkableDataStream.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.common.distribution; 2 | 3 | import deepDriver.dl.aml.cnn.CacheAbleDataStream; 4 | import deepDriver.dl.aml.cnn.IDataMatrix; 5 | import deepDriver.dl.aml.cnn.IDataStream; 6 | 7 | public class LinkableDataStream implements IDataStream { 8 | 9 | /** 10 | * 11 | */ 12 | private static final long serialVersionUID = 1L; 13 | 14 | CacheAbleDataStream root; 15 | CacheAbleDataStream current; 16 | 17 | public LinkableDataStream(CacheAbleDataStream root) { 18 | super(); 19 | this.root = root; 20 | this.current = root; 21 | } 22 | 23 | @Override 24 | public IDataMatrix[] next() { 25 | return current.next(); 26 | } 27 | 28 | @Override 29 | public IDataMatrix[] next(Object pos) { 30 | return current.next(pos); 31 | } 32 | 33 | @Override 34 | public boolean hasNext() { 35 | if (!current.hasNext()) { 36 | current = (CacheAbleDataStream) current.getNext(); 37 | if (current == null) { 38 | return false; 39 | } else { 40 | return current.hasNext(); 41 | } 42 | } else { 43 | return true; 44 | } 45 | } 46 | 47 | @Override 48 | public boolean reset() { 49 | this.current = root; 50 | CacheAbleDataStream cds = root; 51 | cds.reset(); 52 | while (cds.getNext() != null) { 53 | cds = (CacheAbleDataStream) cds.getNext(); 54 | cds.reset(); 55 | } 56 | return true; 57 | } 58 | 59 | @Override 60 | public IDataStream[] splitStream(int segments) { 61 | return null; 62 | } 63 | 64 | @Override 65 | public int splitCnt(int segments) { 66 | return 0; 67 | } 68 | 69 | } 70 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/common/test/TestCommonSlave.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.common.test; 2 | 3 | import deepDriver.dl.aml.common.distribution.CommonSlave; 4 | import deepDriver.dl.aml.distribution.DistributionEnvCfg; 5 | import deepDriver.dl.aml.distribution.P2PServer; 6 | 7 | public class TestCommonSlave { 8 | public static void main(String[] args) throws Exception { 9 | String host = "127.0.0.1"; 10 | if (args != null && args.length >= 1) { 11 | host = args[0]; 12 | } 13 | System.out.println("Connet to Server: "+host); 14 | DistributionEnvCfg.getCfg().set(P2PServer.KEY_SRV_HOST, host); 15 | DistributionEnvCfg.getCfg().set(P2PServer.KEY_SRV_PORT, 8034); 16 | 17 | CommonSlave cSlave = new CommonSlave(); 18 | cSlave.train(); 19 | } 20 | 21 | } 22 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/costFunction/CostFunctionFactory.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.costFunction; 2 | 3 | public class CostFunctionFactory { 4 | 5 | public static int SOFT_MAX = 1; 6 | 7 | public static int REGRESSION = 2; 8 | } 9 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/costFunction/DummyCostFunction.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.costFunction; 2 | 3 | import java.io.Serializable; 4 | import java.util.List; 5 | 6 | 7 | import deepDriver.dl.aml.ann.ILayer; 8 | import deepDriver.dl.aml.ann.INeuroUnit; 9 | import deepDriver.dl.aml.ann.imp.LayerImp; 10 | import deepDriver.dl.aml.ann.imp.NeuroUnitImp; 11 | 12 | public class DummyCostFunction implements ICostFunction, Serializable { 13 | 14 | /** 15 | * 16 | */ 17 | private static final long serialVersionUID = 1L; 18 | 19 | LayerImp layer; 20 | 21 | int zZIndex = 0; 22 | 23 | public int getzZIndex() { 24 | return zZIndex; 25 | } 26 | 27 | public void setzZIndex(int zZIndex) { 28 | this.zZIndex = zZIndex; 29 | } 30 | 31 | @Override 32 | public double [] activate() { 33 | List neuros = layer.getNeuros(); 34 | double [] yt = new double[neuros.size()]; 35 | for (int i = 0; i < neuros.size(); i++) { 36 | NeuroUnitImp nu = (NeuroUnitImp) neuros.get(i); 37 | yt[i] = nu.getAas()[zZIndex]; 38 | } 39 | return yt; 40 | } 41 | 42 | public double caculateStdError() { 43 | return layer.getStdError(new double[][]{target}); 44 | } 45 | 46 | double [] target; 47 | @Override 48 | public void caculateCostError() { 49 | 50 | } 51 | 52 | public LayerImp getLayer() { 53 | return layer; 54 | } 55 | 56 | public void setLayer(ILayer layer) { 57 | this.layer = (LayerImp) layer; 58 | } 59 | 60 | public double[] getTarget() { 61 | return target; 62 | } 63 | 64 | public void setTarget(double[] target) { 65 | this.target = target; 66 | } 67 | 68 | public static void main(String[] args) { 69 | System.out.println(Math.exp(-0.00928597904706061)); 70 | } 71 | 72 | @Override 73 | public double verfiyResult(double[] targets, double[] results) { 74 | return 0; 75 | } 76 | 77 | } -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/costFunction/ICostFunction.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.costFunction; 2 | 3 | import java.io.Serializable; 4 | 5 | import deepDriver.dl.aml.ann.ILayer; 6 | 7 | public interface ICostFunction extends Serializable { 8 | 9 | public int getzZIndex(); 10 | 11 | public void setzZIndex(int zZIndex); 12 | 13 | public double [] activate(); 14 | 15 | public double caculateStdError(); 16 | 17 | public void caculateCostError(); 18 | 19 | public void setLayer(ILayer layer); 20 | 21 | public void setTarget(double[] target); 22 | 23 | public double verfiyResult(double [] targets, double [] results) ; 24 | 25 | } 26 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/costFunction/PositiveTask.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.costFunction; 2 | 3 | import java.io.Serializable; 4 | 5 | public class PositiveTask extends Task implements Serializable { 6 | private static final long serialVersionUID = 1L; 7 | 8 | public boolean checkRule(double [] target) { 9 | for (int i = 0; i < target.length; i++) { 10 | if (target[i] < 0) { 11 | return false; 12 | } 13 | } 14 | return true; 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/costFunction/Task.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.costFunction; 2 | 3 | import java.io.Serializable; 4 | 5 | import deepDriver.dl.aml.ann.imp.NeuroUnitImp; 6 | 7 | public class Task implements Serializable { 8 | 9 | /** 10 | * 11 | */ 12 | private static final long serialVersionUID = 1L; 13 | public static int CF_STD = 1; 14 | public static int CF_SOFTMAX = 2; 15 | 16 | // public static int SS_NORMAL = 0; 17 | // public static int SS_GRL = 1; 18 | 19 | double grlStatus = 0; 20 | 21 | int neuroLen; 22 | int costType; 23 | int resultLen; 24 | 25 | NeuroUnitImp [] nus; 26 | double [] zZs; 27 | 28 | String name; 29 | 30 | public double getGrlStatus() { 31 | return grlStatus; 32 | } 33 | 34 | public void setGrlStatus(double grlStatus) { 35 | this.grlStatus = grlStatus; 36 | } 37 | 38 | public String getName() { 39 | return name; 40 | } 41 | 42 | public void setName(String name) { 43 | this.name = name; 44 | } 45 | 46 | public int getNeuroLen() { 47 | return neuroLen; 48 | } 49 | 50 | public void setNeuroLen(int neuroLen) { 51 | this.neuroLen = neuroLen; 52 | } 53 | public int getCostType() { 54 | return costType; 55 | } 56 | public void setCostType(int costType) { 57 | this.costType = costType; 58 | } 59 | 60 | public boolean checkRule(double [] target) { 61 | return true; 62 | } 63 | 64 | // public int getResultLen() { 65 | // return resultLen; 66 | // } 67 | // 68 | // public void setResultLen(int resultLen) { 69 | // this.resultLen = resultLen; 70 | // } 71 | 72 | public NeuroUnitImp[] getNus() { 73 | return nus; 74 | } 75 | 76 | public void setNus(NeuroUnitImp[] nus) { 77 | this.nus = nus; 78 | } 79 | 80 | public double[] getzZs() { 81 | return zZs; 82 | } 83 | 84 | public void setzZs(double[] zZs) { 85 | this.zZs = zZs; 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/distribution/AsycMaster.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.distribution; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | public abstract class AsycMaster { 7 | protected P2PServer talkServer = new P2PServer(); 8 | protected boolean done = false; 9 | public static String TrainCommand = "-c train"; 10 | public static String TaskCommand = "-c task"; 11 | public static String SubjectCommand = "-c subject"; 12 | public static String CollectErrorCommand = "-c collectError"; 13 | public static String CollectSubjectCommand = "-c collectSubject"; 14 | int cnt = 0; 15 | 16 | List slaveThreads = new ArrayList(); 17 | public void train() throws Exception { 18 | talkServer.setup(getClientsNum()); 19 | talkServer.collectState(); 20 | 21 | List clients = talkServer.getClients(); 22 | for (int i = 0; i < getClientsNum(); i++) { 23 | slaveThreads.add(new AsycSlaveServeThread(i, clients.get(i), this)); 24 | } 25 | 26 | talkServer.distributeCommand(TaskCommand); 27 | talkServer.distributeObjects(splitTasks()); 28 | Object obj = getDistributeSubject(); 29 | talkServer.distributeObject(obj); 30 | 31 | for (int i = 0; i < slaveThreads.size(); i++) { 32 | slaveThreads.get(i).start(); 33 | System.out.println(i+" client thread started."); 34 | } 35 | 36 | for (int i = 0; i < slaveThreads.size(); i++) { 37 | slaveThreads.get(i).join(); 38 | } 39 | System.out.println("Master exit"); 40 | } 41 | 42 | public abstract void testOnMaster() throws Exception; 43 | 44 | public abstract int getClientsNum(); 45 | 46 | public abstract Object [] splitTasks(); 47 | 48 | public abstract Object getDistributeSubject(); 49 | 50 | public abstract double caculateErrorLastTime(Object [] objs); 51 | 52 | public abstract void mergeSubject(Object [] objs); 53 | 54 | public abstract boolean isCltSrvSameMode(Object [] objs); 55 | 56 | } 57 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/distribution/AsycSlave.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.distribution; 2 | 3 | public abstract class AsycSlave { 4 | protected P2PClient talkClient = new P2PClient(); 5 | 6 | String currentCommand; 7 | public void train() throws Exception { 8 | prepareData(false); 9 | talkClient.setup(); 10 | talkClient.responseReady(); 11 | 12 | String command = talkClient.receiveCommand(); 13 | System.out.println("Receive command from server: "+command); 14 | setTask(talkClient.receiveObj()); 15 | setSubject(talkClient.receiveObj()); 16 | if (Master.TaskCommand.equals(command)) { 17 | } else if (Master.SubjectCommand.equals(command)) { 18 | } 19 | while(true) { 20 | long l1 = System.currentTimeMillis(); 21 | trainLocal(); 22 | long l2 = System.currentTimeMillis(); 23 | System.out.println("Training time cost: "+(l2 - l1)); 24 | talkClient.sendObj(getLocalSubject()); 25 | talkClient.sendObj(getError()); 26 | setSubject(talkClient.receiveObj()); 27 | long l3 = System.currentTimeMillis(); 28 | System.out.println("Switch data cost: "+(l3 - l2)); 29 | System.out.println("The threads num per CPU should be: "+(l3 - l2)/(l2 - l1)); 30 | } 31 | } 32 | 33 | public abstract void prepareData(boolean isServer) throws Exception; 34 | 35 | public abstract void setTask(Object obj) throws Exception; 36 | 37 | public abstract void trainLocal() throws Exception; 38 | 39 | public abstract Error getError(); 40 | 41 | public abstract void setSubject(Object obj); 42 | 43 | public abstract Object getLocalSubject(); 44 | 45 | } 46 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/distribution/AsycSlaveServeThread.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.distribution; 2 | 3 | 4 | public class AsycSlaveServeThread extends Thread { 5 | P2PBase p2PBase; 6 | ClientVo cv; 7 | AsycMaster asycMaster; 8 | int index; 9 | 10 | public AsycSlaveServeThread(int index, ClientVo cv, AsycMaster asycMaster) { 11 | super(); 12 | this.index = index; 13 | this.cv = cv; 14 | this.asycMaster = asycMaster; 15 | p2PBase = new P2PBase(cv.getSocket(), cv.getOos(), 16 | cv.getOis()); 17 | } 18 | 19 | @Override 20 | public void run() { 21 | while (true) { 22 | // System.out.println("thread is ready to collect inf from client"); 23 | Object sub = p2PBase.receiveObj(); 24 | Object err = p2PBase.receiveObj(); 25 | synchronized (asycMaster) { 26 | System.out.println("Prepare to merge client "+index 27 | +", "+this.p2PBase.socket); 28 | Object [] subs = new Object [] {sub}; 29 | if (sub == null) { 30 | System.out.println("Its sub uploaded is null.."); 31 | } 32 | if (asycMaster.isCltSrvSameMode(subs)) { 33 | asycMaster.mergeSubject(subs); 34 | asycMaster.caculateErrorLastTime(new Object [] {err}); 35 | Error error = (Error) err; 36 | double avgErr = error.getErr()/(double)error.getCnt(); 37 | if (error.isReady()) { 38 | avgErr = error.getErr(); 39 | } 40 | System.out.println("Prepare to merge client "+index 41 | +", run "+error.getCnt()+" samples, with avg error "+avgErr); 42 | } 43 | Object obj = asycMaster.getDistributeSubject(); 44 | p2PBase.sendObj(obj); 45 | if (asycMaster.done) { 46 | try { 47 | asycMaster.testOnMaster(); 48 | } catch (Exception e) { 49 | e.printStackTrace(); 50 | } 51 | break; 52 | } 53 | } 54 | // try { 55 | // this.sleep(0); 56 | // } catch (InterruptedException e) { 57 | // e.printStackTrace(); 58 | // } 59 | } 60 | } 61 | 62 | 63 | } 64 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/distribution/ClientVo.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.distribution; 2 | 3 | import java.io.IOException; 4 | import java.io.ObjectInputStream; 5 | import java.io.ObjectOutputStream; 6 | import java.net.Socket; 7 | 8 | public class ClientVo { 9 | Socket socket; 10 | // BufferedReader is; 11 | // PrintWriter os; 12 | ObjectOutputStream oos; 13 | ObjectInputStream ois; 14 | 15 | public ObjectInputStream getOis() { 16 | return ois; 17 | } 18 | 19 | public void setOis(ObjectInputStream ois) { 20 | this.ois = ois; 21 | } 22 | 23 | public void rebuild() { 24 | try { 25 | oos = new ObjectOutputStream(socket.getOutputStream()); 26 | // ois = new ObjectInputStream(socket.getInputStream()); 27 | // oos = new ObjectOutputStream(socket.getOutputStream()); 28 | ois = new ObjectInputStream(socket.getInputStream()); 29 | } catch (IOException e) { 30 | e.printStackTrace(); 31 | } 32 | } 33 | 34 | public ClientVo(Socket socket) { 35 | super(); 36 | this.socket = socket; 37 | // this.is = is; 38 | // this.os = os; 39 | try { 40 | oos = new ObjectOutputStream(socket.getOutputStream()); 41 | // ois = new ObjectInputStream(socket.getInputStream()); 42 | // oos = new ObjectOutputStream(socket.getOutputStream()); 43 | ois = new ObjectInputStream(socket.getInputStream()); 44 | 45 | } catch (IOException e) { 46 | e.printStackTrace(); 47 | } 48 | } 49 | 50 | public ObjectOutputStream getOos() { 51 | return oos; 52 | } 53 | 54 | public void setOos(ObjectOutputStream oos) { 55 | this.oos = oos; 56 | } 57 | 58 | public Socket getSocket() { 59 | return socket; 60 | } 61 | public void setSocket(Socket socket) { 62 | this.socket = socket; 63 | } 64 | // public BufferedReader getIs() { 65 | // return is; 66 | // } 67 | // public void setIs(BufferedReader is) { 68 | // this.is = is; 69 | // } 70 | // public PrintWriter getOs() { 71 | // return os; 72 | // } 73 | // public void setOs(PrintWriter os) { 74 | // this.os = os; 75 | // } 76 | 77 | } 78 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/distribution/CommandFilter.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.distribution; 2 | 3 | public interface CommandFilter { 4 | 5 | public boolean filtCommand(String command); 6 | 7 | public CommandFilter nextCommandFilter(); 8 | 9 | } 10 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/distribution/CommandFilterManager.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.distribution; 2 | 3 | public class CommandFilterManager { 4 | CommandFilter root; 5 | public void filtCommand(String command) { 6 | filtCommand(root, command); 7 | } 8 | 9 | public void filtCommand(CommandFilter cf, String command) { 10 | if (!cf.filtCommand(command)) { 11 | filtCommand(cf.nextCommandFilter(), command); 12 | } 13 | } 14 | 15 | } 16 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/distribution/DistributionEnvCfg.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.distribution; 2 | 3 | import java.util.HashMap; 4 | import java.util.Map; 5 | 6 | public class DistributionEnvCfg { 7 | static DistributionEnvCfg cfg = new DistributionEnvCfg(); 8 | Map envHash = new HashMap(); 9 | 10 | private DistributionEnvCfg() { 11 | } 12 | 13 | public static DistributionEnvCfg getCfg() { 14 | return cfg; 15 | } 16 | 17 | public static void setCfg(DistributionEnvCfg cfg) { 18 | DistributionEnvCfg.cfg = cfg; 19 | } 20 | 21 | public Object get(String key) { 22 | return envHash.get(key); 23 | } 24 | 25 | public int getInt(String key) { 26 | Object obj = get(key); 27 | if (obj == null) { 28 | return 0; 29 | } 30 | return (Integer) obj; 31 | } 32 | 33 | public String getString(String key) { 34 | return (String) envHash.get(key); 35 | } 36 | 37 | public void set(String key, Object obj) { 38 | envHash.put(key, obj); 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/distribution/Error.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.distribution; 2 | 3 | import java.io.Serializable; 4 | 5 | public class Error implements Serializable { 6 | /** 7 | * 8 | */ 9 | private static final long serialVersionUID = 1L; 10 | double err; 11 | int cnt; 12 | boolean ready; 13 | 14 | public boolean isReady() { 15 | return ready; 16 | } 17 | public void setReady(boolean ready) { 18 | this.ready = ready; 19 | } 20 | public double getErr() { 21 | return err; 22 | } 23 | public void setErr(double err) { 24 | this.err = err; 25 | } 26 | public int getCnt() { 27 | return cnt; 28 | } 29 | public void setCnt(int cnt) { 30 | this.cnt = cnt; 31 | } 32 | 33 | 34 | 35 | } 36 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/distribution/Fs.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.distribution; 2 | 3 | import java.io.FileInputStream; 4 | import java.io.FileOutputStream; 5 | import java.io.ObjectInputStream; 6 | import java.io.ObjectOutputStream; 7 | 8 | public class Fs { 9 | 10 | String file; 11 | public static void writeObject2File(String file, Object obj) throws Exception { 12 | ObjectOutputStream oos = new ObjectOutputStream( 13 | new FileOutputStream(file)); 14 | oos.writeUnshared(obj); 15 | oos.close(); 16 | } 17 | 18 | public static void writeObj2FileWithTs(String file, Object obj) throws Exception { 19 | ObjectOutputStream oos = new ObjectOutputStream( 20 | new FileOutputStream(file)); 21 | oos.writeUnshared(obj); 22 | oos.close(); 23 | } 24 | 25 | public static Object readObjFromFile(String file) throws Exception { 26 | ObjectInputStream is = new ObjectInputStream( 27 | new FileInputStream(file)); 28 | Object obj = is.readUnshared(); 29 | is.close(); 30 | return obj; 31 | } 32 | 33 | } 34 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/distribution/ITask.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.distribution; 2 | 3 | public interface ITask { 4 | 5 | } 6 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/distribution/Master.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.distribution; 2 | 3 | public abstract class Master { 4 | protected P2PServer talkServer = new P2PServer(); 5 | protected boolean done = false; 6 | public static String TrainCommand = "-c train"; 7 | public static String TaskCommand = "-c task"; 8 | public static String SubjectCommand = "-c subject"; 9 | public static String CollectErrorCommand = "-c collectError"; 10 | public static String CollectSubjectCommand = "-c collectSubject"; 11 | int cnt = 0; 12 | 13 | boolean setup = false; 14 | public void setup() throws Exception { 15 | talkServer.setup(getClientsNum()); 16 | talkServer.collectState(); 17 | setup = true; 18 | } 19 | 20 | public void train() throws Exception { 21 | if (!setup) { 22 | setup(); 23 | } 24 | while (true) { 25 | talkServer.distributeCommand(TaskCommand); 26 | talkServer.distributeObjects(splitTasks()); 27 | talkServer.distributeCommand(SubjectCommand); 28 | Object obj = getDistributeSubject(); 29 | talkServer.distributeObject(obj); 30 | // Fs.writeObj2FileWithTs("D:\\6.workspace\\ANN\\seq2seqCfg", obj); 31 | talkServer.distributeCommand(TrainCommand); 32 | talkServer.distributeCommand(CollectSubjectCommand); 33 | mergeSubject(talkServer.collectObjs()); 34 | testOnMaster(); 35 | talkServer.distributeCommand(CollectErrorCommand); 36 | caculateErrorLastTime(talkServer.collectObjs()); 37 | if (done) { 38 | testOnMaster(); 39 | break; 40 | } 41 | } 42 | } 43 | 44 | // public int setup() { 45 | // return talkServer.getClients().size(); 46 | // } 47 | public abstract void testOnMaster() throws Exception; 48 | 49 | public abstract int getClientsNum(); 50 | 51 | public abstract Object [] splitTasks(); 52 | 53 | public abstract Object getDistributeSubject(); 54 | // 55 | // public abstract void trainOnSlave(); 56 | 57 | public abstract double caculateErrorLastTime(Object [] objs); 58 | 59 | public abstract void mergeSubject(Object [] objs); 60 | 61 | } 62 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/distribution/P2PClient.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.distribution; 2 | 3 | import java.io.ObjectInputStream; 4 | import java.io.ObjectOutputStream; 5 | import java.net.Socket; 6 | 7 | public class P2PClient extends P2PBase { 8 | 9 | public void setup() { 10 | try { 11 | int port = DistributionEnvCfg.getCfg().getInt(P2PServer.KEY_SRV_PORT); 12 | String host = DistributionEnvCfg.getCfg().getString(P2PServer.KEY_SRV_HOST); 13 | if (port > 0) { 14 | sport = port; 15 | } 16 | if (host != null) { 17 | master = host; 18 | } 19 | socket=new Socket(master,sport); 20 | socket.setKeepAlive(true); 21 | socket.setSoTimeout(1000 * 60 * 60); 22 | // os = new PrintWriter(socket.getOutputStream()); 23 | // is = new BufferedReader(new InputStreamReader(socket.getInputStream())); 24 | oos = new ObjectOutputStream(socket.getOutputStream()); 25 | // oos = new ObjectOutputStream(socket.getOutputStream()); 26 | // ois = new ObjectInputStream(socket.getInputStream()); 27 | ois = new ObjectInputStream(socket.getInputStream()); 28 | 29 | System.out.println("slave setup, and connect to master"); 30 | } catch (Exception e) { 31 | } 32 | } 33 | 34 | } 35 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/distribution/Slave.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.distribution; 2 | 3 | public abstract class Slave { 4 | protected P2PClient talkClient = new P2PClient(); 5 | 6 | String currentCommand; 7 | public void train() throws Exception { 8 | talkClient.setup(); 9 | talkClient.responseReady(); 10 | while(true) { 11 | String command = talkClient.receiveCommand(); 12 | System.out.println("Receive command from server: "+command); 13 | if (command == null) { 14 | System.out.println("EORROR OCCURED ON SERVER, EXIT"); 15 | break; 16 | } 17 | else if (Master.TaskCommand.equals(command)) { 18 | setTask(talkClient.receiveObj()); 19 | } else if (Master.SubjectCommand.equals(command)) { 20 | setSubject(talkClient.receiveObj()); 21 | } else if (Master.TrainCommand.equals(command)) { 22 | //afraid it may be time out... 23 | trainLocal(); 24 | } else if (Master.CollectSubjectCommand.equals(command)) { 25 | talkClient.sendObj(getLocalSubject()); 26 | } else if (Master.CollectErrorCommand.equals(command)) { 27 | talkClient.sendObj(getError()); 28 | } else { 29 | handleOthers(command); 30 | } 31 | currentCommand = command; 32 | } 33 | } 34 | 35 | public void handleOthers(String command) throws Exception { 36 | 37 | } 38 | 39 | public abstract void setTask(Object obj) throws Exception; 40 | 41 | public abstract void trainLocal() throws Exception; 42 | 43 | public abstract Error getError(); 44 | 45 | public abstract void setSubject(Object obj); 46 | 47 | public abstract Object getLocalSubject(); 48 | 49 | } 50 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/distribution/modelParallel/PartialCallback.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.distribution.modelParallel; 2 | 3 | public interface PartialCallback { 4 | public void runPartial(int offset, int runLen); 5 | } -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/distribution/modelParallel/ThreadParallel.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.distribution.modelParallel; 2 | 3 | import java.io.Serializable; 4 | 5 | public class ThreadParallel implements Serializable { 6 | 7 | /** 8 | * 9 | */ 10 | private static final long serialVersionUID = 1L; 11 | 12 | class PartialThread extends Thread { 13 | PartialCallback p; 14 | int offset; 15 | int runLen; 16 | 17 | public PartialThread(PartialCallback p, int offset, int runLen) { 18 | super(); 19 | this.p = p; 20 | this.offset = offset; 21 | this.runLen = runLen; 22 | } 23 | 24 | @Override 25 | public void run() { 26 | p.runPartial(offset, runLen); 27 | } 28 | }; 29 | 30 | public void runMutipleThreads(int length, PartialCallback p, int tn) { 31 | int eachPart = length/tn; 32 | PartialThread [] ps = new PartialThread[tn]; 33 | for (int i = 0; i < tn; i++) { 34 | int offset = i * eachPart; 35 | int runLen = eachPart; 36 | if (i == tn - 1) { 37 | runLen = length - i * eachPart; 38 | } 39 | ps[i] = new PartialThread(p, offset, runLen); 40 | ps[i].start(); 41 | } 42 | for (int i = 0; i < ps.length; i++) { 43 | try { 44 | ps[i].join(); 45 | } catch (InterruptedException e) { 46 | e.printStackTrace(); 47 | } 48 | } 49 | } 50 | 51 | } 52 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/distribution/test/HelloCnt.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.distribution.test; 2 | 3 | import java.io.Serializable; 4 | 5 | public class HelloCnt implements Serializable { 6 | /** 7 | * 8 | */ 9 | transient int [] k = new int[30]; 10 | private int i = 0; 11 | private int j = 0; 12 | 13 | public int[] getK() { 14 | return k; 15 | } 16 | public void setK(int[] k) { 17 | this.k = k; 18 | } 19 | public int getI() { 20 | return i; 21 | } 22 | public void setI(int i) { 23 | this.i = i; 24 | } 25 | public int getJ() { 26 | return j; 27 | } 28 | public void setJ(int j) { 29 | this.j = j; 30 | } 31 | 32 | } 33 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/distribution/test/HelloWrapper.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.distribution.test; 2 | 3 | import java.io.Serializable; 4 | 5 | public class HelloWrapper implements Serializable { 6 | /** 7 | * 8 | */ 9 | private static final long serialVersionUID = 1L; 10 | HelloCnt helloCnt; 11 | 12 | } 13 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/distribution/test/cl/Client.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.distribution.test.cl; 2 | 3 | import java.io.ObjectInputStream; 4 | import java.io.ObjectOutputStream; 5 | import java.net.Socket; 6 | 7 | public class Client { 8 | public void run() throws Exception { 9 | System.out.println("employeeNumber= " 10 | + joe .getEmployeeNumber()); 11 | System.out.println("employeeName= " 12 | + joe .getEmployeeName()); 13 | 14 | Socket socketConnection = new Socket("127.0.0.1", 11111); 15 | 16 | 17 | ObjectOutputStream clientOutputStream = new 18 | ObjectOutputStream(socketConnection.getOutputStream()); 19 | ObjectInputStream clientInputStream = new 20 | ObjectInputStream(socketConnection.getInputStream()); 21 | 22 | clientOutputStream.writeObject(joe); 23 | clientOutputStream.flush(); 24 | 25 | Employee joe2 = (Employee)clientInputStream.readObject(); 26 | System.out.println(joe); 27 | System.out.println(joe2); 28 | System.out.println("employeeNumber= " 29 | + joe2 .getEmployeeNumber()); 30 | System.out.println("employeeName= " 31 | + joe2 .getEmployeeName()); 32 | joe2.setEmployeeName("aaaa"); 33 | joe2.setEmployeeNumber(1); 34 | clientOutputStream.writeObject(joe2); 35 | clientOutputStream.flush(); 36 | 37 | joe2 = (Employee)clientInputStream.readObject(); 38 | System.out.println(joe); 39 | System.out.println(joe2); 40 | System.out.println("employeeNumber= " 41 | + joe2 .getEmployeeNumber()); 42 | System.out.println("employeeName= " 43 | + joe2 .getEmployeeName()); 44 | joe2.setEmployeeName("bbbb"); 45 | joe2.setEmployeeNumber(2); 46 | clientOutputStream.writeObject(joe2); 47 | clientOutputStream.flush(); 48 | 49 | clientOutputStream.close(); 50 | clientInputStream.close(); 51 | } 52 | Employee joe = new Employee(150, "Joe"); 53 | 54 | public static void main(String[] arg) { 55 | try { 56 | 57 | Client client = new Client(); 58 | client.run(); 59 | 60 | 61 | } catch (Exception e) {System.out.println(e); } 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/distribution/test/cl/Employee.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.distribution.test.cl; 2 | 3 | import java.io.Serializable; 4 | 5 | public class Employee implements Serializable { 6 | 7 | /** 8 | * 9 | */ 10 | private static final long serialVersionUID = 1L; 11 | private int employeeNumber; 12 | private String employeeName; 13 | 14 | public Employee(int num, String name) { 15 | employeeNumber = num; 16 | employeeName= name; 17 | } 18 | 19 | public int getEmployeeNumber() { 20 | return employeeNumber ; 21 | } 22 | 23 | public void setEmployeeNumber(int num) { 24 | employeeNumber = num; 25 | } 26 | 27 | public String getEmployeeName() { 28 | return employeeName ; 29 | } 30 | 31 | public void setEmployeeName(String name) { 32 | employeeName = name; 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/dnc/DNCConfigurator.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.dnc; 2 | 3 | import deepDriver.dl.aml.ann.ANN; 4 | import deepDriver.dl.aml.lstm.LSTMConfigurator; 5 | 6 | public class DNCConfigurator { 7 | 8 | DNCMemory memory; 9 | DNCReadHead [] readHeads; 10 | DNCWriteHead writeHead; 11 | DNCController controller; 12 | 13 | int yLen = 30;//the merged y, before outputing. 14 | int rhNum; 15 | int memoryNum; 16 | int memoryLength; 17 | 18 | int maxTime = 100; 19 | 20 | int trainingLoop = 100000; 21 | 22 | int ldecayLoop = 20000; 23 | 24 | public DNCConfigurator(double l, double m, int maxTime, ANN ann, LSTMConfigurator cfg, int yLen, int rhNum, int memoryNum, int memoryLength) { 25 | this.l = l; 26 | this.m = m; 27 | this.yLen = yLen; 28 | this.rhNum = rhNum; 29 | this.memoryNum = memoryNum; 30 | this.memoryLength = memoryLength; 31 | 32 | this.maxTime = maxTime; 33 | 34 | controller = new DNCController(ann, cfg, this); 35 | 36 | memory = new DNCMemory(memoryNum, memoryLength, this); 37 | 38 | readHeads = new DNCReadHead[rhNum]; 39 | for (int i = 0; i < readHeads.length; i++) { 40 | readHeads[i] = new DNCReadHead(this); 41 | } 42 | 43 | writeHead = new DNCWriteHead(this); 44 | } 45 | 46 | public int getMaxTime() { 47 | return maxTime; 48 | } 49 | 50 | public void setMaxTime(int maxTime) { 51 | this.maxTime = maxTime; 52 | } 53 | 54 | double l = 0.001; 55 | double m = 0.1; 56 | 57 | double ml = 0.0001; 58 | 59 | public double getMl() { 60 | return ml; 61 | } 62 | 63 | public void setMl(double ml) { 64 | this.ml = ml; 65 | } 66 | 67 | public double getL() { 68 | return l; 69 | } 70 | 71 | public void setL(double l) { 72 | this.l = l; 73 | } 74 | 75 | public double getM() { 76 | return m; 77 | } 78 | 79 | public void setM(double m) { 80 | this.m = m; 81 | } 82 | 83 | public int getLdecayLoop() { 84 | return ldecayLoop; 85 | } 86 | 87 | public void setLdecayLoop(int ldecayLoop) { 88 | this.ldecayLoop = ldecayLoop; 89 | } 90 | 91 | } 92 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/dnc/ITxtStream.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.dnc; 2 | 3 | import deepDriver.dl.aml.lstm.IStream; 4 | 5 | public interface ITxtStream extends IStream { 6 | 7 | public int [] getTargetPos(); 8 | 9 | } 10 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/dnc/test/babi/FullPBabiStream.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.dnc.test.babi; 2 | 3 | import deepDriver.dl.aml.string.Dictionary; 4 | 5 | public class FullPBabiStream extends BabiStream { 6 | 7 | public FullPBabiStream(Dictionary dic, int t) { 8 | super(dic, t); 9 | } 10 | 11 | public void next(Object pos) { 12 | if (pa == null) { 13 | sampleTT = null; 14 | return; 15 | } 16 | int [] is = pa.getFullTxt(); 17 | int [] a = pa.getFullAnswer(); 18 | tpos = pa.getFullAnswerPos(); 19 | sampleTT = new double[is.length][]; 20 | targetTT = new double[a.length][]; 21 | StringBuffer sb = null; 22 | StringBuffer sb2 = null; 23 | if (out) { 24 | sb = new StringBuffer(); 25 | sb2 = new StringBuffer(); 26 | } 27 | 28 | for (int i = 0; i < a.length; i++) { 29 | int ai = a[i]; 30 | double [] tw = targetTT[i] = new double[targetFeatureNum]; 31 | if (ai >= 1) { 32 | tw[ai - 1] = 1; 33 | } 34 | } 35 | 36 | 37 | for (int j = 0; j < is.length; j++) { 38 | int si = is[j]; 39 | double [] sw = sampleTT[j] = new double[sampleFeatureNum]; 40 | if (si >= 1) { 41 | sw[si - 1] = 1; 42 | } 43 | 44 | } 45 | if (out) { 46 | System.out.println("t:"+sb2.toString()); 47 | System.out.println("s:"+sb.toString()); 48 | } 49 | pa = pa.getNext(); 50 | } 51 | 52 | int [] tpos; 53 | 54 | public int[] getTargetPos() { 55 | return tpos; 56 | } 57 | 58 | } 59 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/dnn/DNN4Stream.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.dnn; 2 | 3 | public class DNN4Stream { 4 | 5 | 6 | 7 | } 8 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/dnn/distribute/test/DNNSlaveTest.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.dnn.distribute.test; 2 | 3 | import deepDriver.dl.aml.distribution.DistributionEnvCfg; 4 | import deepDriver.dl.aml.distribution.P2PServer; 5 | import deepDriver.dl.aml.dnn.distribute.DNNSlave; 6 | 7 | public class DNNSlaveTest { 8 | 9 | public static void main(String[] args) throws Exception { 10 | String host = "127.0.0.1"; 11 | if (args != null && args.length >= 1) { 12 | host = args[0]; 13 | } 14 | System.out.println("Connet to Server: "+host); 15 | DistributionEnvCfg.getCfg().set(P2PServer.KEY_SRV_HOST, host); 16 | DistributionEnvCfg.getCfg().set(P2PServer.KEY_SRV_PORT, 8034); 17 | 18 | DNNSlave ds = new DNNSlave(); 19 | ds.train(); 20 | } 21 | 22 | } 23 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/dnn/test/StreamAdapter.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.dnn.test; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | 7 | import deepDriver.dl.aml.ann.InputParameters; 8 | import deepDriver.dl.aml.cnn.ConvolutionNeuroNetwork; 9 | import deepDriver.dl.aml.cnn.IDataMatrix; 10 | import deepDriver.dl.aml.cnn.IDataStream; 11 | 12 | public class StreamAdapter { 13 | 14 | public void loadFromStream(IDataStream is, InputParameters ip) { 15 | List inputList = new ArrayList(); 16 | List resultList = new ArrayList(); 17 | while (is.hasNext()) { 18 | IDataMatrix [] dm = is.next(); 19 | inputList.add(matrix2Vector(dm[ConvolutionNeuroNetwork.MatrixTargetIndex].getMatrix())); 20 | resultList.add(dm[ConvolutionNeuroNetwork.MatrixTargetIndex].getTarget()); 21 | } 22 | double [][] in = new double[inputList.size()][]; 23 | double [][] ta = new double[inputList.size()][]; 24 | for (int i = 0; i < ta.length; i++) { 25 | in[i] = inputList.get(i); 26 | ta[i] = resultList.get(i); 27 | } 28 | ip.setInput(in); 29 | ip.setResult2(ta); 30 | inputList.clear(); 31 | resultList.clear(); 32 | } 33 | 34 | public double [] matrix2Vector(double [][] matrix) { 35 | double [] v = new double[matrix.length * matrix[0].length]; 36 | int cnt = 0; 37 | for (int i = 0; i < matrix.length; i++) { 38 | for (int j = 0; j < matrix[i].length; j++) { 39 | v[cnt ++] = matrix[i][j]; 40 | } 41 | } 42 | return v; 43 | } 44 | 45 | } 46 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/dnn/test/TestDNN4Hwr.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.dnn.test; 2 | 3 | import java.io.File; 4 | 5 | 6 | import deepDriver.dl.aml.ann.InputParameters; 7 | import deepDriver.dl.aml.cnn.img.CsvImgLoader; 8 | import deepDriver.dl.aml.cnn.img.ImgDataStream; 9 | import deepDriver.dl.aml.costFunction.SoftMax4ANN; 10 | import deepDriver.dl.aml.dnn.DNN; 11 | 12 | public class TestDNN4Hwr { 13 | 14 | public void train(String file, String tfile) throws Exception { 15 | int kLength = 10; 16 | CsvImgLoader imgLoader = new CsvImgLoader(); 17 | imgLoader.loadImg(file); 18 | ImgDataStream is = new ImgDataStream(imgLoader, kLength); 19 | 20 | StreamAdapter streamAdapter = new StreamAdapter(); 21 | InputParameters ip = new InputParameters(); 22 | streamAdapter.loadFromStream(is, ip); 23 | 24 | ip.setIterationNum(3000); 25 | ip.setLamda(0.0000001); 26 | ip.setLayerNum(8); 27 | ip.setNeuros(new int[]{90, 90, 90, 90, 90, 90, 90, kLength}); 28 | 29 | DNN dnn = new DNN(); 30 | dnn.setCf(new SoftMax4ANN()); 31 | dnn.setkLength(kLength); 32 | dnn.getaNNCfg().setDropOut(0); 33 | 34 | dnn.trainModel(ip); 35 | } 36 | 37 | public static void main(String[] args) throws Exception { 38 | TestDNN4Hwr test = new TestDNN4Hwr(); 39 | String sf = "E:\\0.workspace\\4.data\\cnn"; 40 | File fsf = new File(sf); 41 | if (!fsf.exists()) { 42 | sf = System.getProperty("user.dir"); 43 | } 44 | File dir = new File(sf, "data"); 45 | dir.mkdirs(); 46 | test.train(dir.getAbsolutePath()+"\\kaggleTest\\modelTrain", 47 | dir.getAbsolutePath()+"\\kaggleTest\\modelTest"); 48 | // test.train(String file, String tfile); 49 | 50 | 51 | } 52 | 53 | } 54 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/fnn/FractalNet.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.fnn; 2 | 3 | import deepDriver.dl.aml.cnn.ConvolutionNeuroNetwork; 4 | import deepDriver.dl.aml.cnn.FractalBlock; 5 | import deepDriver.dl.aml.cnn.ICNNLayer; 6 | import deepDriver.dl.aml.cnn.LayerConfigurator; 7 | 8 | public class FractalNet extends ConvolutionNeuroNetwork { 9 | 10 | public ICNNLayer createCNNLayer(LayerConfigurator lc, ICNNLayer previous) { 11 | ICNNLayer layer = null; 12 | if (LayerConfigurator.FRACTAL_BLOCK_LAYER == lc.getType()) { 13 | layer = new FractalBlock(lc, previous); 14 | return layer; 15 | } 16 | return super.createCNNLayer(lc, previous); 17 | } 18 | 19 | } 20 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/linearReg/ISubject2Optimized.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.linearReg; 2 | 3 | public interface ISubject2Optimized { 4 | 5 | public int getThetasNum(); 6 | 7 | public void initSubjectFunction(double [][] xVector, double [] y); 8 | 9 | public double cacluateSubject(double [] thetas); 10 | 11 | public void updateThetas(double [] thetas); 12 | 13 | public double getThetaDecent(int index); 14 | 15 | public double [] getInitTheta(); 16 | 17 | public void setInitTheta(double [] thetas); 18 | 19 | } 20 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/linearReg/LinearExpression.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.linearReg; 2 | 3 | import java.io.Serializable; 4 | 5 | import deepDriver.dl.aml.cart.DataSet; 6 | 7 | public class LinearExpression implements Serializable { 8 | 9 | private static final long serialVersionUID = 1L; 10 | double [] thetas; 11 | public LinearExpression(double[] thetas) { 12 | super(); 13 | this.thetas = thetas; 14 | } 15 | 16 | public double[] predict(DataSet ds) { 17 | double [][] vars = ds.getDependentVars(); 18 | double [] ys = new double[vars.length]; 19 | for (int i = 0; i < ys.length; i++) { 20 | for (int j = 0; j < thetas.length; j++) { 21 | ys[i] = ys[i] + thetas[j] * vars[i][j]; 22 | } 23 | } 24 | return ys; 25 | } 26 | 27 | public double[] getThetas() { 28 | return thetas; 29 | } 30 | 31 | 32 | } 33 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/linearReg/LinearRegression.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.linearReg; 2 | 3 | import deepDriver.dl.aml.cart.DataSet; 4 | 5 | public class LinearRegression { 6 | GradientDecentOptimizer gradientDecentOptimizer = 7 | new GradientDecentOptimizer(); 8 | 9 | public LinearExpression fit(DataSet ds) { 10 | LinearFunctionSubject linearFunctionSubject = new LinearFunctionSubject(); 11 | double [] thetas = gradientDecentOptimizer. 12 | optimizeFunction(linearFunctionSubject, ds.getDependentVars(), ds.getIndependentVars(), false); 13 | return new LinearExpression(thetas); 14 | } 15 | 16 | } 17 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/linearReg/ParameterScaler.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.linearReg; 2 | 3 | public class ParameterScaler { 4 | double [] maxSet; 5 | double [] minSet; 6 | public double [] scaleCoefficients(double [] thetas) { 7 | for (int i = 0; i < thetas.length; i++) { 8 | if (maxSet[i] != minSet[i]) { 9 | thetas[i] = thetas[i]/(maxSet[i] - minSet[i]); 10 | } 11 | } 12 | return thetas; 13 | } 14 | public double [] [] scaleParameters(double[][] xVector) { 15 | maxSet = new double[xVector[0].length]; 16 | minSet = new double[maxSet.length]; 17 | for (int i = 0; i < xVector.length; i++) { 18 | double [] x = xVector[i]; 19 | for (int j = 0; j < maxSet.length; j++) { 20 | if (maxSet[j] < x[j]) { 21 | maxSet[j] = x[j]; 22 | } 23 | if (minSet[j] > x[j]) { 24 | minSet[j] = x[j]; 25 | } 26 | } 27 | } 28 | for (int i = 0; i < xVector.length; i++) { 29 | double [] x = xVector[i]; 30 | for (int j = 0; j < maxSet.length; j++) { 31 | if (maxSet[j] != minSet[j]) { 32 | x[j] = x[j]/(maxSet[j] - minSet[j]); 33 | } 34 | } 35 | } 36 | return xVector; 37 | } 38 | 39 | 40 | } 41 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/linearReg/TestLinearReg.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.linearReg; 2 | 3 | import deepDriver.dl.aml.cart.DataSet; 4 | 5 | public class TestLinearReg { 6 | 7 | public static DataSet creatDs(double coef) { 8 | int cnt = 5; 9 | int columns = 6; 10 | double [] [] vars = new double[cnt][columns]; 11 | double [] inVars = new double[cnt]; 12 | String [] lables = new String[cnt]; 13 | for (int i = 0; i < cnt; i++) { 14 | vars[i] = new double[columns]; 15 | for (int j = 0; j < columns; j++) { 16 | if (j == 0) { 17 | vars[i][j] = i ; 18 | } else { 19 | vars[i][j] = i + j; 20 | } 21 | 22 | } 23 | inVars[i] =coef * i+1 ; 24 | lables[i] = "la"+i; 25 | } 26 | DataSet ds = new DataSet(); 27 | ds.setDependentVars(vars); 28 | ds.setIndependentVars(inVars); 29 | ds.setLabels(lables); 30 | return ds; 31 | } 32 | 33 | public static void main(String[] args) { 34 | LinearRegression reg = new LinearRegression(); 35 | LinearExpression le = reg.fit(creatDs(1)); 36 | DataSet ds = creatDs(0.75); 37 | double [] ys = le.predict(ds); 38 | for (int i = 0; i < ys.length; i++) { 39 | System.out.println(ds.getLabels()[i]+","+ds.getIndependentVars()[i]+","+ys[i]); 40 | } 41 | } 42 | 43 | } 44 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/linearReg/test/TestSelfProduced.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.linearReg.test; 2 | 3 | 4 | import deepDriver.dl.aml.cart.DataSet; 5 | import deepDriver.dl.aml.linearReg.LinearExpression; 6 | import deepDriver.dl.aml.linearReg.LinearRegression; 7 | import deepDriver.dl.aml.utils.AccuracyCaculator; 8 | 9 | public class TestSelfProduced { 10 | public static DataSet creatDs(double coef) { 11 | int cnt = 7; 12 | int columns = 2; 13 | double [] [] vars = new double[cnt][columns]; 14 | double [] inVars = new double[cnt]; 15 | double [] pv = {79.32 ,87.28 ,455.35 ,113.91 ,86.15 ,179.35 ,248.71,127.5,46.00 ,88.15,145.62 ,61.30 ,28.40 ,41.04}; 16 | double [] vv = {145.603105,146.592232,769.615582,214.942849,170.157229,368.989821,525.367893,461.566178,200.951821,405.199137,911.204181,460.704428,235.400016,367.576954}; 17 | String [] lables = new String[cnt]; 18 | for (int i = 0; i < cnt; i++) { 19 | vars[i] = new double[columns]; 20 | vars[i][0] = pv[i]; 21 | vars[i][1] = 1; 22 | inVars[i] = vv[i]; 23 | lables[i] = "epi"+i; 24 | } 25 | DataSet ds = new DataSet(); 26 | ds.setDependentVars(vars); 27 | ds.setIndependentVars(inVars); 28 | ds.setLabels(lables); 29 | return ds; 30 | } 31 | 32 | public static void main(String[] args) { 33 | LinearRegression reg = new LinearRegression(); 34 | LinearExpression le = reg.fit(creatDs(1)); 35 | DataSet ds = creatDs(0.75); 36 | double [] ys = le.predict(ds); 37 | for (int i = 0; i < ys.length; i++) { 38 | System.out.println(ds.getLabels()[i]+","+ds.getIndependentVars()[i]+","+ys[i]); 39 | } 40 | double [] ts = le.getThetas(); 41 | for (int i = 0; i < ts.length; i++) { 42 | System.out.println("t"+i+":"+ts[i]); 43 | } 44 | 45 | AccuracyCaculator acc = new AccuracyCaculator(); 46 | double ac = acc.caculateAccuracy(ds.getIndependentVars(), ys); 47 | System.out.println("ac:"+ac); 48 | } 49 | 50 | } 51 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lrate/LearningRateManager.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lrate; 2 | 3 | public interface LearningRateManager { 4 | 5 | public double adjustML(double err, double lrate); 6 | 7 | } 8 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lrate/StepReductionLR.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lrate; 2 | 3 | import java.io.Serializable; 4 | 5 | public class StepReductionLR implements LearningRateManager, Serializable { 6 | 7 | /** 8 | * 9 | */ 10 | private static final long serialVersionUID = 1L; 11 | int stepsCnt = 40000; 12 | double minLr = 0.0001; 13 | double reductionRate = 0.1; 14 | 15 | int cnt = 0; 16 | @Override 17 | public double adjustML(double err, double lrate) { 18 | cnt ++; 19 | double nl = lrate * reductionRate; 20 | if (cnt % stepsCnt == 0 && nl >= minLr) { 21 | lrate = nl; 22 | } 23 | return lrate; 24 | } 25 | public int getStepsCnt() { 26 | return stepsCnt; 27 | } 28 | public void setStepsCnt(int stepsCnt) { 29 | this.stepsCnt = stepsCnt; 30 | } 31 | public double getMinLr() { 32 | return minLr; 33 | } 34 | public void setMinLr(double minLr) { 35 | this.minLr = minLr; 36 | } 37 | public double getReductionRate() { 38 | return reductionRate; 39 | } 40 | public void setReductionRate(double reductionRate) { 41 | this.reductionRate = reductionRate; 42 | } 43 | 44 | } 45 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/BiCell.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | public class BiCell extends BiRNNNeuroVo implements ICell { 4 | 5 | /** 6 | * 7 | */ 8 | private static final long serialVersionUID = 1L; 9 | ICell cell; 10 | public BiCell(RNNNeuroVo vo) { 11 | super(vo); 12 | cell = (ICell) vo; 13 | } 14 | 15 | @Override 16 | public double[] getSc() { 17 | return cell.getSc(); 18 | } 19 | 20 | @Override 21 | public void setSc(double[] sc) { 22 | cell.setSc(sc); 23 | } 24 | 25 | @Override 26 | public double[] getDeltaSc() { 27 | return cell.getDeltaSc(); 28 | } 29 | 30 | @Override 31 | public void setDeltaSc(double[] deltaSc) { 32 | cell.setDeltaSc(deltaSc); 33 | } 34 | 35 | @Override 36 | public double[] getCZz() { 37 | return cell.getCZz(); 38 | } 39 | 40 | @Override 41 | public void setCZz(double[] scZz) { 42 | cell.setCZz(scZz); 43 | } 44 | 45 | @Override 46 | public double[] getDeltaC() { 47 | return cell.getDeltaC(); 48 | } 49 | 50 | @Override 51 | public void setDeltaC(double[] deltaC) { 52 | cell.setDeltaC(deltaC); 53 | } 54 | 55 | } 56 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/BiRNNLayer.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | public class BiRNNLayer extends RNNLayer { 4 | private static final long serialVersionUID = 1L; 5 | BiRNNNeuroVo [] vos1; 6 | public BiRNNLayer(int nodeNN, 7 | int t, boolean inHidenLayer, int previousNNN, int nextLayerNN, LayerCfg lc) { 8 | super(nodeNN, t, inHidenLayer, previousNNN, nextLayerNN, lc); 9 | vos1 = new BiRNNNeuroVo[vos0.length]; 10 | for (int i = 0; i < vos1.length; i++) { 11 | vos1[i] = new BiRNNNeuroVo(vos0[i]); 12 | } 13 | } 14 | 15 | public void reverse(int lt) { 16 | for (int i = 0; i < vos1.length; i++) { 17 | vos1[i].reverse(lt); 18 | } 19 | } 20 | 21 | public void reverseBack() { 22 | for (int i = 0; i < vos1.length; i++) { 23 | vos1[i].reverseBack(); 24 | } 25 | } 26 | 27 | public RNNNeuroVo[] getRNNNeuroVos() { 28 | return vos1; 29 | } 30 | 31 | } 32 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/BiRNNNeuroVo.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | public class BiRNNNeuroVo extends RNNNeuroVo { 4 | /** 5 | * 6 | */ 7 | private static final long serialVersionUID = 1L; 8 | 9 | SimpleNeuroVo [] orignalNvs; 10 | SimpleNeuroVo [] reverseNvs; 11 | 12 | RNNNeuroVo real; 13 | public BiRNNNeuroVo(RNNNeuroVo vo) { 14 | this.real = vo; 15 | this.neuroVos = real.neuroVos; 16 | this.orignalNvs = real.neuroVos; 17 | orignalNvs = neuroVos; 18 | reverseNvs = new SimpleNeuroVo[neuroVos.length]; 19 | } 20 | 21 | public BiRNNNeuroVo(int t, boolean inHidenLayer, int previousNNN, 22 | int LayerNN, int blockNN, int nextLayerNN, LayerCfg lc) { 23 | super(t, inHidenLayer, previousNNN, LayerNN, blockNN, nextLayerNN, lc); 24 | orignalNvs = neuroVos; 25 | reverseNvs = new SimpleNeuroVo[neuroVos.length]; 26 | } 27 | 28 | public void reverse(int lt) { 29 | int cnt = 0; 30 | for (int i = lt - 1; i >= 0; i--) { 31 | reverseNvs[cnt++] = neuroVos[i]; 32 | } 33 | neuroVos = reverseNvs; 34 | } 35 | 36 | public void reverseBack() { 37 | neuroVos = orignalNvs; 38 | } 39 | 40 | 41 | 42 | } 43 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/Context.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | import java.io.Serializable; 4 | 5 | public class Context implements Serializable { 6 | 7 | /** 8 | * 9 | */ 10 | private static final long serialVersionUID = 1L; 11 | ContextLayer [] contextLayers; 12 | 13 | public ContextLayer[] getContextLayers() { 14 | return contextLayers; 15 | } 16 | 17 | public void setContextLayers(ContextLayer[] contextLayers) { 18 | this.contextLayers = contextLayers; 19 | } 20 | 21 | public Context() { 22 | 23 | } 24 | 25 | double [] preCxtSc = null; 26 | double [] preCxtAa = null; 27 | public Context(double[] preCxtSc, double[] preCxtAa) { 28 | super(); 29 | this.preCxtSc = preCxtSc; 30 | this.preCxtAa = preCxtAa; 31 | } 32 | public double[] getPreCxtSc() { 33 | return preCxtSc; 34 | } 35 | public void setPreCxtSc(double[] preCxtSc) { 36 | this.preCxtSc = preCxtSc; 37 | } 38 | public double[] getPreCxtAa() { 39 | return preCxtAa; 40 | } 41 | public void setPreCxtAa(double[] preCxtAa) { 42 | this.preCxtAa = preCxtAa; 43 | } 44 | 45 | 46 | 47 | } 48 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/ContextLayer.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | import java.io.Serializable; 4 | 5 | public class ContextLayer implements Serializable { 6 | 7 | private static final long serialVersionUID = 1L; 8 | double [] preCxtSc = null; 9 | double [] preCxtAa = null; 10 | public ContextLayer(double[] preCxtSc, double[] preCxtAa) { 11 | super(); 12 | this.preCxtSc = preCxtSc; 13 | this.preCxtAa = preCxtAa; 14 | } 15 | public double[] getPreCxtSc() { 16 | return preCxtSc; 17 | } 18 | public void setPreCxtSc(double[] preCxtSc) { 19 | this.preCxtSc = preCxtSc; 20 | } 21 | public double[] getPreCxtAa() { 22 | return preCxtAa; 23 | } 24 | public void setPreCxtAa(double[] preCxtAa) { 25 | this.preCxtAa = preCxtAa; 26 | } 27 | 28 | } 29 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/CxtLeverager.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | import java.io.Serializable; 4 | import java.util.ArrayList; 5 | import java.util.List; 6 | 7 | public class CxtLeverager implements IPreCxtProvider, ICxtConsumer, Serializable { 8 | 9 | /** 10 | * 11 | */ 12 | private static final long serialVersionUID = 1L; 13 | 14 | int providerIdx = 0; 15 | 16 | List arr = new ArrayList(); 17 | int consumerIdx = 0; 18 | boolean complete = false; 19 | public void reset() { 20 | providerIdx = 0; 21 | } 22 | 23 | public boolean hasNext() { 24 | if(providerIdx <= arr.size() - 1) { 25 | return true; 26 | } 27 | return false; 28 | } 29 | 30 | public Context next() { 31 | Context ctx = arr.get(providerIdx); 32 | providerIdx ++; 33 | return ctx; 34 | } 35 | 36 | @Override 37 | public void addContext(Context cxt) { 38 | if (!complete) { 39 | arr.add(cxt); 40 | } else { 41 | arr.set(consumerIdx, cxt); 42 | } 43 | 44 | if (requireObj != null) { 45 | synchronized (requireObj) { 46 | requireObj.notify(); 47 | } 48 | requireObj = null; 49 | } 50 | consumerIdx ++; 51 | } 52 | 53 | @Override 54 | public void complete() { 55 | System.out.println("There are "+consumerIdx+" contexts"); 56 | complete = true; 57 | consumerIdx = 0; 58 | } 59 | 60 | @Override 61 | public boolean isCompleted() { 62 | return complete; 63 | } 64 | 65 | Object requireObj; 66 | @Override 67 | public void require(Object obj) { 68 | try { 69 | synchronized (obj) { 70 | obj.wait(); 71 | } 72 | } catch (InterruptedException e) { 73 | e.printStackTrace(); 74 | } 75 | requireObj = obj; 76 | } 77 | 78 | } 79 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/CxtLeverager4S2sTraining.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | import java.io.Serializable; 4 | 5 | public class CxtLeverager4S2sTraining implements IPreCxtProvider, ICxtConsumer, Serializable { 6 | 7 | /** 8 | * 9 | */ 10 | private static final long serialVersionUID = 1L; 11 | 12 | public CxtLeverager4S2sTraining(LSTM qLSTM, LSTM aLSTM) { 13 | super(); 14 | this.qLSTM = qLSTM; 15 | this.aLSTM = aLSTM; 16 | } 17 | 18 | Context currentContext = null; 19 | public void reset() { 20 | } 21 | 22 | public boolean hasNext() { 23 | if(currentContext != null) { 24 | return true; 25 | } 26 | return false; 27 | } 28 | 29 | public Context next() { 30 | Context ctx = currentContext; 31 | currentContext = null; 32 | return ctx; 33 | } 34 | 35 | @Override 36 | public void addContext(Context cxt) { 37 | currentContext = cxt; 38 | } 39 | 40 | @Override 41 | public void complete() { 42 | } 43 | 44 | @Override 45 | public boolean isCompleted() { 46 | return false; 47 | } 48 | 49 | LSTM qLSTM; 50 | LSTM aLSTM; 51 | Object requireObj; 52 | @Override 53 | public void require(Object obj) { 54 | Object pos = aLSTM.is.getPos(); 55 | qLSTM.is.next(pos); 56 | qLSTM.test(qLSTM.is.getSampleTT(), qLSTM.is.getTarget()); 57 | // qLSTM.test(sample, targets); 58 | } 59 | 60 | } 61 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/IBPTT.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | 4 | public interface IBPTT extends IRNNLayerVisitor { 5 | 6 | public Context getPreCxts(); 7 | 8 | public void setPreCxts(Context preCxts); 9 | 10 | public Context getHLContext(); 11 | 12 | public double [][] fTT(double [][] sample, boolean test); 13 | 14 | public double runEpich(double [][] sample, 15 | double [][] targets); 16 | 17 | public void fTT4RNNLayer(RNNLayer layer); 18 | 19 | public void fTT4RNNLayer(LSTMLayer layer); 20 | 21 | public void bpTT4RNNLayer(RNNLayer layer); 22 | 23 | public void bpTT4RNNLayer(LSTMLayer layer); 24 | 25 | public void fTT4RNNLayer(ProjectionLayer layer); 26 | 27 | public void bpTT4RNNLayer(ProjectionLayer layer); 28 | 29 | // public void fTT4RNNLayer(BiLstmLayer layer); 30 | // 31 | // public void bpTT4RNNLayer(BiLstmLayer layer); 32 | 33 | // public void updateWw4RNNLayer(RNNLayer layer); 34 | // 35 | // public void updateWw4RNNLayer(LSTMLayer layer); 36 | 37 | } 38 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/IBlock.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | public interface IBlock { 4 | 5 | public IInputGate getInputGate(); 6 | 7 | public void setInputGate(IInputGate inputGate); 8 | 9 | public IOutputGate getOutPutGate(); 10 | 11 | public void setOutPutGate(IOutputGate outPutGate); 12 | 13 | public IForgetGate getForgetGate(); 14 | 15 | public void setForgetGate(IForgetGate forgetGate); 16 | 17 | public ICell[] getCells(); 18 | 19 | public void setCells(ICell[] cells); 20 | 21 | } 22 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/ICell.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | public interface ICell extends IRNNNeuroVo { 4 | 5 | public double [] getSc(); 6 | 7 | public void setSc(double [] sc); 8 | 9 | public double[] getDeltaSc(); 10 | 11 | public void setDeltaSc(double[] deltaSc); 12 | 13 | public double[] getCZz(); 14 | 15 | public void setCZz(double[] scZz); 16 | 17 | public double[] getDeltaC(); 18 | 19 | public void setDeltaC(double[] deltaC); 20 | 21 | } 22 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/ICxtConsumer.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | public interface ICxtConsumer { 4 | 5 | public void addContext(Context cxt); 6 | 7 | public void complete(); 8 | 9 | 10 | } 11 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/IForgetGate.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | public interface IForgetGate extends IRNNNeuroVo { 4 | 5 | } 6 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/IInputGate.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | public interface IInputGate extends IRNNNeuroVo { 4 | 5 | } 6 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/ILSTMNeuro.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | public interface ILSTMNeuro { 4 | 5 | } 6 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/IOutputGate.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | public interface IOutputGate extends IRNNNeuroVo { 4 | 5 | } 6 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/IPreCxtProvider.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | public interface IPreCxtProvider { 4 | 5 | public void reset(); 6 | 7 | public boolean hasNext(); 8 | 9 | public Context next(); 10 | 11 | public boolean isCompleted(); 12 | 13 | public void require(Object obj); 14 | 15 | } 16 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/IRNNLayer.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | public interface IRNNLayer { 4 | 5 | public void fTT(IBPTT bptt); 6 | 7 | public void bpTT(IBPTT bptt); 8 | 9 | public RNNNeuroVo [] getRNNNeuroVos(); 10 | 11 | public void setRNNNeuroVos(RNNNeuroVo [] rnnvos); 12 | 13 | public void updateWw(IRNNLayerVisitor visitor); 14 | 15 | public LayerCfg getLc(); 16 | 17 | public void setLc(LayerCfg lc); 18 | 19 | // public ICell [] getCells(); 20 | 21 | } 22 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/IRNNLayerVisitor.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | public interface IRNNLayerVisitor { 4 | 5 | public void updateWw4RNNLayer(RNNLayer layer); 6 | 7 | public void updateWw4RNNLayer(LSTMLayer layer); 8 | 9 | public void updateWw4RNNLayer(ProjectionLayer layer); 10 | 11 | // public void updateWw4RNNLayer(BiLstmLayer layer); 12 | 13 | } 14 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/IRNNNeuroVo.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | public interface IRNNNeuroVo { 4 | 5 | public double[] getLwWs(); 6 | 7 | public void setLwWs(double[] lwWs); 8 | 9 | public double[] getDeltaLwWs(); 10 | 11 | public void setDeltaLwWs(double[] deltaLwWs); 12 | 13 | public double[] getwWs() ; 14 | 15 | public void setwWs(double[] wWs); 16 | 17 | public double[] getRwWs(); 18 | 19 | public void setRwWs(double[] rwWs); 20 | 21 | public SimpleNeuroVo[] getNvTT(); 22 | 23 | public void setNeuroVos(SimpleNeuroVo[] neuroVos); 24 | 25 | public int getT(); 26 | 27 | public void setT(int t); 28 | 29 | public boolean isInHidenLayer(); 30 | 31 | public void setInHidenLayer(boolean inHidenLayer); 32 | 33 | public int getPreviousNNN(); 34 | 35 | public void setPreviousNNN(int previousNNN); 36 | 37 | public double[] getDeltaWWs(); 38 | 39 | public void setDeltaWWs(double[] deltaWWs); 40 | 41 | public double[] getDeltaRwWs(); 42 | 43 | public void setDeltaRwWs(double[] deltaRwWs); 44 | 45 | public double[] getxWWs(); 46 | 47 | 48 | public void setxWWs(double[] xWWs); 49 | 50 | 51 | public double[] getxRwWs(); 52 | 53 | 54 | public void setxRwWs(double[] xRwWs); 55 | 56 | 57 | public double[] getxLwWs(); 58 | 59 | 60 | public void setxLwWs(double[] xLwWs); 61 | 62 | 63 | 64 | } 65 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/IStream.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | public interface IStream { 4 | 5 | public void reset(); 6 | 7 | public boolean hasNext(); 8 | 9 | public void next(); 10 | 11 | public double [][] getSampleTT(); 12 | 13 | public double [][] getTarget(); 14 | 15 | public int getSampleTTLength(); 16 | 17 | public int getSampleFeatureNum(); 18 | 19 | public int getTargetFeatureNum(); 20 | 21 | public Object getPos(); 22 | 23 | public void next(Object pos); 24 | 25 | public IStream[] splitStream(int cnt); 26 | 27 | public int splitCnt(int cnt); 28 | 29 | } 30 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/ITest.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | import java.io.Serializable; 4 | 5 | public interface ITest extends Serializable { 6 | 7 | public void test() throws Exception; 8 | 9 | } 10 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/ITimePeriod.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | public interface ITimePeriod { 4 | 5 | public double getPeriod(); 6 | 7 | public void setPeriod(double period); 8 | 9 | } 10 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/LSTMLayer.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | import java.io.Serializable; 4 | 5 | import deepDriver.dl.aml.lstm.imp.Block; 6 | 7 | public class LSTMLayer implements IRNNLayer, Serializable { 8 | 9 | /** 10 | * 11 | */ 12 | private static final long serialVersionUID = 1L; 13 | Block [] blocks; 14 | LayerCfg lc; 15 | public LSTMLayer() { 16 | 17 | } 18 | 19 | public LSTMLayer(int nodeNN, int t, boolean inHidenLayer, int previousNNN, int nextLayerNN, LayerCfg lc) { 20 | blocks = new Block[1]; 21 | this.lc = lc; 22 | blocks[0] = new Block(nodeNN, nodeNN, t, inHidenLayer, previousNNN, nextLayerNN, lc); 23 | 24 | } 25 | 26 | public LayerCfg getLc() { 27 | return lc; 28 | } 29 | 30 | public void setLc(LayerCfg lc) { 31 | this.lc = lc; 32 | } 33 | 34 | @Override 35 | public RNNNeuroVo[] getRNNNeuroVos() { 36 | return blocks[0].getRNNNeuroVos(); 37 | } 38 | @Override 39 | public void fTT(IBPTT bptt) { 40 | bptt.fTT4RNNLayer(this); 41 | } 42 | @Override 43 | public void bpTT(IBPTT bptt) { 44 | bptt.bpTT4RNNLayer(this); 45 | } 46 | public Block[] getBlocks() { 47 | return blocks; 48 | } 49 | public void setBlocks(Block[] blocks) { 50 | this.blocks = blocks; 51 | } 52 | @Override 53 | public void updateWw(IRNNLayerVisitor bptt) { 54 | bptt.updateWw4RNNLayer(this); 55 | } 56 | @Override 57 | public void setRNNNeuroVos(RNNNeuroVo[] rnnvos) { 58 | } 59 | 60 | public ICell[] getCells() { 61 | return blocks[0].getCells(); 62 | } 63 | 64 | } 65 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/LSTMLayerV2.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | import java.io.Serializable; 4 | 5 | import deepDriver.dl.aml.lstm.imp.Block; 6 | import deepDriver.dl.aml.lstm.imp.Cell; 7 | 8 | public class LSTMLayerV2 extends LSTMLayer implements IRNNLayer, Serializable { 9 | 10 | /** 11 | * 12 | */ 13 | private static final long serialVersionUID = 1L; 14 | Block [] blocks; 15 | Cell [] cells; 16 | int cellN = 1; 17 | public LSTMLayerV2(int nodeNN, int t, boolean inHidenLayer, int previousNNN, int nextLayerNN, LayerCfg lc) { 18 | super(nodeNN, t, inHidenLayer, previousNNN, nextLayerNN, lc); 19 | blocks = new Block[nodeNN]; 20 | cells = new Cell[nodeNN * cellN]; 21 | int cnt = 0; 22 | for (int i = 0; i < blocks.length; i++) { 23 | blocks[i] = new Block(cells.length, cellN, t, inHidenLayer, previousNNN, nextLayerNN, lc); 24 | Cell [] bcs = (Cell[]) blocks[i].getCells(); 25 | for (int j = 0; j < bcs.length; j++) { 26 | cells[cnt ++] = bcs[j]; 27 | } 28 | } 29 | } 30 | @Override 31 | public RNNNeuroVo[] getRNNNeuroVos() { 32 | return cells; 33 | } 34 | 35 | public ICell [] getCells() { 36 | return cells; 37 | } 38 | @Override 39 | public void fTT(IBPTT bptt) { 40 | bptt.fTT4RNNLayer(this); 41 | } 42 | @Override 43 | public void bpTT(IBPTT bptt) { 44 | bptt.bpTT4RNNLayer(this); 45 | } 46 | public Block[] getBlocks() { 47 | return blocks; 48 | } 49 | public void setBlocks(Block[] blocks) { 50 | this.blocks = blocks; 51 | } 52 | @Override 53 | public void updateWw(IRNNLayerVisitor bptt) { 54 | bptt.updateWw4RNNLayer(this); 55 | } 56 | @Override 57 | public void setRNNNeuroVos(RNNNeuroVo[] rnnvos) { 58 | } 59 | 60 | } 61 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/LayerCfg.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | import java.io.Serializable; 4 | 5 | public class LayerCfg implements Serializable { 6 | 7 | /** 8 | * 9 | */ 10 | private static final long serialVersionUID = 1L; 11 | int attentionLength; 12 | 13 | public int getAttentionLength() { 14 | return attentionLength; 15 | } 16 | 17 | public void setAttentionLength(int attentionLength) { 18 | this.attentionLength = attentionLength; 19 | } 20 | 21 | } 22 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/NeuroNetworkArchitecture.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | import java.io.Serializable; 4 | 5 | public class NeuroNetworkArchitecture implements Serializable { 6 | 7 | /** 8 | * 9 | */ 10 | private static final long serialVersionUID = 1L; 11 | int [] nnArch; 12 | int costFunction = LSTMConfigurator.LEAST_SQUARE; 13 | 14 | boolean useProjectionLayer = false; 15 | 16 | public static int HiddenLSTM = 1; 17 | public static int HiddenRNN = 2; 18 | 19 | int hiddenType = HiddenLSTM; 20 | 21 | 22 | 23 | public int getHiddenType() { 24 | return hiddenType; 25 | } 26 | 27 | public void setHiddenType(int hiddenType) { 28 | this.hiddenType = hiddenType; 29 | } 30 | 31 | public boolean isUseProjectionLayer() { 32 | return useProjectionLayer; 33 | } 34 | 35 | public void setUseProjectionLayer(boolean useProjectionLayer) { 36 | this.useProjectionLayer = useProjectionLayer; 37 | } 38 | 39 | public int[] getNnArch() { 40 | return nnArch; 41 | } 42 | 43 | public void setNnArch(int[] nnArch) { 44 | this.nnArch = nnArch; 45 | } 46 | 47 | public int getCostFunction() { 48 | return costFunction; 49 | } 50 | 51 | public void setCostFunction(int costFunction) { 52 | this.costFunction = costFunction; 53 | } 54 | 55 | 56 | } 57 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/PosValue.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | public class PosValue { 4 | 5 | int pos; 6 | double value; 7 | public int getPos() { 8 | return pos; 9 | } 10 | public void setPos(int pos) { 11 | this.pos = pos; 12 | } 13 | public double getValue() { 14 | return value; 15 | } 16 | public void setValue(double value) { 17 | this.value = value; 18 | } 19 | 20 | } 21 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/PreCxtProvider.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | public interface PreCxtProvider { 4 | 5 | public void reset(); 6 | 7 | public boolean hasNext(); 8 | 9 | public double [] next(); 10 | 11 | public boolean isCompleted(); 12 | 13 | public void require(Object obj); 14 | 15 | } 16 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/RNNLayer.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | import java.io.Serializable; 4 | 5 | 6 | public class RNNLayer implements IRNNLayer, Serializable { 7 | /** 8 | * 9 | */ 10 | private static final long serialVersionUID = 1L; 11 | RNNNeuroVo [] vos0; 12 | 13 | LayerCfg lc; 14 | 15 | public RNNLayer(int nodeNN, 16 | int t, boolean inHidenLayer, int previousNNN, int nextLayerNN, LayerCfg lc) { 17 | this.lc = lc; 18 | vos0 = new RNNNeuroVo[nodeNN]; 19 | for (int i = 0; i < vos0.length; i++) { 20 | // vos0[i] = new RNNNeuroVo(t, inHidenLayer, previousNNN, nodeNN, nodeNN); 21 | vos0[i] = new RNNNeuroVo(t, inHidenLayer, previousNNN, nodeNN, 0, nextLayerNN, lc); 22 | } 23 | } 24 | 25 | public LayerCfg getLc() { 26 | return lc; 27 | } 28 | 29 | public void setLc(LayerCfg lc) { 30 | this.lc = lc; 31 | } 32 | 33 | @Override 34 | public RNNNeuroVo[] getRNNNeuroVos() { 35 | return vos0; 36 | } 37 | 38 | @Override 39 | public void fTT(IBPTT bptt) { 40 | bptt.fTT4RNNLayer(this); 41 | } 42 | @Override 43 | public void bpTT(IBPTT bptt) { 44 | bptt.bpTT4RNNLayer(this); 45 | } 46 | 47 | public void updateWw(IRNNLayerVisitor bptt) { 48 | bptt.updateWw4RNNLayer(this); 49 | } 50 | 51 | @Override 52 | public void setRNNNeuroVos(RNNNeuroVo[] rnnvos) { 53 | this.vos0 = rnnvos; 54 | } 55 | 56 | } 57 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/RecurrentNeuroNetwork.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | public class RecurrentNeuroNetwork { 4 | 5 | } 6 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/Seq2SeqLSTMConfigurator.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | import java.io.Serializable; 4 | 5 | public class Seq2SeqLSTMConfigurator implements Serializable { 6 | /** 7 | * 8 | */ 9 | private static final long serialVersionUID = 1L; 10 | LSTMConfigurator qlSTMConfigurator; 11 | LSTMConfigurator alSTMConfigurator; 12 | int loop = 0; 13 | boolean testQ = true; 14 | public Seq2SeqLSTMConfigurator(LSTMConfigurator qlSTMConfigurator, 15 | LSTMConfigurator alSTMConfigurator) { 16 | super(); 17 | this.qlSTMConfigurator = qlSTMConfigurator; 18 | this.alSTMConfigurator = alSTMConfigurator; 19 | } 20 | public LSTMConfigurator getQlSTMConfigurator() { 21 | return qlSTMConfigurator; 22 | } 23 | public void setQlSTMConfigurator(LSTMConfigurator qlSTMConfigurator) { 24 | this.qlSTMConfigurator = qlSTMConfigurator; 25 | } 26 | public LSTMConfigurator getAlSTMConfigurator() { 27 | return alSTMConfigurator; 28 | } 29 | public void setAlSTMConfigurator(LSTMConfigurator alSTMConfigurator) { 30 | this.alSTMConfigurator = alSTMConfigurator; 31 | } 32 | public boolean isTestQ() { 33 | return testQ; 34 | } 35 | public void setTestQ(boolean testQ) { 36 | this.testQ = testQ; 37 | } 38 | public int getLoop() { 39 | return loop; 40 | } 41 | public void setLoop(int loop) { 42 | this.loop = loop; 43 | } 44 | 45 | 46 | 47 | } 48 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/SimpleNeuroVo.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm; 2 | 3 | import java.io.Serializable; 4 | 5 | public class SimpleNeuroVo implements Serializable { 6 | /** 7 | * 8 | */ 9 | private static final long serialVersionUID = 1L; 10 | boolean dropOut = false; 11 | double aA; 12 | double zZ; 13 | double deltaZz; 14 | public double getaA() { 15 | return aA; 16 | } 17 | public void setaA(double aA) { 18 | this.aA = aA; 19 | } 20 | public double getzZ() { 21 | return zZ; 22 | } 23 | public void setzZ(double zZ) { 24 | this.zZ = zZ; 25 | } 26 | public double getDeltaZz() { 27 | return deltaZz; 28 | } 29 | public void setDeltaZz(double deltaZz) { 30 | this.deltaZz = deltaZz; 31 | } 32 | public boolean isDropOut() { 33 | return dropOut; 34 | } 35 | public void setDropOut(boolean dropOut) { 36 | this.dropOut = dropOut; 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/apps/ner/test/NerTaggerVerify.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.apps.ner.test; 2 | 3 | import java.util.Properties; 4 | 5 | import deepDriver.dl.aml.lstm.LSTM; 6 | import deepDriver.dl.aml.lstm.apps.ner.NerStream; 7 | import deepDriver.dl.aml.lstm.apps.ner.NerTagger; 8 | import deepDriver.dl.aml.lstm.apps.util.StringUtils; 9 | 10 | public class NerTaggerVerify { 11 | 12 | static int qFile = 0; 13 | static int testA = 2; 14 | 15 | public static void main(String[] args) throws Exception { 16 | NerTagger tagger = new NerTagger(); 17 | Properties prop = StringUtils.parseArgs(args); 18 | 19 | // Read from local property files 20 | // String pf = 21 | // "C:\\workspace\\DeepDriver\\properties\\tagger_default.properties"; 22 | // InputStream in = new BufferedInputStream(new FileInputStream(pf)); 23 | // Properties prop = new Properties(); 24 | // prop.load(in); 25 | // prop.setProperty("sqFile", 26 | // "C:\\workspace\\DeepDriver\\data\\china_daily_1472631418124_0.m"); 27 | 28 | System.out.println("###### Reading Training Datasets ######"); 29 | NerStream psTrain = tagger.loadDataset(prop, "train", false); 30 | System.out.println("###### Reading Development Datasets ######"); 31 | NerStream psDev = tagger.loadDataset(prop, "dev", false); 32 | System.out.println("###### Reading Test Datasets ######"); 33 | NerStream psTest = tagger.loadDataset(prop, "test", false); 34 | 35 | int tSize = psTrain.getTargetFeatureNum(); 36 | int fSize = psTrain.getSampleFeatureNum(); 37 | final LSTM qlstm = tagger.createModel(prop, tSize, fSize); 38 | 39 | if (prop.getProperty("devFile")!=null) { 40 | System.out.println("###### Evaluate Development Datasets ######"); 41 | qlstm.testModel(psDev); 42 | } 43 | if (prop.getProperty("trainFile")!=null) { 44 | System.out.println("###### Evaluate Training Datasets ######"); 45 | qlstm.testModel(psTrain); 46 | } 47 | if (prop.getProperty("testFile")!=null) { 48 | System.out.println("###### Evaluate Test Datasets ######"); 49 | qlstm.testModel(psTest); 50 | } 51 | 52 | } 53 | 54 | } 55 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/apps/pos/test/PosTaggerTest.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.apps.pos.test; 2 | 3 | import java.util.List; 4 | import java.util.Properties; 5 | 6 | import deepDriver.dl.aml.lstm.apps.pos.PosTagger; 7 | import deepDriver.dl.aml.lstm.apps.util.TaggedWord; 8 | 9 | /** 10 | * Demo for Pos Tagger predict method API 11 | * */ 12 | 13 | public class PosTaggerTest { 14 | 15 | public static void main(String[] args) { 16 | 17 | Properties prop = new Properties(); 18 | prop.setProperty("dictFile", "C:/workspace/DeepDriver/models/POS/199801_dict.txt"); 19 | prop.setProperty("sqFile", "C:/workspace/DeepDriver/models/POS/china_daily_1472638305564_3.m"); 20 | 21 | String s1 = "新华社 北京 十二月 二十五日 电 ( 记者 )"; 22 | PosTagger tagger = new PosTagger(prop); 23 | List line = tagger.predict(s1); 24 | System.out.println(line.toString()); 25 | 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/apps/pos/test/PosTaggerVerify.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.apps.pos.test; 2 | 3 | import java.util.Properties; 4 | 5 | import deepDriver.dl.aml.lstm.LSTM; 6 | import deepDriver.dl.aml.lstm.apps.pos.PosStream; 7 | import deepDriver.dl.aml.lstm.apps.pos.PosTagger; 8 | import deepDriver.dl.aml.lstm.apps.util.StringUtils; 9 | 10 | public class PosTaggerVerify { 11 | 12 | static int qFile = 0; 13 | static int testA = 2; 14 | 15 | public static void main(String[] args) throws Exception { 16 | PosTagger tagger = new PosTagger(); 17 | Properties prop = StringUtils.parseArgs(args); 18 | 19 | // Read from local property files 20 | // String pf = 21 | // "C:\\workspace\\DeepDriver\\properties\\tagger_default.properties"; 22 | // InputStream in = new BufferedInputStream(new FileInputStream(pf)); 23 | // Properties prop = new Properties(); 24 | // prop.load(in); 25 | // prop.setProperty("sqFile", 26 | // "C:\\workspace\\DeepDriver\\data\\china_daily_1472631418124_0.m"); 27 | 28 | System.out.println("###### Reading Training Datasets ######"); 29 | PosStream psTrain = tagger.loadDataset(prop, "train", false); 30 | System.out.println("###### Reading Development Datasets ######"); 31 | PosStream psDev = tagger.loadDataset(prop, "dev", false); 32 | System.out.println("###### Reading Test Datasets ######"); 33 | PosStream psTest = tagger.loadDataset(prop, "test", false); 34 | 35 | int tSize = psTrain.getTargetFeatureNum(); 36 | int fSize = psTrain.getSampleFeatureNum(); 37 | final LSTM qlstm = tagger.createModel(prop, tSize, fSize); 38 | 39 | if (prop.getProperty("devFile")!=null) { 40 | System.out.println("###### Evaluate Development Datasets ######"); 41 | qlstm.testModel(psDev); 42 | } 43 | if (prop.getProperty("trainFile")!=null) { 44 | System.out.println("###### Evaluate Training Datasets ######"); 45 | qlstm.testModel(psTrain); 46 | } 47 | if (prop.getProperty("testFile")!=null) { 48 | System.out.println("###### Evaluate Test Datasets ######"); 49 | qlstm.testModel(psTest); 50 | } 51 | 52 | } 53 | 54 | } 55 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/apps/util/FeatureFactory.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.apps.util; 2 | 3 | import java.util.List; 4 | 5 | /** 6 | * Feature Factory to control the embedding feature 7 | * 8 | * */ 9 | public class FeatureFactory { 10 | 11 | Embedding embedding; 12 | int dim; // embedding dimensions 13 | 14 | public FeatureFactory(Embedding embedding) { 15 | this.embedding = embedding; 16 | this.dim = embedding.dim; 17 | } 18 | 19 | /** 20 | * Concatenate inputs to the first layer of LSTM model 21 | * */ 22 | 23 | public double[] getEmbedFeature(List sen, int index, int window) { 24 | int ngram = 2 * window + 1; 25 | int length = sen.size(); // length of sentences 26 | double[][] mat = new double[ngram][dim]; 27 | for (int i = (index - window); i < (index + window + 1); i++) { 28 | int s = i - (index - window); 29 | if (i >= 0 && i < length) { 30 | mat[s] = embedding.getWordVec(sen.get(i).word()); 31 | } else if (i < length) { // 0 padding 32 | mat[s] = new double[dim]; 33 | } 34 | } 35 | return concat(mat); 36 | } 37 | 38 | private double[] concat(double[][] mat){ 39 | int nrow = mat.length; 40 | int ncol = mat[0].length; 41 | double[] vec = new double[nrow * ncol]; 42 | for (int i = 0; i < nrow; i++) { 43 | for (int j = 0; j < ncol; j++) { 44 | int idx = i * ncol + j; 45 | vec[idx] = mat[i][j]; 46 | } 47 | } 48 | return vec; 49 | } 50 | 51 | public double[] getOneHotFeature(List sen, int index) { 52 | // To Do 53 | return null; 54 | } 55 | 56 | } -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/apps/util/StringUtils.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.apps.util; 2 | 3 | import java.util.ArrayList; 4 | import java.util.HashSet; 5 | import java.util.List; 6 | import java.util.Properties; 7 | import java.util.Set; 8 | 9 | public class StringUtils { 10 | 11 | public static String joinString(List list, String delimiter) { 12 | StringBuilder builder = new StringBuilder(); 13 | if (list.size() > 0) { 14 | builder.append(list.get(0)); 15 | } 16 | for (int i = 1; i < list.size(); i++) { 17 | builder.append(delimiter).append(list.get(i)); 18 | } 19 | return builder.toString(); 20 | } 21 | 22 | public static Properties parseArgs(String[] args) { 23 | Properties prop = new Properties(); 24 | List remainingArgs = new ArrayList(); 25 | if (args.length == 0) { 26 | return prop; 27 | } else { 28 | Set index = new HashSet(); 29 | for (int i = 0; i < (args.length-1); i++) { 30 | String k1 = args[i]; 31 | String k2 = args[i+1]; 32 | if (!k1.isEmpty() && k1.charAt(0) == '-' 33 | && !k2.isEmpty() && k2.charAt(0) != '-'){ 34 | k1 = k1.substring(1, k1.length()); //remove '-' 35 | prop.setProperty(k1, k2); // flag, value 36 | index.add(i); 37 | index.add(i+1); 38 | } 39 | } 40 | for (int i = 0; i < args.length; i++) { 41 | String k3 = args[i]; 42 | if (!k3.isEmpty() && k3.charAt(0) != '-' 43 | && !index.contains(k3)) 44 | remainingArgs.add(k3); 45 | } 46 | if (!remainingArgs.isEmpty()) 47 | prop.setProperty("OTHERS", joinString(remainingArgs, " ")); 48 | } 49 | return prop; 50 | } 51 | 52 | } 53 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/apps/util/TaggedWord.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.apps.util; 2 | 3 | /** 4 | * Util Class for Labeling Tagged Word [word/tag] 5 | * */ 6 | 7 | public class TaggedWord { 8 | 9 | private String word; 10 | private String tag; 11 | 12 | private static final String DIVIDER = "/"; 13 | 14 | public TaggedWord() { 15 | super(); 16 | } 17 | 18 | public TaggedWord(String word) { 19 | this.word = word; 20 | } 21 | 22 | public TaggedWord(String word, String tag) { 23 | this.word = word; 24 | this.tag = tag; 25 | } 26 | 27 | public String tag() { 28 | return tag; 29 | } 30 | 31 | public void setTag(String tag) { 32 | this.tag = tag; 33 | } 34 | 35 | public String word() { 36 | return word; 37 | } 38 | 39 | public void setWord(String word) { 40 | this.word = word; 41 | } 42 | 43 | @Override 44 | public String toString() { 45 | return toString(DIVIDER); 46 | } 47 | 48 | public String toString(String divider) { 49 | return word + divider + tag; 50 | } 51 | 52 | public void setFromString(String taggedWord) { 53 | setFromString(taggedWord, DIVIDER); 54 | } 55 | 56 | public void setFromString(String taggedWord, String divider) { 57 | int where = taggedWord.lastIndexOf(divider); 58 | if (where >= 0) { 59 | setWord(taggedWord.substring(0, where)); 60 | setTag(taggedWord.substring(where + 1)); 61 | } else { 62 | setWord(taggedWord); 63 | setTag(null); 64 | } 65 | } 66 | 67 | } 68 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/apps/wordSegmentation/WordSegment.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.apps.wordSegmentation; 2 | 3 | import java.io.Serializable; 4 | 5 | public class WordSegment implements Serializable { 6 | /** 7 | * 8 | */ 9 | private static final long serialVersionUID = 1L; 10 | String [] words; 11 | int [] wordsInt; 12 | 13 | WordSegment previous; 14 | WordSegment next; 15 | 16 | public String[] getWords() { 17 | return words; 18 | } 19 | public void setWords(String[] words) { 20 | this.words = words; 21 | } 22 | public WordSegment getPrevious() { 23 | return previous; 24 | } 25 | public void setPrevious(WordSegment previous) { 26 | this.previous = previous; 27 | } 28 | public WordSegment getNext() { 29 | return next; 30 | } 31 | public void setNext(WordSegment next) { 32 | this.next = next; 33 | } 34 | public int[] getWordsInt() { 35 | return wordsInt; 36 | } 37 | public void setWordsInt(int[] wordsInt) { 38 | this.wordsInt = wordsInt; 39 | } 40 | 41 | } 42 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/attentionEnDecoder/AttentionCfg.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.attentionEnDecoder; 2 | 3 | import java.io.Serializable; 4 | 5 | import deepDriver.dl.aml.lstm.LSTMConfigurator; 6 | import deepDriver.dl.aml.lstm.LstmAttention; 7 | 8 | public class AttentionCfg implements Serializable { 9 | 10 | /** 11 | * 12 | */ 13 | private static final long serialVersionUID = 1L; 14 | 15 | LSTMConfigurator qcfg; 16 | LSTMConfigurator acfg; 17 | LstmAttention attention; 18 | String name = "attention"; 19 | 20 | public AttentionCfg(LSTMConfigurator qcfg, LSTMConfigurator acfg, 21 | LstmAttention attention) { 22 | super(); 23 | this.qcfg = qcfg; 24 | this.acfg = acfg; 25 | this.attention = attention; 26 | } 27 | 28 | 29 | public String getName() { 30 | return name; 31 | } 32 | 33 | 34 | 35 | public void setName(String name) { 36 | this.name = name; 37 | } 38 | 39 | 40 | 41 | public LSTMConfigurator getQcfg() { 42 | return qcfg; 43 | } 44 | public void setQcfg(LSTMConfigurator qcfg) { 45 | this.qcfg = qcfg; 46 | } 47 | public LSTMConfigurator getAcfg() { 48 | return acfg; 49 | } 50 | public void setAcfg(LSTMConfigurator acfg) { 51 | this.acfg = acfg; 52 | } 53 | public LstmAttention getAttention() { 54 | return attention; 55 | } 56 | public void setAttention(LstmAttention attention) { 57 | this.attention = attention; 58 | } 59 | 60 | 61 | 62 | } 63 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/attentionEnDecoder/test/TestAttentionEnDecoderQABbSystem.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.attentionEnDecoder.test; 2 | 3 | import deepDriver.dl.aml.lstm.attentionEnDecoder.AttentionEnDecoderLSTM; 4 | 5 | public class TestAttentionEnDecoderQABbSystem { 6 | 7 | static int qFile = 1; 8 | static int testA = 2; 9 | 10 | public static void main(String[] args) throws Exception { 11 | AttEn2DeSetup encoder2DecoderSetup = new AttEn2DeSetup(); 12 | encoder2DecoderSetup.setSetupDic(true); 13 | encoder2DecoderSetup.bootstrap(null, false); 14 | AttentionEnDecoderLSTM encoderDecoderLSTM = new AttentionEnDecoderLSTM(encoder2DecoderSetup.getQcfg(), 15 | encoder2DecoderSetup.getAcfg()); 16 | encoderDecoderLSTM.trainModel(encoder2DecoderSetup.getQsi(), 17 | encoder2DecoderSetup.getAsi(), false); 18 | 19 | 20 | } 21 | 22 | } 23 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/beamSearch/BeamLayer.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.beamSearch; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | public class BeamLayer { 7 | 8 | List bns = new ArrayList(); 9 | 10 | public List getBns() { 11 | return bns; 12 | } 13 | 14 | public void setBns(List bns) { 15 | this.bns = bns; 16 | } 17 | 18 | 19 | } 20 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/beamSearch/BeamNode.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.beamSearch; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | public class BeamNode { 7 | BeamNode parent; 8 | // List children = new ArrayList(); 9 | BeamLayer bl; 10 | 11 | double prob; 12 | double pos; 13 | 14 | public BeamNode(BeamNode parent, BeamLayer bl, double prob, double pos) { 15 | super(); 16 | this.parent = parent; 17 | this.bl = bl; 18 | this.prob = prob; 19 | this.pos = pos; 20 | } 21 | public BeamLayer getBl() { 22 | return bl; 23 | } 24 | public void setBl(BeamLayer bl) { 25 | this.bl = bl; 26 | } 27 | public BeamNode getParent() { 28 | return parent; 29 | } 30 | public void setParent(BeamNode parent) { 31 | this.parent = parent; 32 | } 33 | // public List getChildren() { 34 | // return children; 35 | // } 36 | // public void setChildren(List children) { 37 | // this.children = children; 38 | // } 39 | public double getProb() { 40 | return prob; 41 | } 42 | public void setProb(double prob) { 43 | this.prob = prob; 44 | } 45 | 46 | public double getPos() { 47 | return pos; 48 | } 49 | public void setPos(double pos) { 50 | this.pos = pos; 51 | } 52 | 53 | } 54 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/bidirection/test/TestBiLstmWS.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.bidirection.test; 2 | 3 | public class TestBiLstmWS { 4 | 5 | public static void main(String[] args) { 6 | 7 | 8 | 9 | } 10 | 11 | } 12 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/conversation/Encoder2DecoderConversation.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.conversation; 2 | 3 | 4 | public class Encoder2DecoderConversation { 5 | 6 | // 7 | 8 | 9 | } 10 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/data/CfgDataCleaner.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.data; 2 | 3 | public class CfgDataCleaner { 4 | 5 | public static void clean(double [][][] data) { 6 | for (int i = 0; i < data.length; i++) { 7 | clean(data[i]); 8 | data[i] = null; 9 | } 10 | } 11 | 12 | public static void clean(double [][] data) { 13 | for (int i = 0; i < data.length; i++) { 14 | data[i] = null; 15 | } 16 | } 17 | 18 | public void clean(double [] data) { 19 | } 20 | 21 | } 22 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/data/LSTMCfgData.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.data; 2 | 3 | import java.io.Serializable; 4 | 5 | public class LSTMCfgData implements Serializable { 6 | 7 | double [][][] cfg = null; 8 | int type; 9 | int loop; 10 | 11 | public int getType() { 12 | return type; 13 | } 14 | 15 | public void setType(int type) { 16 | this.type = type; 17 | } 18 | 19 | public int getLoop() { 20 | return loop; 21 | } 22 | 23 | public void setLoop(int loop) { 24 | this.loop = loop; 25 | } 26 | 27 | public double[][][] getCfg() { 28 | return cfg; 29 | } 30 | 31 | public void setCfg(double[][][] cfg) { 32 | this.cfg = cfg; 33 | } 34 | 35 | 36 | } 37 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/distribution/LSTMSlave.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.distribution; 2 | 3 | import deepDriver.dl.aml.distribution.Error; 4 | import deepDriver.dl.aml.distribution.Slave; 5 | import deepDriver.dl.aml.lstm.IStream; 6 | import deepDriver.dl.aml.lstm.LSTM; 7 | import deepDriver.dl.aml.lstm.LSTMWwArrayTranslator; 8 | 9 | public class LSTMSlave extends Slave { 10 | 11 | LSTM lstm; 12 | static int mb = 1024; 13 | IStream is; 14 | @Override 15 | public void setTask(Object obj) throws Exception { 16 | is = (IStream) obj; 17 | } 18 | 19 | double err = 0; 20 | Error error = new Error(); 21 | @Override 22 | public void trainLocal() throws Exception { 23 | if (lstm.getbPTT() == null) { 24 | lstm.setbPTT(lstm.createBPTT()); 25 | } 26 | 27 | err = 0; 28 | for (int i = 0; i < mb; i++) { 29 | if (!is.hasNext()) { 30 | is.reset(); 31 | } 32 | is.next(); 33 | err = err + lstm.runEpich(is.getSampleTT(), is.getTarget()); 34 | // cnn.getcNNBP().runTrainEpich(new IDataMatrix[] { dm }, dm.getTarget()); 35 | // err = err + cnn.getcNNBP().getStdError(); 36 | 37 | } 38 | } 39 | 40 | @Override 41 | public Error getError() { 42 | error.setErr(err); 43 | return error; 44 | } 45 | 46 | @Override 47 | public void setSubject(Object obj) { 48 | if (obj instanceof LSTM) { 49 | lstm = (LSTM) obj; 50 | } else { 51 | swWs = (double [][]) obj; 52 | // cnnMerger.merge(cnn, swWs, true); 53 | translator.update(lstm.getCfg(), swWs, true); 54 | } 55 | } 56 | 57 | LSTMWwArrayTranslator translator = new LSTMWwArrayTranslator(); 58 | double [][] wWs; 59 | double [][] swWs; 60 | @Override 61 | public Object getLocalSubject() { 62 | if (wWs == null) { 63 | } 64 | wWs = new double[lstm.getCfg().getLayers().length][]; 65 | translator.update(lstm.getCfg(), wWs, false); 66 | return wWs; 67 | } 68 | 69 | } 70 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/distribution/Seq2SeqAsycSlaveV6Thread.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.distribution; 2 | 3 | public class Seq2SeqAsycSlaveV6Thread extends Thread { 4 | Seq2SeqAsycSlaveV6 slave; 5 | 6 | public Seq2SeqAsycSlaveV6Thread(Seq2SeqAsycSlaveV6 slave) { 7 | super(); 8 | this.slave = slave; 9 | } 10 | 11 | @Override 12 | public void run() { 13 | try { 14 | slave.train(); 15 | } catch (Exception e) { 16 | e.printStackTrace(); 17 | } 18 | } 19 | 20 | } 21 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/distribution/Seq2SeqLSTMBoostrapper.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.distribution; 2 | 3 | 4 | import deepDriver.dl.aml.lstm.NeuroNetworkArchitecture; 5 | import deepDriver.dl.aml.lstm.Seq2SeqLSTM; 6 | import deepDriver.dl.aml.string.ANFixedStreamImpV2; 7 | import deepDriver.dl.aml.string.Dictionary; 8 | import deepDriver.dl.aml.string.NFixedStreamImpV2; 9 | 10 | public interface Seq2SeqLSTMBoostrapper { 11 | 12 | public void prepareData(boolean isServer) throws Exception; 13 | 14 | public void bootstrap(SimpleTask task, boolean need4Test) throws Exception; 15 | 16 | public NeuroNetworkArchitecture getNna(); 17 | 18 | public void setNna(NeuroNetworkArchitecture nna); 19 | 20 | public NFixedStreamImpV2 getQsi(); 21 | 22 | public void setQsi(NFixedStreamImpV2 qsi); 23 | 24 | public ANFixedStreamImpV2 getAsi(); 25 | 26 | public void setAsi(ANFixedStreamImpV2 asi); 27 | 28 | public Seq2SeqLSTM getSeq2SeqLSTM(); 29 | 30 | public void setSeq2SeqLSTM(Seq2SeqLSTM seq2SeqLSTM); 31 | 32 | public Dictionary getDic(); 33 | 34 | public void setDic(Dictionary dic); 35 | } 36 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/distribution/Seq2SeqSlaveV3.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.distribution; 2 | 3 | 4 | import deepDriver.dl.aml.lstm.LSTMConfigurator; 5 | import deepDriver.dl.aml.lstm.LSTMWwUpdater; 6 | import deepDriver.dl.aml.lstm.data.CfgDataTransfer; 7 | import deepDriver.dl.aml.lstm.data.LSTMCfgData; 8 | 9 | public class Seq2SeqSlaveV3 extends Seq2SeqSlave { 10 | 11 | public Seq2SeqSlaveV3(Seq2SeqLSTMBoostrapper boot) { 12 | super(boot); 13 | } 14 | 15 | LSTMCfgData cfgFromSrv; 16 | CfgDataTransfer cfgDataTransfer = new CfgDataTransfer(); 17 | public void setSubject(Object obj) { 18 | cfgFromSrv = (LSTMCfgData) obj; 19 | testQ = Seq2SeqMasterV2.isQMode(cfgFromSrv); 20 | System.out.println("Subject is from server, " + 21 | "in Q mode? "+ testQ+", with round"+cfgFromSrv.getLoop()); 22 | seq2SeqLSTM.getCfg().setTestQ(testQ); 23 | if (testQ) { 24 | cfgDataTransfer.copyData2Cfg(cfgFromSrv, seq2SeqLSTM.getCfg().getQlSTMConfigurator().getLayers()); 25 | } else { 26 | cfgDataTransfer.copyData2Cfg(cfgFromSrv, seq2SeqLSTM.getCfg().getAlSTMConfigurator().getLayers()); 27 | } 28 | System.out.println("Already fresh Ww from server" ); 29 | } 30 | 31 | LSTMWwUpdater wWchecker = new LSTMWwUpdater(true, true); 32 | public void testWw(LSTMConfigurator fcfg) { 33 | System.out.println("Prepare to check Wws"); 34 | wWchecker.updatewWs(fcfg, seq2SeqLSTM.getCfg().getQlSTMConfigurator()); 35 | System.out.println("Done to check Wws"); 36 | } 37 | 38 | @Override 39 | public Object getLocalSubject() { 40 | System.out.println("Prepare the local cfg.." ); 41 | if (seq2SeqLSTM.getCfg().isTestQ()) { 42 | cfgFromSrv = cfgDataTransfer.loadCfg( 43 | seq2SeqLSTM.getCfg().getQlSTMConfigurator(). 44 | getLayers()); 45 | } else { 46 | cfgFromSrv = cfgDataTransfer.loadCfg( 47 | seq2SeqLSTM.getCfg().getAlSTMConfigurator(). 48 | getLayers()); 49 | } 50 | System.out.println("Get the local cfg ready" ); 51 | return cfgFromSrv; 52 | } 53 | 54 | } 55 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/distribution/SimpleTask.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.distribution; 2 | 3 | import java.io.Serializable; 4 | 5 | import deepDriver.dl.aml.distribution.ITask; 6 | 7 | public class SimpleTask implements ITask, Serializable { 8 | /** 9 | * 10 | */ 11 | private static final long serialVersionUID = 1L; 12 | int start; 13 | int end; 14 | int mbatch; 15 | 16 | public SimpleTask(int start, int end, int mbatch) { 17 | super(); 18 | this.start = start; 19 | this.end = end; 20 | this.mbatch = mbatch; 21 | } 22 | public int getStart() { 23 | return start; 24 | } 25 | public void setStart(int start) { 26 | this.start = start; 27 | } 28 | public int getEnd() { 29 | return end; 30 | } 31 | public void setEnd(int end) { 32 | this.end = end; 33 | } 34 | public int getMbatch() { 35 | return mbatch; 36 | } 37 | public void setMbatch(int mbatch) { 38 | this.mbatch = mbatch; 39 | } 40 | 41 | } 42 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/distribution/test/Heoo.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.distribution.test; 2 | 3 | public class Heoo { 4 | 5 | double i; 6 | 7 | 8 | } 9 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/distribution/test/Test.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.distribution.test; 2 | 3 | public class Test { 4 | 5 | public static void main(String[] args) throws InterruptedException { 6 | double [][][] aa = new double[1000][][]; 7 | for (int i = 0; i < aa.length; i++) { 8 | aa[i] = new double[6][]; 9 | for (int j = 0; j < aa[i].length; j++) { 10 | aa[i][j] = new double[1000]; 11 | for (int j2 = 0; j2 < aa.length; j2++) { 12 | aa[i][j][j2] = 0.1; 13 | } 14 | } 15 | } 16 | int cnt = 0; 17 | while (true) { 18 | System.out.println(""+cnt ++); 19 | aa = new double[1000][][]; 20 | for (int i = 0; i < aa.length; i++) { 21 | aa[i] = new double[6][]; 22 | for (int j = 0; j < aa[i].length; j++) { 23 | aa[i][j] = new double[1000]; 24 | for (int j2 = 0; j2 < aa.length; j2++) { 25 | aa[i][j][j2] = 0.1; 26 | } 27 | } 28 | } 29 | // Thread.sleep(100); 30 | } 31 | } 32 | 33 | } 34 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/distribution/test/TestLSTMSlave.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.distribution.test; 2 | 3 | import deepDriver.dl.aml.distribution.DistributionEnvCfg; 4 | import deepDriver.dl.aml.distribution.P2PServer; 5 | import deepDriver.dl.aml.lstm.distribution.LSTMSlave; 6 | 7 | public class TestLSTMSlave { 8 | 9 | public static void main(String[] args) throws Exception { 10 | String host = "127.0.0.1"; 11 | if (args != null && args.length >= 1) { 12 | host = args[0]; 13 | } 14 | System.out.println("Connet to Server: "+host); 15 | DistributionEnvCfg.getCfg().set(P2PServer.KEY_SRV_HOST, host); 16 | DistributionEnvCfg.getCfg().set(P2PServer.KEY_SRV_PORT, 8034); 17 | 18 | LSTMSlave lstmSlave = new LSTMSlave(); 19 | lstmSlave.train(); 20 | } 21 | 22 | } 23 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/distribution/test/TestMultipleSeq2SeqAsycSlave.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.distribution.test; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | import deepDriver.dl.aml.lstm.distribution.Seq2SeqAsycSlaveV6; 7 | import deepDriver.dl.aml.lstm.distribution.Seq2SeqAsycSlaveV6Thread; 8 | 9 | public class TestMultipleSeq2SeqAsycSlave { 10 | public static void main(String[] args) throws Exception { 11 | // Seq2SeqSlaveV2 slave = new Seq2SeqSlaveV2(new Seq2SeqLSTMSetup()); 12 | int sn = TestSeq2SeqAsycMaster.CLIENT_NUM; 13 | List ths = 14 | new ArrayList(); 15 | for (int i = 0; i < sn; i++) { 16 | ths.add(new Seq2SeqAsycSlaveV6Thread(new Seq2SeqAsycSlaveV6(new Seq2SeqLSTMSetup()))); 17 | } 18 | for (int i = 0; i < sn; i++) { 19 | ths.get(i).start(); 20 | } 21 | for (int i = 0; i < sn; i++) { 22 | ths.get(i).join(); 23 | } 24 | } 25 | 26 | } 27 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/distribution/test/TestPath.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.distribution.test; 2 | 3 | public class TestPath { 4 | 5 | public static void main(String[] args) { 6 | String s = System.getProperty("java.class.path"); 7 | s = System.getProperty("user.dir"); 8 | System.out.println(s); 9 | } 10 | 11 | } 12 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/distribution/test/TestQAMaster.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.distribution.test; 2 | 3 | import java.io.File; 4 | 5 | 6 | import deepDriver.dl.aml.distribution.DistributionEnvCfg; 7 | import deepDriver.dl.aml.lstm.distribution.Seq2SeqAsycMasterV6; 8 | 9 | public class TestQAMaster { 10 | static int CLIENT_NUM = 2; 11 | 12 | public static void main(String[] args) throws Exception { 13 | DistributionEnvCfg.getCfg(). set(Seq2SeqLSTMSetup.KEY_FS_ROOT, args[0]); 14 | DistributionEnvCfg.getCfg(). set(Seq2SeqLSTMSetup.KEY_TEST_FILE, args[1]); 15 | Seq2SeqAsycMasterV6 master = null; 16 | int sn = CLIENT_NUM; 17 | if (args.length > 2) { 18 | sn = Integer.parseInt(args[2]); 19 | } 20 | if (args.length > 3) { 21 | System.out.println("There are params passed in."); 22 | String sf = System.getProperty("user.dir"); 23 | File mf = new File(sf, "data/"+args[3]); 24 | if (mf.exists()) { 25 | master = new Seq2SeqAsycMasterV6( 26 | sn, new Seq2SeqLSTMSetup(), mf.getAbsolutePath()); 27 | } else { 28 | master = new Seq2SeqAsycMasterV6( 29 | sn, new Seq2SeqLSTMSetup(), null); 30 | } 31 | } else { 32 | master = new Seq2SeqAsycMasterV6( 33 | sn, new Seq2SeqLSTMSetup()); 34 | } 35 | master.train(); 36 | } 37 | 38 | } 39 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/distribution/test/TestQAWorker.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.distribution.test; 2 | 3 | 4 | import deepDriver.dl.aml.distribution.DistributionEnvCfg; 5 | import deepDriver.dl.aml.distribution.P2PServer; 6 | import deepDriver.dl.aml.lstm.distribution.Seq2SeqAsycSlaveV6; 7 | 8 | public class TestQAWorker { 9 | public static void main(String[] args) throws Exception { 10 | DistributionEnvCfg.getCfg(). set(Seq2SeqLSTMSetup.KEY_FS_ROOT, args[0]); 11 | DistributionEnvCfg.getCfg(). set(Seq2SeqLSTMSetup.KEY_TEST_FILE, args[1]); 12 | DistributionEnvCfg.getCfg(). set(P2PServer.KEY_SRV_HOST, args[2]); 13 | // DistributionEnvCfg.getCfg(). set(P2PServer., "10.1.242.48"); 14 | 15 | // Seq2SeqSlaveV2 slave = new Seq2SeqSlaveV2(new Seq2SeqLSTMSetup()); 16 | Seq2SeqAsycSlaveV6 slave = new Seq2SeqAsycSlaveV6(new Seq2SeqLSTMSetup()); 17 | slave.train(); 18 | } 19 | 20 | } 21 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/distribution/test/TestSeq2SeqAsycMaster.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.distribution.test; 2 | 3 | import java.io.File; 4 | 5 | import deepDriver.dl.aml.lstm.distribution.Seq2SeqAsycMasterV6; 6 | 7 | public class TestSeq2SeqAsycMaster { 8 | static int CLIENT_NUM = 4; 9 | 10 | public static void main(String[] args) throws Exception { 11 | Seq2SeqAsycMasterV6 master = null; 12 | int sn = CLIENT_NUM; 13 | if (args.length > 0) { 14 | System.out.println("There are params passed in."); 15 | String sf = System.getProperty("user.dir"); 16 | File mf = new File(sf, "data/"+args[0]); 17 | if (mf.exists()) { 18 | master = new Seq2SeqAsycMasterV6( 19 | sn, new Seq2SeqLSTMSetup(), mf.getAbsolutePath()); 20 | } else { 21 | master = new Seq2SeqAsycMasterV6( 22 | sn, new Seq2SeqLSTMSetup(), null); 23 | } 24 | } else { 25 | master = new Seq2SeqAsycMasterV6( 26 | sn, new Seq2SeqLSTMSetup()); 27 | } 28 | master.train(); 29 | } 30 | 31 | } 32 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/distribution/test/TestSeq2SeqAsycSlave.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.distribution.test; 2 | 3 | import deepDriver.dl.aml.lstm.distribution.Seq2SeqAsycSlaveV6; 4 | 5 | public class TestSeq2SeqAsycSlave { 6 | public static void main(String[] args) throws Exception { 7 | // Seq2SeqSlaveV2 slave = new Seq2SeqSlaveV2(new Seq2SeqLSTMSetup()); 8 | Seq2SeqAsycSlaveV6 slave = new Seq2SeqAsycSlaveV6(new Seq2SeqLSTMSetup()); 9 | slave.train(); 10 | } 11 | 12 | } 13 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/distribution/test/TestSeq2SeqMaster.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.distribution.test; 2 | 3 | import java.io.File; 4 | 5 | 6 | import deepDriver.dl.aml.distribution.DistributionEnvCfg; 7 | import deepDriver.dl.aml.lstm.distribution.Seq2SeqMasterV5; 8 | 9 | public class TestSeq2SeqMaster { 10 | 11 | public static void main(String[] args) throws Exception { 12 | DistributionEnvCfg.getCfg(). set(Seq2SeqLSTMSetup.KEY_FS_ROOT, args[0]); 13 | DistributionEnvCfg.getCfg(). set(Seq2SeqLSTMSetup.KEY_TEST_FILE, args[1]); 14 | Seq2SeqMasterV5 master = null; 15 | if (args.length > 3) { 16 | System.out.println("There are params passed in."); 17 | String sf = System.getProperty("user.dir"); 18 | File mf = new File(sf, "data/"+args[2]); 19 | if (mf.exists()) { 20 | master = new Seq2SeqMasterV5( 21 | 4, new Seq2SeqLSTMSetup(), mf.getAbsolutePath()); 22 | } else { 23 | master = new Seq2SeqMasterV5( 24 | 4, new Seq2SeqLSTMSetup(), null); 25 | } 26 | } else { 27 | master = new Seq2SeqMasterV5( 28 | 4, new Seq2SeqLSTMSetup()); 29 | } 30 | master.train(); 31 | } 32 | 33 | } 34 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/distribution/test/TestSeq2SeqMaster4Srv48.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.distribution.test; 2 | 3 | import java.io.File; 4 | 5 | 6 | import deepDriver.dl.aml.distribution.DistributionEnvCfg; 7 | import deepDriver.dl.aml.distribution.P2PServer; 8 | import deepDriver.dl.aml.lstm.distribution.Seq2SeqMasterV5; 9 | 10 | public class TestSeq2SeqMaster4Srv48 { 11 | 12 | public static void main(String[] args) throws Exception { 13 | DistributionEnvCfg.getCfg(). 14 | set(P2PServer.KEY_SRV_HOST, "10.1.242.48"); 15 | Seq2SeqMasterV5 master = null; 16 | if (args.length > 0) { 17 | System.out.println("There are params passed in."); 18 | String sf = System.getProperty("user.dir"); 19 | File mf = new File(sf, "data/"+args[0]); 20 | if (mf.exists()) { 21 | master = new Seq2SeqMasterV5( 22 | 4, new Seq2SeqLSTMSetup(), mf.getAbsolutePath()); 23 | } else { 24 | master = new Seq2SeqMasterV5( 25 | 4, new Seq2SeqLSTMSetup(), null); 26 | } 27 | } else { 28 | master = new Seq2SeqMasterV5( 29 | 4, new Seq2SeqLSTMSetup()); 30 | } 31 | master.train(); 32 | } 33 | 34 | } 35 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/distribution/test/TestSeq2SeqSlave.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.distribution.test; 2 | 3 | 4 | import deepDriver.dl.aml.distribution.DistributionEnvCfg; 5 | import deepDriver.dl.aml.distribution.P2PServer; 6 | import deepDriver.dl.aml.lstm.distribution.Seq2SeqSlaveV5; 7 | 8 | public class TestSeq2SeqSlave { 9 | public static void main(String[] args) throws Exception { 10 | DistributionEnvCfg.getCfg(). set(Seq2SeqLSTMSetup.KEY_FS_ROOT, args[0]); 11 | DistributionEnvCfg.getCfg(). set(Seq2SeqLSTMSetup.KEY_TEST_FILE, args[1]); 12 | DistributionEnvCfg.getCfg(). set(P2PServer.KEY_SRV_HOST, args[2]); 13 | // Seq2SeqSlaveV2 slave = new Seq2SeqSlaveV2(new Seq2SeqLSTMSetup()); 14 | Seq2SeqSlaveV5 slave = new Seq2SeqSlaveV5(new Seq2SeqLSTMSetup()); 15 | slave.train(); 16 | } 17 | 18 | } 19 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/distribution/test/TestSeq2SeqSlave4Srv48.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.distribution.test; 2 | 3 | 4 | import deepDriver.dl.aml.distribution.DistributionEnvCfg; 5 | import deepDriver.dl.aml.distribution.P2PServer; 6 | import deepDriver.dl.aml.lstm.distribution.Seq2SeqSlaveV5; 7 | 8 | public class TestSeq2SeqSlave4Srv48 { 9 | public static void main(String[] args) throws Exception { 10 | DistributionEnvCfg.getCfg(). 11 | set(P2PServer.KEY_SRV_HOST, "10.1.242.48"); 12 | // Seq2SeqSlaveV2 slave = new Seq2SeqSlaveV2(new Seq2SeqLSTMSetup()); 13 | Seq2SeqSlaveV5 slave = new Seq2SeqSlaveV5(new Seq2SeqLSTMSetup()); 14 | slave.train(); 15 | } 16 | 17 | } 18 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/enDecoder/test/TestQAEnDecoderBabySystem.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.enDecoder.test; 2 | 3 | import deepDriver.dl.aml.lstm.attentionEnDecoder.AttentionEnDecoderLSTM; 4 | 5 | public class TestQAEnDecoderBabySystem { 6 | 7 | static int qFile = 1; 8 | static int testA = 2; 9 | 10 | public static void main(String[] args) throws Exception { 11 | Encoder2DecoderSetup encoder2DecoderSetup = new Encoder2DecoderSetup(); 12 | encoder2DecoderSetup.setSetupDic(true); 13 | encoder2DecoderSetup.bootstrap(null, false); 14 | AttentionEnDecoderLSTM encoderDecoderLSTM = new AttentionEnDecoderLSTM(encoder2DecoderSetup.getQcfg(), 15 | encoder2DecoderSetup.getAcfg()); 16 | encoderDecoderLSTM.trainModel(encoder2DecoderSetup.getQsi(), 17 | encoder2DecoderSetup.getAsi(), false); 18 | 19 | 20 | } 21 | 22 | } 23 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/hred/HredBPTT.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.hred; 2 | 3 | public class HredBPTT { 4 | 5 | } 6 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/imp/Cell.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.imp; 2 | 3 | import java.io.Serializable; 4 | 5 | import deepDriver.dl.aml.lstm.ICell; 6 | import deepDriver.dl.aml.lstm.LayerCfg; 7 | import deepDriver.dl.aml.lstm.RNNNeuroVo; 8 | 9 | public class Cell extends RNNNeuroVo implements ICell, Serializable { 10 | 11 | /** 12 | * 13 | */ 14 | private static final long serialVersionUID = 1L; 15 | // transient double [] sc; 16 | // transient double [] deltaSc; 17 | // 18 | // transient double [] cZz; 19 | // transient double [] deltaC; 20 | double [] sc; 21 | double [] deltaSc; 22 | 23 | double [] cZz; 24 | double [] deltaC; 25 | 26 | public Cell(int t, boolean inHidenLayer, int previousNNN, int layerNN, int blockNN, int nextLayerNN, 27 | LayerCfg lc) { 28 | super(t, inHidenLayer, previousNNN, layerNN, blockNN, nextLayerNN, lc); 29 | } 30 | 31 | public double [] getSc() { 32 | return sc; 33 | } 34 | 35 | public void setSc(double [] sc) { 36 | this.sc = sc; 37 | } 38 | 39 | public double[] getDeltaSc() { 40 | return deltaSc; 41 | } 42 | 43 | public void setDeltaSc(double[] deltaSc) { 44 | this.deltaSc = deltaSc; 45 | } 46 | 47 | public double[] getCZz() { 48 | return cZz; 49 | } 50 | 51 | public void setCZz(double[] scZz) { 52 | this.cZz = scZz; 53 | } 54 | 55 | public double[] getDeltaC() { 56 | return deltaC; 57 | } 58 | 59 | public void setDeltaC(double[] deltaC) { 60 | this.deltaC = deltaC; 61 | } 62 | 63 | 64 | 65 | } 66 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/imp/ForgetGate.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.imp; 2 | 3 | import java.io.Serializable; 4 | 5 | import deepDriver.dl.aml.lstm.IForgetGate; 6 | import deepDriver.dl.aml.lstm.RNNNeuroVo; 7 | 8 | public class ForgetGate extends RNNNeuroVo implements IForgetGate, Serializable { 9 | 10 | /** 11 | * 12 | */ 13 | private static final long serialVersionUID = 1L; 14 | 15 | public ForgetGate(int t, boolean inHidenLayer, int previousNNN, 16 | int LayerNN, int blockNN, int nextLayerNN) { 17 | super(t, inHidenLayer, previousNNN, LayerNN, blockNN, nextLayerNN, null); 18 | } 19 | 20 | } 21 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/imp/InputGate.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.imp; 2 | 3 | import java.io.Serializable; 4 | 5 | import deepDriver.dl.aml.lstm.IInputGate; 6 | import deepDriver.dl.aml.lstm.RNNNeuroVo; 7 | 8 | public class InputGate extends RNNNeuroVo implements IInputGate, Serializable { 9 | 10 | /** 11 | * 12 | */ 13 | private static final long serialVersionUID = 1L; 14 | 15 | public InputGate(int t, boolean inHidenLayer, int previousNNN, 16 | int LayerNN, int blockNN, int nextLayerNN) { 17 | super(t, inHidenLayer, previousNNN, LayerNN, blockNN, nextLayerNN, null); 18 | } 19 | 20 | } 21 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/imp/LSTMNeuro.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.imp; 2 | 3 | import deepDriver.dl.aml.lstm.ILSTMNeuro; 4 | 5 | public class LSTMNeuro implements ILSTMNeuro { 6 | double [][] asTT; 7 | double [][] zZsTT; 8 | double [] weights; 9 | 10 | } 11 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/imp/OutputGate.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.imp; 2 | 3 | import java.io.Serializable; 4 | 5 | import deepDriver.dl.aml.lstm.IOutputGate; 6 | import deepDriver.dl.aml.lstm.RNNNeuroVo; 7 | 8 | public class OutputGate extends RNNNeuroVo implements IOutputGate, Serializable { 9 | 10 | /** 11 | * 12 | */ 13 | private static final long serialVersionUID = 1L; 14 | 15 | public OutputGate(int t, boolean inHidenLayer, int previousNNN, 16 | int LayerNN, int blockNN, int nextLayerNN) { 17 | super(t, inHidenLayer, previousNNN, LayerNN, blockNN, nextLayerNN, null); 18 | } 19 | 20 | } 21 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/imp/TanhAf.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.imp; 2 | 3 | import java.io.Serializable; 4 | 5 | import deepDriver.dl.aml.ann.IActivationFunction; 6 | 7 | public class TanhAf implements IActivationFunction, Serializable { 8 | 9 | /** 10 | * 11 | */ 12 | private static final long serialVersionUID = 1L; 13 | 14 | @Override 15 | public double activate(double x) { 16 | if (x > 100) { 17 | return 1; 18 | } 19 | double t = Math.exp(2 * x); 20 | return (t - 1)/(t + 1); 21 | } 22 | 23 | @Override 24 | public double deActivate(double x) { 25 | if (x > 100) { 26 | return 0; 27 | } 28 | double a = activate(x); 29 | return 1 - a * a; 30 | } 31 | 32 | public static void main(String[] args) { 33 | TanhAf tanhAf = new TanhAf(); 34 | System.out.println(tanhAf.activate(1190000000)); 35 | System.out.println(tanhAf.deActivate(1190000000)); 36 | 37 | System.out.println(tanhAf.activate(-1190000000)); 38 | System.out.println(tanhAf.deActivate(-1190000000)); 39 | System.out.println(tanhAf.activate(0)); 40 | System.out.println(tanhAf.deActivate(0)); 41 | } 42 | 43 | } 44 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/imp/TimePreriod.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.imp; 2 | 3 | import deepDriver.dl.aml.lstm.ITimePeriod; 4 | 5 | public class TimePreriod implements ITimePeriod { 6 | double period; 7 | 8 | public double getPeriod() { 9 | return period; 10 | } 11 | 12 | public void setPeriod(double period) { 13 | this.period = period; 14 | } 15 | 16 | 17 | } 18 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/test/Test1.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.test; 2 | 3 | import java.util.Random; 4 | 5 | public class Test1 { 6 | 7 | public static void main(String[] args) throws InterruptedException { 8 | // Map mp = new HashMap(); 9 | // int a = 1; 10 | // double b = 2; 11 | // mp.put(a, b); 12 | // System.out.println(mp.get(a)); 13 | // mp.put(a, mp.get(a) + 2); 14 | // System.out.println(mp.get(a)); 15 | Random rd = new Random(10000); 16 | int k = 100; 17 | for (int i = 0; i < k; i++) { 18 | System.out.println(rd.nextDouble()); 19 | } 20 | System.out.println("......"); 21 | Thread.sleep(3000); 22 | Random rd1 = new Random(10000); 23 | for (int i = 0; i < k; i++) { 24 | System.out.println(rd1.nextDouble()); 25 | } 26 | } 27 | 28 | } 29 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/lstm/test/TestQa.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.lstm.test; 2 | 3 | 4 | import deepDriver.dl.aml.distribution.DistributionEnvCfg; 5 | import deepDriver.dl.aml.lstm.conversation.Seq2SeqConversation; 6 | 7 | public class TestQa { 8 | 9 | public static void main(String[] args) throws Exception { 10 | Seq2SeqConversation converstion = new Seq2SeqConversation(); 11 | String root = "D:\\6.workspace\\ANN\\lstm\\QaModel\\"; 12 | Seq2SeqBabySysSetup s2s = new Seq2SeqBabySysSetup(); 13 | s2s.setThreadsNum(4); 14 | DistributionEnvCfg.getCfg(). set(Seq2SeqBabySysSetup.KEY_FS_ROOT, root); 15 | DistributionEnvCfg.getCfg(). set(Seq2SeqBabySysSetup.KEY_TEST_FILE, 16 | // "talk2016.txt"); 17 | "talk2015_2016.txt"); 18 | // DistributionEnvCfg.getCfg(). set(P2PServer.KEY_SRV_HOST, args[2]); 19 | 20 | /**converstion.load(s2s, 21 | root+"qModel_v1466671795460_1.m" , root+"aModel_v1467890329260_0.m",// 22 | root+"v.m");**/ 23 | 24 | converstion.load(s2s, 25 | root+"qModel_v1516_1468479736787_0.m" , 26 | root+"aModel_v1516_1468846020244_0.m",// 27 | root+"v_1516.m"); 28 | // converstion.testQ("感谢"); 29 | long l = System.currentTimeMillis(); 30 | converstion.testQas("怎么我的是黄金套餐还要购买才能看电影?", 3, 41); //我账号密码是什么?我看视频一直缓冲为什么如何取消自动续费 你工号多少,我投诉你"" 31 | System.out.println((System.currentTimeMillis() - l)+" costed."); 32 | 33 | } 34 | 35 | } 36 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/math/ContentBasedWeighting.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.math; 2 | 3 | public class ContentBasedWeighting { 4 | 5 | double [][] matrix; 6 | double [][] dm; 7 | 8 | public double[][] getMatrix() { 9 | return matrix; 10 | } 11 | 12 | public void setMatrix(double[][] matrix) { 13 | this.matrix = matrix; 14 | } 15 | 16 | double [] sims; 17 | double [] k; 18 | double beta; 19 | double [] dk; 20 | 21 | double [] sm; 22 | 23 | public double[] getSm() { 24 | return sm; 25 | } 26 | 27 | public void setSm(double[] sm) { 28 | this.sm = sm; 29 | } 30 | 31 | public double [] weighting(double [] k, double beta) { 32 | this.k = k; 33 | this.beta = beta; 34 | if (sims == null) { 35 | sims = new double[matrix.length]; 36 | } 37 | for (int i = 0; i < sims.length; i++) { 38 | sims[i] = MathUtil.cos(k, matrix[i]); 39 | } 40 | sm = MathUtil.softMax(sims, beta); 41 | return sm; 42 | } 43 | 44 | double dbeta = 0; 45 | 46 | public double [] backWeighting(double [] da, double [] k, double beta) { 47 | weighting(k, beta); 48 | 49 | dm = MathUtil.allocate(matrix.length, matrix[0].length); 50 | double [] dsims = MathUtil.difSoftMax4Weighting(da, sims, beta); 51 | dbeta = MathUtil.difSoftMax4Beta(da, sims, beta); 52 | dk = new double[k.length]; 53 | 54 | for (int i = 0; i < matrix.length; i++) { 55 | double [] dk1 = new double[k.length]; 56 | MathUtil.difCos(dsims[i], dk1, k, matrix[i]); 57 | MathUtil.plus2V(dk1, dk); 58 | 59 | MathUtil.difCos(dsims[i], dm[i], matrix[i], k); 60 | } 61 | return dk; 62 | } 63 | 64 | public void update(double l, double m) { 65 | 66 | } 67 | 68 | public double[][] getDm() { 69 | return dm; 70 | } 71 | 72 | public double[] getDk() { 73 | return dk; 74 | } 75 | 76 | public double getDbeta() { 77 | return dbeta; 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/math/IExp.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.math; 2 | 3 | public interface IExp { 4 | 5 | public void compute(Object obj); 6 | 7 | public void difCompute(Object obj); 8 | 9 | public Object getR(); 10 | 11 | } 12 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/math/IExp4Function.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.math; 2 | 3 | public interface IExp4Function { 4 | /* 5 | * y = f(x), x is a vector 6 | * **/ 7 | public void compute(double [] x); 8 | 9 | public void difCompute(double dy, double [] x); 10 | 11 | public double getR(); 12 | 13 | public double [] getDv(); 14 | 15 | public void resetDv(); 16 | 17 | public void update(double l, double m); 18 | 19 | } 20 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/math/IMatrixExp.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.math; 2 | 3 | public interface IMatrixExp { 4 | 5 | public double[] getRs(); 6 | 7 | public void setRs(double[] rs); 8 | 9 | public void compute(double [] x); 10 | 11 | public void difCompute(double [] dy, double [] x); 12 | 13 | public double[] getDv(); 14 | 15 | public void update(double l, double m); 16 | 17 | public void resetDv(); 18 | 19 | } 20 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/math/LinearMatrixExp.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.math; 2 | 3 | public class LinearMatrixExp implements IMatrixExp { 4 | 5 | LinearExp [] linearExps; 6 | 7 | double [] rs; 8 | 9 | double [] dr; 10 | 11 | public LinearMatrixExp(int lNum, int length) { 12 | super(); 13 | linearExps = new LinearExp[lNum]; 14 | for (int i = 0; i < linearExps.length; i++) { 15 | linearExps[i] = new LinearExp(length); 16 | } 17 | rs = new double[lNum]; 18 | } 19 | 20 | public double[] getRs() { 21 | return rs; 22 | } 23 | 24 | public void setRs(double[] rs) { 25 | this.rs = rs; 26 | } 27 | 28 | public void compute(double [] x) { 29 | for (int i = 0; i < linearExps.length; i++) { 30 | linearExps[i].compute(x); 31 | rs[i] = linearExps[i].getR(); 32 | } 33 | } 34 | 35 | public void difCompute(double [] dy, double [] x) { 36 | dr = dy; 37 | for (int i = 0; i < linearExps.length; i++) { 38 | linearExps[i].difCompute(dy[i], x); 39 | } 40 | } 41 | 42 | public double[] getDv() { 43 | double [] dv = new double[linearExps[0].getDv().length]; 44 | for (int i = 0; i < linearExps.length; i++) { 45 | MathUtil.plus2V(linearExps[i].getDv(), 1.0, dv); 46 | } 47 | return dv; 48 | } 49 | 50 | public void update(double l, double m) { 51 | for (int i = 0; i < linearExps.length; i++) { 52 | linearExps[i].update(l, m); 53 | } 54 | } 55 | 56 | public void resetDv() { 57 | for (int i = 0; i < linearExps.length; i++) { 58 | linearExps[i].resetDv(); 59 | } 60 | } 61 | 62 | } 63 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/math/LinearRegression.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.math; 2 | 3 | public class LinearRegression { 4 | 5 | 6 | 7 | } 8 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/math/OnePlusExp.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.math; 2 | 3 | public class OnePlusExp implements IExp4Function { 4 | 5 | LinearExp le; 6 | double r; 7 | 8 | double dr; 9 | double dlr; 10 | 11 | public OnePlusExp(int length) { 12 | this.le = new LinearExp(length); 13 | } 14 | 15 | public void compute(double [] x) { 16 | le.compute(x); 17 | r = MathUtil.onePlus(le.getR()); 18 | } 19 | 20 | public double [] getX() { 21 | return le.x; 22 | } 23 | 24 | public double getR() { 25 | return r; 26 | } 27 | 28 | @Override 29 | public void difCompute(double dy, double [] x) { 30 | dr = dy; 31 | le.compute(x); 32 | dlr = MathUtil.difOnePlus(le.getR()) * dy; 33 | le.difCompute(dlr, x); 34 | } 35 | 36 | @Override 37 | public double[] getDv() { 38 | return le.getDv(); 39 | } 40 | 41 | @Override 42 | public void update(double l, double m) { 43 | le.update(l, m); 44 | } 45 | 46 | public void resetDv() { 47 | le.resetDv(); 48 | } 49 | 50 | public static void main(String[] args) { 51 | OnePlusExp one = new OnePlusExp(6); 52 | double [] x = {0,0,0,0,0,0}; 53 | one.compute(x); 54 | System.out.println(one.getR()); 55 | System.out.println(MathUtil.onePlus(0)); 56 | 57 | } 58 | 59 | } 60 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/math/SigmodExp.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.math; 2 | 3 | public class SigmodExp implements IExp4Function { 4 | 5 | double r; 6 | LinearExp le; 7 | 8 | double dr; 9 | double dlr; 10 | 11 | 12 | public SigmodExp(int length) { 13 | this.le = new LinearExp(length); 14 | } 15 | 16 | public void compute(double [] x) { 17 | le.compute(x); 18 | r = MathUtil.sigmod(le.getR()); 19 | } 20 | 21 | public double getR() { 22 | return r; 23 | } 24 | 25 | @Override 26 | public void difCompute(double dy, double [] x) { 27 | le.compute(x); 28 | dr = dy; 29 | dlr = MathUtil.difSigmod(le.getR()) * dy; 30 | le.difCompute(dlr, x); 31 | } 32 | 33 | public double[] getDv() { 34 | return le.getDv(); 35 | } 36 | 37 | @Override 38 | public void update(double l, double m) { 39 | le.update(l, m); 40 | } 41 | 42 | @Override 43 | public void resetDv() { 44 | le.resetDv(); 45 | } 46 | 47 | public double[] getX() { 48 | return le.x; 49 | } 50 | 51 | public double[] getPara() { 52 | return le.parameters; 53 | } 54 | 55 | public double[] getDl() { 56 | return le.deltaPara; 57 | } 58 | 59 | public double[] getDl2() { 60 | return le.deltaPara2; 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/math/SigmodMatrixExp.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.math; 2 | 3 | public class SigmodMatrixExp implements IMatrixExp { 4 | 5 | SigmodExp [] exps; 6 | 7 | double [] rs; 8 | 9 | double [] dr; 10 | 11 | public SigmodMatrixExp(int lNum, int length) { 12 | super(); 13 | exps = new SigmodExp[lNum]; 14 | for (int i = 0; i < exps.length; i++) { 15 | exps[i] = new SigmodExp(length); 16 | } 17 | rs = new double[lNum]; 18 | } 19 | 20 | public double[] getRs() { 21 | return rs; 22 | } 23 | 24 | public void setRs(double[] rs) { 25 | this.rs = rs; 26 | } 27 | 28 | public void compute(double [] x) { 29 | for (int i = 0; i < exps.length; i++) { 30 | exps[i].compute(x); 31 | rs[i] = exps[i].getR(); 32 | } 33 | } 34 | 35 | public void difCompute(double [] dy, double [] x) { 36 | dr = dy; 37 | for (int i = 0; i < exps.length; i++) { 38 | exps[i].difCompute(dy[i], x); 39 | } 40 | } 41 | 42 | public double[] getDv() { 43 | double [] dv = new double[exps[0].getDv().length]; 44 | for (int i = 0; i < exps.length; i++) { 45 | MathUtil.plus2V(exps[i].getDv(), 1.0, dv); 46 | } 47 | return dv; 48 | } 49 | 50 | public void update(double l, double m) { 51 | for (int i = 0; i < exps.length; i++) { 52 | exps[i].update(l, m); 53 | } 54 | } 55 | 56 | public void resetDv() { 57 | for (int i = 0; i < exps.length; i++) { 58 | exps[i].resetDv(); 59 | } 60 | } 61 | 62 | } 63 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/math/SoftMaxExp.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.math; 2 | 3 | public class SoftMaxExp { 4 | 5 | SigmodMatrixExp me; 6 | double [] r; 7 | 8 | double beta; 9 | 10 | double [] dr; 11 | double [] dlr; 12 | 13 | public SoftMaxExp(int lNum, int length, double beta) { 14 | this.beta = beta; 15 | me = new SigmodMatrixExp(lNum, length); 16 | } 17 | 18 | public void compute(double[] x) { 19 | me.compute(x); 20 | double [] y = me.getRs(); 21 | r = MathUtil.softMax(y, beta); 22 | } 23 | 24 | public void difCompute(double [] dy, double [] x) { 25 | dr = dy; 26 | me.compute(x); 27 | dlr = MathUtil.difSoftMax4Weighting(dy, me.getRs(), beta); 28 | me.difCompute(dlr, x); 29 | } 30 | 31 | public double [] getR() { 32 | return r; 33 | } 34 | 35 | public double[] getDv() { 36 | return me.getDv(); 37 | } 38 | 39 | public void update(double l, double m) { 40 | me.update(l, m); 41 | } 42 | 43 | public void resetDv() { 44 | me.resetDv(); 45 | } 46 | 47 | } 48 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/math/test/Test.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.math.test; 2 | 3 | import deepDriver.dl.aml.math.MathUtil; 4 | 5 | public class Test { 6 | 7 | public static void main(String[] args) { 8 | double b = -2; 9 | double a = Math.sqrt(b); 10 | System.out.println(a); 11 | if (MathUtil.isNaN(a)) { 12 | System.out.println(" a it is "+a); 13 | } 14 | if (MathUtil.isNaN(b)) { 15 | System.out.println("b it is "+b); 16 | } 17 | } 18 | 19 | } 20 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/math/test/TestCos.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.math.test; 2 | 3 | import deepDriver.dl.aml.math.MathUtil; 4 | 5 | public class TestCos { 6 | 7 | public static void testDifCos() { 8 | double [] v1 = {1, 1}; 9 | double [] v2 = {-1, -2}; 10 | double [] dv1 = new double[v1.length]; 11 | System.out.println(MathUtil.cos(v1, v2)); 12 | MathUtil.difCos(0.1, dv1, v1, v2); 13 | for (int i = 0; i < dv1.length; i++) { 14 | System.out.println(dv1[i]); 15 | } 16 | } 17 | 18 | public static void testCos() { 19 | double [] v1 = {1, 1}; 20 | double [][] v2 = { 21 | {-1, -1}, 22 | {0, 1}, 23 | {1, 0}, 24 | {0, -1}, 25 | {0, -2}, 26 | {1, -1}, 27 | {-1, 1}, 28 | {-1, 0}, 29 | {-2, 0} 30 | }; 31 | for (int i = 0; i < v2.length; i++) { 32 | System.out.println("v2["+i+ 33 | "] cos is "+ MathUtil.cos(v1, v2[i])); 34 | } 35 | 36 | } 37 | 38 | public static void testDifSoftmax() { 39 | double [] v1 = {1.0, 0, -1.0, 0.7, -0.7}; 40 | double [] dr = {0.1, 0.2, 0.3, 0.4, 0.5}; 41 | double [] sf = MathUtil.difSoftMax4Weighting(dr, v1, 1.5); 42 | for (int i = 0; i < sf.length; i++) { 43 | double d = sf[i]; 44 | System.out.println(d); 45 | } 46 | double db = MathUtil.difSoftMax4Beta(dr, v1, 1.5); 47 | System.out.println("db is "+db); 48 | } 49 | 50 | public static void testSoftmax() { 51 | double [] v1 = {1.0, 0, -1.0, 0.7, -0.7}; 52 | double [] sf = MathUtil.softMax(v1, 2); 53 | for (int i = 0; i < sf.length; i++) { 54 | double d = sf[i]; 55 | System.out.println(d); 56 | } 57 | } 58 | 59 | public static void main(String[] args) { 60 | // testCos(); 61 | // testSoftmax(); 62 | // testDifSoftmax(); 63 | testDifCos(); 64 | } 65 | 66 | } 67 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/math/test/TestIExp4Function.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.math.test; 2 | 3 | import deepDriver.dl.aml.math.LinearExp; 4 | 5 | public class TestIExp4Function { 6 | 7 | public static void testCompute() { 8 | LinearExp le = new LinearExp(5); 9 | double [] x = new double[5]; 10 | for (int i = 0; i < x.length; i++) { 11 | x[i] = 1.0; 12 | } 13 | 14 | le.compute(x); 15 | double r = le.getR(); 16 | double [] v = le.getParameters(); 17 | System.out.println("r: "+r); 18 | print(v); 19 | 20 | } 21 | 22 | public static void print(double [] v) { 23 | for (int i = 0; i < v.length; i++) { 24 | System.out.println(v[i]); 25 | } 26 | } 27 | 28 | public static void testDifCompute() { 29 | LinearExp le = new LinearExp(5); 30 | double [] x = new double[5]; 31 | for (int i = 0; i < x.length; i++) { 32 | x[i] = 1.0; 33 | } 34 | le.difCompute(1.0, x); 35 | le.difCompute(1.0, x); 36 | // le.update(0.1, 0.1); 37 | System.out.println("p:"); 38 | print(le.getParameters()); 39 | System.out.println("dv:"); 40 | print(le.getDv()); 41 | System.out.println("dp:"); 42 | print(le.getDeltaPara()); 43 | le.update(0.1, 0.1); 44 | System.out.println("p:"); 45 | print(le.getParameters()); 46 | } 47 | 48 | public static void main(String[] args) { 49 | // TestLineExp.testCompute(); 50 | testDifCompute(); 51 | } 52 | 53 | } 54 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/math/test/TestJcuBlas.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.math.test; 2 | 3 | import deepDriver.dl.aml.math.JCudaBlasMathFunction; 4 | 5 | public class TestJcuBlas { 6 | 7 | public static void main(String[] args) { 8 | float [][] a = new float[][]{ 9 | {1,1,1}, 10 | {1,1,1} 11 | }; 12 | float [][] b = new float[][]{ 13 | {1,1}, 14 | {1,1}, 15 | {1,1} 16 | }; 17 | JCudaBlasMathFunction jbmf = new JCudaBlasMathFunction(); 18 | float [][] c = new float[a.length][]; 19 | for (int i = 0; i < c.length; i++) { 20 | c[i] = new float[b[0].length]; 21 | } 22 | jbmf.multiple(a, b, c); 23 | syso(c); 24 | } 25 | 26 | public static void syso(float [][] a) { 27 | for (int i = 0; i < a.length; i++) { 28 | for (int j = 0; j < a[i].length; j++) { 29 | System.out.print(a[i][j]+","); 30 | } 31 | System.out.println(); 32 | } 33 | } 34 | 35 | } 36 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/math/test/TestMathUtils.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.math.test; 2 | 3 | import org.jblas.FloatMatrix; 4 | 5 | import deepDriver.dl.aml.math.BlasMathFunction; 6 | 7 | public class TestMathUtils { 8 | 9 | public static void main(String[] args) { 10 | float [][] a = new float[][]{ 11 | {1,1,1}, 12 | {1,1,1} 13 | }; 14 | float [][] b = new float[][]{ 15 | {1,1,1}, 16 | {1,1,1} 17 | }; 18 | float [][] c = new float[][]{ 19 | {1,1,1}, 20 | {1,1,1} 21 | }; 22 | BlasMathFunction bmf = new BlasMathFunction(); 23 | bmf.plus(a, 1.2f, b, 2.3f, c); 24 | // System.out.println(bmf); 25 | syso(c); 26 | } 27 | 28 | public static void syso(float [][] a) { 29 | for (int i = 0; i < a.length; i++) { 30 | for (int j = 0; j < a[i].length; j++) { 31 | System.out.print(a[i][j]+","); 32 | } 33 | System.out.println(); 34 | } 35 | } 36 | 37 | public static void main2(String[] args) { 38 | float [][] a = new float[][]{ 39 | {1,1,1}, 40 | {1,1,1} 41 | }; 42 | FloatMatrix fm = new FloatMatrix(a); 43 | fm.addi(1.2f); 44 | System.out.println(fm); 45 | } 46 | 47 | } 48 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/math/test/TestSoftMaxExp.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.math.test; 2 | 3 | import deepDriver.dl.aml.math.SoftMaxExp; 4 | 5 | public class TestSoftMaxExp { 6 | 7 | public static void testCompute() { 8 | SoftMaxExp sme = new SoftMaxExp(3, 5, 1.0); 9 | double [] x = new double[5]; 10 | for (int i = 0; i < x.length; i++) { 11 | x[i] = 1.0; 12 | } 13 | sme.compute(x); 14 | } 15 | 16 | public static void testDifCompute() { 17 | SoftMaxExp sme = new SoftMaxExp(3, 5, 1.0); 18 | double [] x = new double[5]; 19 | for (int i = 0; i < x.length; i++) { 20 | x[i] = 1.0; 21 | } 22 | double [] dy = new double[]{1.0, 1.0, 1.0}; 23 | sme.difCompute(dy, x); 24 | } 25 | 26 | public static void main(String[] args) { 27 | // testCompute(); 28 | testDifCompute(); 29 | } 30 | 31 | } 32 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/random/RandomFactory.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.random; 2 | 3 | import java.util.Random; 4 | 5 | public class RandomFactory { 6 | 7 | static transient Random random = new Random(System.currentTimeMillis()); 8 | 9 | public static Random getRandom() { 10 | return random; 11 | } 12 | 13 | } 14 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/rn/RelationConnCostFunction.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.rn; 2 | 3 | import deepDriver.dl.aml.ann.ILayer; 4 | import deepDriver.dl.aml.costFunction.ICostFunction; 5 | 6 | public class RelationConnCostFunction implements ICostFunction { 7 | 8 | /** 9 | * 10 | */ 11 | private static final long serialVersionUID = 1L; 12 | 13 | RN4DNN rn; 14 | 15 | public RelationConnCostFunction(RN4DNN rn) { 16 | super(); 17 | this.rn = rn; 18 | } 19 | 20 | @Override 21 | public int getzZIndex() { 22 | return 0; 23 | } 24 | 25 | @Override 26 | public void setzZIndex(int zZIndex) { 27 | 28 | } 29 | 30 | @Override 31 | public double[] activate() { 32 | return null; 33 | } 34 | 35 | @Override 36 | public double caculateStdError() { 37 | return 0; 38 | } 39 | 40 | @Override 41 | public void caculateCostError() { 42 | 43 | } 44 | 45 | @Override 46 | public void setLayer(ILayer layer) { 47 | 48 | } 49 | 50 | @Override 51 | public void setTarget(double[] target) { 52 | 53 | } 54 | 55 | @Override 56 | public double verfiyResult(double[] targets, double[] results) { 57 | return 0; 58 | } 59 | 60 | } 61 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/rn/RelationObject.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.rn; 2 | 3 | public class RelationObject { 4 | 5 | double [] input; 6 | double [] target; 7 | 8 | public double[] getInput() { 9 | return input; 10 | } 11 | public void setInput(double[] input) { 12 | this.input = input; 13 | } 14 | public double[] getTarget() { 15 | return target; 16 | } 17 | public void setTarget(double[] target) { 18 | this.target = target; 19 | } 20 | 21 | 22 | 23 | } 24 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/rn/RelationObjectSet.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.rn; 2 | 3 | public class RelationObjectSet { 4 | 5 | RelationObject [] rObjs; 6 | RelationObject tObj; 7 | 8 | public RelationObject[] getrObjs() { 9 | return rObjs; 10 | } 11 | public void setrObjs(RelationObject[] rObjs) { 12 | this.rObjs = rObjs; 13 | } 14 | public RelationObject gettObj() { 15 | return tObj; 16 | } 17 | public void settObj(RelationObject tObj) { 18 | this.tObj = tObj; 19 | } 20 | 21 | 22 | 23 | } 24 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/rn/test/TestDrama4RNDNNMTL.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.rn.test; 2 | 3 | public class TestDrama4RNDNNMTL { 4 | 5 | 6 | 7 | } 8 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/sa/SA.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.sa; 2 | 3 | import java.util.Random; 4 | 5 | public class SA { 6 | 7 | double defaultT = 500; 8 | double t = defaultT; 9 | double r = 0.5; 10 | 11 | boolean optimMax = true; 12 | 13 | int initLoop = 1; 14 | 15 | Random rd = new Random(System.currentTimeMillis()); 16 | 17 | public SA(double defaultT, double r, boolean optimMax, int initLoop) { 18 | super(); 19 | this.r = r; 20 | this.defaultT = defaultT; 21 | t = defaultT; 22 | this.optimMax = optimMax; 23 | this.initLoop = initLoop; 24 | if (initLoop > 0 ) { 25 | for (int i = 0; i < initLoop; i++) { 26 | t = r * t; 27 | } 28 | } 29 | } 30 | 31 | public void reset() { 32 | t = defaultT; 33 | } 34 | 35 | public boolean sa(double deltaE) { 36 | if (!optimMax) { 37 | deltaE = -1.0 * deltaE; 38 | } 39 | t = t * r; 40 | double d = rd.nextDouble(); 41 | if (deltaE > 0) { 42 | return true; 43 | } 44 | if (Math.exp(deltaE / t) > d) { 45 | return true; 46 | } 47 | return false; 48 | } 49 | 50 | } 51 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/stream/IWordStream.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.stream; 2 | 3 | public interface IWordStream { 4 | 5 | public void reset(); 6 | 7 | public boolean hasNext(); 8 | 9 | public void next(); 10 | 11 | public String [] getSampleTT(); 12 | 13 | public String [] getTarget(); 14 | 15 | public int getSampleTTLength(); 16 | 17 | public int getSampleFeatureNum(); 18 | 19 | public int getTargetFeatureNum(); 20 | 21 | public Object getPos(); 22 | 23 | public void next(Object pos); 24 | 25 | } 26 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/utils/AccuracyCaculator.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.utils; 2 | 3 | public class AccuracyCaculator { 4 | public double caculateAccuracy(double [] x, double [] px) { 5 | double avgX = 0; 6 | double sumX = 0; 7 | double stdVar = 0; 8 | double err = 0; 9 | for (int i = 0; i < x.length; i++) { 10 | sumX = sumX + x[i]; 11 | } 12 | avgX = sumX/(double)x.length; 13 | for (int i = 0; i < x.length; i++) { 14 | stdVar = stdVar + (x[i] - avgX) * (x[i] - avgX); 15 | err = err+ (x[i] - px[i]) * (x[i] - px[i]); 16 | } 17 | return 1 - err/stdVar; 18 | } 19 | 20 | public boolean check(double [] ta, double [] tb) { 21 | if (getMaxPos(ta) == getMaxPos(tb)) { 22 | return true; 23 | } 24 | return false; 25 | } 26 | 27 | public int getMaxPos(double [] ta) { 28 | int pos = 0; 29 | for (int i = 0; i < ta.length; i++) { 30 | if (ta[i] > ta[pos]) { 31 | pos = i; 32 | } 33 | } 34 | return pos; 35 | } 36 | 37 | int cnt = 0; 38 | int correctCnt = 0; 39 | 40 | public void cntIncrease() { 41 | cnt ++; 42 | } 43 | 44 | public void correctCntIncrease() { 45 | correctCnt ++; 46 | } 47 | 48 | public void reset() { 49 | cnt = 0; 50 | correctCnt = 0; 51 | } 52 | 53 | int summaryInterval = 200; 54 | 55 | public void summaryCp() { 56 | if (cnt % summaryInterval == 0) { 57 | summary(); 58 | } 59 | } 60 | 61 | public void summary() { 62 | System.out.println("All count is: "+ cnt 63 | +", the correct count is: "+correctCnt 64 | +", the accuracy is: "+ (double)correctCnt/(double)cnt); 65 | } 66 | 67 | } 68 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/w2v/KeyCntPair.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.w2v; 2 | 3 | import java.io.Serializable; 4 | 5 | public class KeyCntPair implements Serializable { 6 | private static final long serialVersionUID = 1L; 7 | String key; 8 | double value; 9 | public String getKey() { 10 | return key; 11 | } 12 | public void setKey(String key) { 13 | this.key = key; 14 | } 15 | public double getValue() { 16 | return value; 17 | } 18 | public void setValue(double value) { 19 | this.value = value; 20 | } 21 | 22 | } -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/w2v/test/TestNegativeSampling.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.w2v.test; 2 | 3 | import deepDriver.dl.aml.lstm.apps.wordSegmentation.WordSegSetV2; 4 | import deepDriver.dl.aml.w2v.NegtiveSampling; 5 | import deepDriver.dl.aml.w2v.Window4WordSegStream; 6 | 7 | public class TestNegativeSampling { 8 | 9 | public static void main(String[] args) throws Exception { 10 | WordSegSetV2 wss = new WordSegSetV2(); 11 | wss.setMaxLength(1000); 12 | wss.setRequireBlank(true); 13 | wss.setRequireEndFlagCheck(false); 14 | if (args.length > 0) { 15 | wss.loadFlatDs(args[0]); 16 | } else { 17 | wss.loadFlatDs("D:\\6.workspace\\p.NLP\\000000_0.0");//50000-corpus.txt 18 | } 19 | 20 | 21 | // wss.setVoLoadOnly(true); 22 | // wss.loadWordSegSet("D:\\6.workspace\\p.NLP\\dev.conll"); 23 | 24 | Window4WordSegStream qsi = new Window4WordSegStream(wss); 25 | 26 | NegtiveSampling negtiveSampling = new NegtiveSampling(); 27 | negtiveSampling.w2v(qsi); 28 | 29 | } 30 | 31 | } 32 | -------------------------------------------------------------------------------- /src/deepDriver/dl/aml/w2v/test/VerifyW2v.java: -------------------------------------------------------------------------------- 1 | package deepDriver.dl.aml.w2v.test; 2 | 3 | import java.util.List; 4 | 5 | import deepDriver.dl.aml.distribution.Fs; 6 | import deepDriver.dl.aml.w2v.KeyCntPair; 7 | import deepDriver.dl.aml.w2v.W2V; 8 | 9 | public class VerifyW2v { 10 | 11 | public static void main(String[] args) throws Exception { 12 | W2V w2v = (W2V) Fs.readObjFromFile("D:\\p.output\\build\\data\\" + 13 | "w2v_1515427606816_0.m"); 14 | String k = "狗血"; 15 | List list = w2v.getSimilarity(k, 100); 16 | System.out.println("Similarity with "+k+" is: "); 17 | for (int i = 0; i < 10; i++) { 18 | KeyCntPair kcp = list.get(i); 19 | System.out.println(kcp.getKey()+","+kcp.getValue()); 20 | } 21 | } 22 | 23 | } 24 | --------------------------------------------------------------------------------