├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── data ├── debug │ └── debugWeights │ │ ├── matlab-weights-1.csv │ │ ├── matlab-weights-2.csv │ │ ├── matlab-weights-3.csv │ │ ├── matlab-weights-4.csv │ │ └── matlab-weights-5.csv ├── generated │ └── .gitignore ├── ocr.mat ├── regression-tests │ ├── bcfw-base-cache.csv │ ├── bcfw-base-wavg.csv │ ├── bcfw-base.csv │ ├── dissolve-base-cache.csv │ ├── dissolve-base-wavg.csv │ ├── dissolve-base.csv │ ├── dissolve-frac-parts-wavg-cache.csv │ ├── dissolve-frac-parts-wavg.csv │ ├── dissolve-frac-parts.csv │ └── dissolve-frac.csv └── retrieve_datasets.sh ├── debug └── .gitignore ├── dissolve-struct-application ├── build.sbt ├── project │ └── plugins.sbt └── src │ └── main │ └── scala │ └── ch │ └── ethz │ └── dalab │ └── dissolve │ └── app │ └── DSApp.scala ├── dissolve-struct-examples ├── build.sbt ├── conf │ └── log4j.properties ├── lib │ └── .gitignore ├── project │ └── plugins.sbt └── src │ ├── main │ ├── resources │ │ ├── adj.txt │ │ ├── chain_test.csv │ │ ├── chain_train.csv │ │ ├── imageseg_cattle_test.txt │ │ ├── imageseg_cattle_train.txt │ │ ├── imageseg_colormap.txt │ │ ├── imageseg_lab_freq.txt │ │ ├── imageseg_label_color_map.txt │ │ ├── imageseg_test.txt │ │ ├── imageseg_train.txt │ │ └── noun.txt │ └── scala │ │ └── ch │ │ └── ethz │ │ └── dalab │ │ └── dissolve │ │ └── examples │ │ ├── binaryclassification │ │ ├── AdultBinary.scala │ │ ├── BinaryClassificationDemo.scala │ │ ├── COVBinary.scala │ │ └── RCV1Binary.scala │ │ ├── chain │ │ ├── ChainBPDemo.scala │ │ └── ChainDemo.scala │ │ ├── imageseg │ │ ├── ImageSeg.scala │ │ ├── ImageSegRunner.scala │ │ ├── ImageSegTypes.scala │ │ └── ImageSegUtils.scala │ │ ├── multiclass │ │ └── COVMulticlass.scala │ │ └── utils │ │ └── ExampleUtils.scala │ └── test │ └── scala │ └── ch │ └── ethz │ └── dalab │ └── dissolve │ └── diagnostics │ ├── FeatureFnSpec.scala │ ├── OracleSpec.scala │ ├── StructLossSpec.scala │ └── UnitSpec.scala ├── dissolve-struct-lib ├── .gitignore ├── build.sbt ├── lib │ └── jython.jar ├── project │ ├── assembly.sbt │ └── plugins.sbt └── src │ └── main │ └── scala │ └── ch │ └── ethz │ └── dalab │ └── dissolve │ ├── classification │ ├── BinarySVMWithDBCFW.scala │ ├── BinarySVMWithSSG.scala │ ├── ClassificationUtils.scala │ ├── MultiClassSVMWithDBCFW.scala │ ├── StructSVMModel.scala │ ├── StructSVMWithBCFW.scala │ ├── StructSVMWithDBCFW.scala │ ├── StructSVMWithMiniBatch.scala │ ├── StructSVMWithSSG.scala │ └── Types.scala │ ├── optimization │ ├── BCFWSolver.scala │ ├── DBCFWSolverTuned.scala │ ├── DissolveFunctions.scala │ ├── SSGSolver.scala │ ├── SolverOptions.scala │ └── SolverUtils.scala │ ├── regression │ └── LabeledObject.scala │ └── utils │ └── cli │ ├── CLAParser.scala │ └── Config.scala ├── helpers ├── __init__.py ├── benchmark_runner.py ├── benchmark_setup.py ├── benchmark_utils.py ├── brutus_runner.py ├── brutus_sample.cfg ├── brutus_setup.py ├── buildall.py ├── ocr_helpers.py ├── paths.py └── retrieve_datasets.py ├── paper ├── dissolve-jmlr-software.bbl ├── dissolve-jmlr-software.pdf ├── dissolve-jmlr-software.tex ├── jmlr2e.sty └── references.bib └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # use glob syntax. 2 | syntax: glob 3 | *.ser 4 | *.class 5 | *~ 6 | *.bak 7 | #*.off 8 | *.old 9 | 10 | # eclipse conf file 11 | .settings 12 | .classpath 13 | .project 14 | .manager 15 | .scala_dependencies 16 | 17 | # idea 18 | .idea 19 | *.iml 20 | 21 | # building 22 | target 23 | build 24 | null 25 | tmp* 26 | temp* 27 | dist 28 | test-output 29 | build.log 30 | 31 | # other scm 32 | .svn 33 | .CVS 34 | .hg* 35 | 36 | # Mac stuff 37 | .DS_Store 38 | 39 | .metadata/ 40 | _site/ 41 | Gemfile.lock 42 | 43 | # switch to regexp syntax. 44 | # syntax: regexp 45 | # ^\.pc/ 46 | 47 | #SHITTY output not in target directory 48 | build.log 49 | *.pyc 50 | 51 | # checkpoint files generated by Spark 52 | checkpoint-files/* 53 | 54 | # Custom stuff 55 | *.cache 56 | *.log 57 | 58 | debug/* 59 | data/generated/* 60 | 61 | dissolve-struct-lib/lib_managed/* 62 | 63 | dissolve-struct-examples/spark-1.* 64 | 65 | # Experiment related scripts/data 66 | conf/* 67 | expt-data/* 68 | figures/* 69 | *.ipynb 70 | setup_cluster.sh 71 | run.sh 72 | 73 | # Data for regression testing 74 | data/debug/* 75 | 76 | # ec2 credentials 77 | ec2_config.json 78 | 79 | # New eclipse introduced some temp files 80 | .cache-main 81 | .cache-tests 82 | 83 | # tribhu's current directory 84 | benchmark-data/ 85 | jars/ 86 | 87 | % latex 88 | *.out 89 | *.synctex.gz 90 | *.blg 91 | *.aux 92 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: scala 2 | scala: 3 | - 2.10.4 4 | 5 | # Build only the lib package, and exclude examples package from build 6 | before_install: 7 | - cd dissolve-struct-lib 8 | 9 | # whitelist 10 | branches: 11 | only: 12 | - master 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/dalab/dissolve-struct.svg?branch=tuning)](https://travis-ci.org/dalab/dissolve-struct) 2 | [![Release status](https://img.shields.io/badge/release-v0.1-orange.svg)](https://github.com/dalab/dissolve-struct/releases) 3 | 4 | dissolvestruct 5 | =========== 6 | 7 | Distributed solver library for structured output prediction, based on Spark. 8 | 9 | The library is based on the primal-dual BCFW solver, allowing approximate inference oracles, and distributes this algorithm using the recent communication efficient CoCoA scheme. 10 | The interface to the user is the same as in the widely used SVMstruct in the single machine case. 11 | 12 | For more information, checkout the [project page](http://dalab.github.io/dissolve-struct/) 13 | 14 | -------------------------------------------------------------------------------- /data/generated/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | 3 | !.gitignore -------------------------------------------------------------------------------- /data/ocr.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dalab/dissolve-struct/67f37377b74c32cf05d8f43a0e3658a10864f9bf/data/ocr.mat -------------------------------------------------------------------------------- /data/regression-tests/bcfw-base-cache.csv: -------------------------------------------------------------------------------- 1 | # BCFW 2 | # numPasses=5 3 | # doWeightedAveraging=false 4 | # randSeed=42 5 | # sample=perm 6 | # lambda=0.010000 7 | # doLineSearch=true 8 | # enableManualPartitionSize=false 9 | # NUM_PART=1 10 | # enableOracleCache=true 11 | # oracleCacheSize=10 12 | # H=5 13 | # sampleFrac=0.500000 14 | # sampleWithReplacement=false 15 | # debugInfoPath=/Users/tribhu/git/DBCFWstruct/debug/debug-bcfw-1416414185362.csv 16 | 17 | round,time,iter,primal,dual,gap,train_error,test_error 18 | 1,0.761000,312,0.740438,0.004261,0.736177,0.289614,0.382340 19 | 2,1.756000,624,0.619814,0.005358,0.614455,0.220483,0.339086 20 | 3,2.816000,936,0.547373,0.006040,0.541333,0.186888,0.315781 21 | 4,3.741000,1248,0.502321,0.006678,0.495643,0.167291,0.294808 22 | 5,4.692000,1560,0.461032,0.007274,0.453758,0.158045,0.284424 23 | -------------------------------------------------------------------------------- /data/regression-tests/bcfw-base-wavg.csv: -------------------------------------------------------------------------------- 1 | # BCFW 2 | # numPasses=5 3 | # doWeightedAveraging=true 4 | # randSeed=42 5 | # sample=perm 6 | # lambda=0.010000 7 | # doLineSearch=true 8 | # enableManualPartitionSize=false 9 | # NUM_PART=1 10 | # enableOracleCache=false 11 | # oracleCacheSize=10 12 | # H=5 13 | # sampleFrac=0.500000 14 | # sampleWithReplacement=false 15 | # debugInfoPath=/Users/tribhu/git/DBCFWstruct/debug/debug-bcfw-1416413767561.csv 16 | 17 | round,time,iter,primal,dual,gap,train_error,test_error 18 | 1,0.757000,312,0.649247,0.003156,0.646091,0.232588,0.329987 19 | 2,1.908000,624,0.511866,0.005451,0.506414,0.170380,0.286416 20 | 3,2.888000,936,0.445621,0.007233,0.438388,0.143429,0.267820 21 | 4,3.850000,1248,0.402138,0.008722,0.393416,0.124673,0.258407 22 | 5,4.823000,1560,0.369806,0.010022,0.359784,0.114160,0.253438 23 | -------------------------------------------------------------------------------- /data/regression-tests/bcfw-base.csv: -------------------------------------------------------------------------------- 1 | # BCFW 2 | # numPasses=5 3 | # doWeightedAveraging=false 4 | # randSeed=42 5 | # sample=perm 6 | # lambda=0.010000 7 | # doLineSearch=true 8 | # enableManualPartitionSize=false 9 | # NUM_PART=1 10 | # enableOracleCache=false 11 | # oracleCacheSize=10 12 | # H=5 13 | # sampleFrac=0.500000 14 | # sampleWithReplacement=false 15 | # debugInfoPath=/Users/tribhu/git/DBCFWstruct/debug/debug-bcfw-1416401133519.csv 16 | 17 | round,time,iter,primal,dual,gap,train_error,test_error 18 | 1,0.807000,312,0.740438,0.004261,0.736177,0.289614,0.382340 19 | 2,2.497000,624,0.587597,0.007060,0.580537,0.207465,0.336020 20 | 3,3.623000,936,0.482679,0.009151,0.473528,0.151328,0.292496 21 | 4,4.749000,1248,0.493073,0.010945,0.482128,0.153231,0.305247 22 | 5,5.732000,1560,0.411257,0.012449,0.398809,0.122657,0.288465 23 | -------------------------------------------------------------------------------- /data/regression-tests/dissolve-base-cache.csv: -------------------------------------------------------------------------------- 1 | # numPasses=5 2 | # doWeightedAveraging=false 3 | # randSeed=42 4 | # sample=frac 5 | # lambda=0.010000 6 | # doLineSearch=true 7 | # enableManualPartitionSize=true 8 | # NUM_PART=1 9 | # enableOracleCache=true 10 | # oracleCacheSize=10 11 | # H=5 12 | # sampleFrac=1.000000 13 | # sampleWithReplacement=false 14 | # debugInfoPath=/Users/tribhu/git/DBCFWstruct/debug/debug-dissolve-1416414162557.csv 15 | 16 | # spark.app.name=Chain-DBCFW 17 | # spark.executor.memory=NA 18 | # spark.task.cpus=NA 19 | # spark.local.dir=NA 20 | # spark.default.parallelism=NA 21 | 22 | # indexedTrainDataRDD.partitions.size=1 23 | # indexedPrimalsRDD.partitions.size=1 24 | round,time,primal,dual,gap,train_error,test_error 25 | 1,5,0.740438,0.004261,0.736177,0.289614,0.382340 26 | 2,8,0.619814,0.005358,0.614455,0.220483,0.339086 27 | 3,11,0.547373,0.006040,0.541333,0.186888,0.315781 28 | 4,13,0.502321,0.006678,0.495643,0.167291,0.294808 29 | 5,16,0.461032,0.007274,0.453758,0.158045,0.284424 30 | -------------------------------------------------------------------------------- /data/regression-tests/dissolve-base-wavg.csv: -------------------------------------------------------------------------------- 1 | # numPasses=5 2 | # doWeightedAveraging=true 3 | # randSeed=42 4 | # sample=frac 5 | # lambda=0.010000 6 | # doLineSearch=true 7 | # enableManualPartitionSize=true 8 | # NUM_PART=1 9 | # enableOracleCache=false 10 | # oracleCacheSize=10 11 | # H=5 12 | # sampleFrac=1.000000 13 | # sampleWithReplacement=false 14 | # debugInfoPath=/Users/tribhu/git/DBCFWstruct/debug/debug-dissolve-1416413743225.csv 15 | 16 | # spark.app.name=Chain-DBCFW 17 | # spark.executor.memory=NA 18 | # spark.task.cpus=NA 19 | # spark.local.dir=NA 20 | # spark.default.parallelism=NA 21 | 22 | # indexedTrainDataRDD.partitions.size=1 23 | # indexedPrimalsRDD.partitions.size=1 24 | round,time,primal,dual,gap,train_error,test_error 25 | 1,5,0.649247,0.003156,0.646091,0.232588,0.329987 26 | 2,9,0.511866,0.005451,0.506414,0.170380,0.286416 27 | 3,12,0.445621,0.007233,0.438388,0.143429,0.267820 28 | 4,15,0.402138,0.008722,0.393416,0.124673,0.258407 29 | 5,18,0.369806,0.010022,0.359784,0.114160,0.253438 30 | -------------------------------------------------------------------------------- /data/regression-tests/dissolve-base.csv: -------------------------------------------------------------------------------- 1 | # numPasses=5 2 | # doWeightedAveraging=false 3 | # randSeed=42 4 | # sample=frac 5 | # lambda=0.010000 6 | # doLineSearch=true 7 | # enableManualPartitionSize=true 8 | # NUM_PART=1 9 | # enableOracleCache=false 10 | # oracleCacheSize=10 11 | # H=5 12 | # sampleFrac=1.000000 13 | # sampleWithReplacement=false 14 | # debugInfoPath=/Users/tribhu/git/DBCFWstruct/debug/debug-dissolve-1416401107250.csv 15 | 16 | # spark.app.name=Chain-DBCFW 17 | # spark.executor.memory=NA 18 | # spark.task.cpus=NA 19 | # spark.local.dir=NA 20 | # spark.default.parallelism=NA 21 | 22 | # indexedTrainDataRDD.partitions.size=1 23 | # indexedPrimalsRDD.partitions.size=1 24 | round,time,primal,dual,gap,train_error,test_error 25 | 1,6,0.740438,0.004261,0.736177,0.289614,0.382340 26 | 2,8,0.587597,0.007060,0.580537,0.207465,0.336020 27 | 3,11,0.482679,0.009151,0.473528,0.151328,0.292496 28 | 4,14,0.493073,0.010945,0.482128,0.153231,0.305247 29 | 5,17,0.411257,0.012449,0.398809,0.122657,0.288465 30 | -------------------------------------------------------------------------------- /data/regression-tests/dissolve-frac-parts-wavg-cache.csv: -------------------------------------------------------------------------------- 1 | # numPasses=5 2 | # doWeightedAveraging=true 3 | # randSeed=42 4 | # sample=frac 5 | # lambda=0.010000 6 | # doLineSearch=true 7 | # enableManualPartitionSize=true 8 | # NUM_PART=4 9 | # enableOracleCache=true 10 | # oracleCacheSize=10 11 | # H=5 12 | # sampleFrac=0.500000 13 | # sampleWithReplacement=false 14 | # debugInfoPath=/Users/tribhu/git/DBCFWstruct/debug/debug-dissolve-1416475800888.csv 15 | 16 | # spark.app.name=Chain-DBCFW 17 | # spark.executor.memory=NA 18 | # spark.task.cpus=NA 19 | # spark.local.dir=NA 20 | # spark.default.parallelism=NA 21 | 22 | # indexedTrainDataRDD.partitions.size=4 23 | # indexedPrimalsRDD.partitions.size=4 24 | round,time,primal,dual,gap,train_error,test_error 25 | 1,6,0.949089,0.000614,0.948474,0.423528,0.483519 26 | 2,9,0.904272,0.000989,0.903284,0.380705,0.450028 27 | 3,12,0.879400,0.001221,0.878179,0.360667,0.436719 28 | 4,15,0.852875,0.001393,0.851483,0.342916,0.422211 29 | 5,19,0.826390,0.001539,0.824851,0.329641,0.407550 30 | -------------------------------------------------------------------------------- /data/regression-tests/dissolve-frac-parts-wavg.csv: -------------------------------------------------------------------------------- 1 | # numPasses=5 2 | # doWeightedAveraging=true 3 | # randSeed=42 4 | # sample=frac 5 | # lambda=0.010000 6 | # doLineSearch=true 7 | # enableManualPartitionSize=true 8 | # NUM_PART=4 9 | # enableOracleCache=false 10 | # oracleCacheSize=10 11 | # H=5 12 | # sampleFrac=0.500000 13 | # sampleWithReplacement=false 14 | # debugInfoPath=/Users/tribhu/git/DBCFWstruct/debug/debug-dissolve-1416475749034.csv 15 | 16 | # spark.app.name=Chain-DBCFW 17 | # spark.executor.memory=NA 18 | # spark.task.cpus=NA 19 | # spark.local.dir=NA 20 | # spark.default.parallelism=NA 21 | 22 | # indexedTrainDataRDD.partitions.size=4 23 | # indexedPrimalsRDD.partitions.size=4 24 | round,time,primal,dual,gap,train_error,test_error 25 | 1,5,0.949089,0.000614,0.948474,0.423528,0.483519 26 | 2,8,0.860342,0.001137,0.859205,0.334412,0.406605 27 | 3,11,0.810682,0.001574,0.809108,0.306459,0.385472 28 | 4,13,0.774036,0.001963,0.772073,0.283006,0.367418 29 | 5,16,0.744513,0.002314,0.742199,0.264717,0.350701 30 | -------------------------------------------------------------------------------- /data/regression-tests/dissolve-frac-parts.csv: -------------------------------------------------------------------------------- 1 | # numPasses=5 2 | # doWeightedAveraging=false 3 | # randSeed=42 4 | # sample=frac 5 | # lambda=0.010000 6 | # doLineSearch=true 7 | # enableManualPartitionSize=true 8 | # NUM_PART=4 9 | # enableOracleCache=false 10 | # oracleCacheSize=10 11 | # H=5 12 | # sampleFrac=0.500000 13 | # sampleWithReplacement=false 14 | # debugInfoPath=/Users/tribhu/git/DBCFWstruct/debug/debug-dissolve-1416475677521.csv 15 | 16 | # spark.app.name=Chain-DBCFW 17 | # spark.executor.memory=NA 18 | # spark.task.cpus=NA 19 | # spark.local.dir=NA 20 | # spark.default.parallelism=NA 21 | 22 | # indexedTrainDataRDD.partitions.size=4 23 | # indexedPrimalsRDD.partitions.size=4 24 | round,time,primal,dual,gap,train_error,test_error 25 | 1,6,0.942347,0.000898,0.941449,0.421662,0.473316 26 | 2,9,0.842540,0.001582,0.840958,0.337215,0.412921 27 | 3,12,0.785794,0.002142,0.783652,0.302160,0.378707 28 | 4,14,0.749114,0.002636,0.746478,0.282378,0.359638 29 | 5,17,0.712107,0.003064,0.709043,0.255069,0.345087 30 | -------------------------------------------------------------------------------- /data/regression-tests/dissolve-frac.csv: -------------------------------------------------------------------------------- 1 | # numPasses=5 2 | # doWeightedAveraging=false 3 | # randSeed=42 4 | # sample=frac 5 | # lambda=0.010000 6 | # doLineSearch=true 7 | # enableManualPartitionSize=true 8 | # NUM_PART=1 9 | # enableOracleCache=false 10 | # oracleCacheSize=10 11 | # H=5 12 | # sampleFrac=0.500000 13 | # sampleWithReplacement=false 14 | # debugInfoPath=/Users/tribhu/git/DBCFWstruct/debug/debug-dissolve-1416475480026.csv 15 | 16 | # spark.app.name=Chain-DBCFW 17 | # spark.executor.memory=NA 18 | # spark.task.cpus=NA 19 | # spark.local.dir=NA 20 | # spark.default.parallelism=NA 21 | 22 | # indexedTrainDataRDD.partitions.size=1 23 | # indexedPrimalsRDD.partitions.size=1 24 | round,time,primal,dual,gap,train_error,test_error 25 | 1,4,0.832474,0.002271,0.830203,0.347001,0.416671 26 | 2,7,0.795129,0.003703,0.791426,0.316188,0.402175 27 | 3,9,0.665157,0.004785,0.660371,0.258996,0.334717 28 | 4,11,0.610020,0.005626,0.604394,0.220944,0.333295 29 | 5,13,0.565354,0.006300,0.559054,0.210634,0.317841 30 | -------------------------------------------------------------------------------- /data/retrieve_datasets.sh: -------------------------------------------------------------------------------- 1 | GEN_DIR="generated" 2 | 3 | if [ ! -d "$GEN_DIR" ]; then 4 | mkdir $GEN_DIR 5 | fi 6 | 7 | cd $GEN_DIR 8 | 9 | # Adult 10 | wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/a1a 11 | 12 | # Forest Cover (Binary) 13 | wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/covtype.libsvm.binary.scale.bz2 14 | bzip2 -d covtype.libsvm.binary.scale.bz2 15 | 16 | # Forest Cover (Multiclass) 17 | wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/covtype.scale.bz2 18 | bzip2 -d covtype.scale.bz2 19 | 20 | # RCV1 21 | wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/rcv1_train.binary.bz2 22 | bzip2 -d rcv1_train.binary.bz2 23 | 24 | # Factorie jar 25 | wget https://github.com/factorie/factorie/releases/download/factorie-1.0/factorie-1.0.jar 26 | mv factorie-1.0.jar ../../dissolve-struct-examples/lib/ 27 | 28 | cd .. 29 | -------------------------------------------------------------------------------- /debug/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | 3 | !.gitignore -------------------------------------------------------------------------------- /dissolve-struct-application/build.sbt: -------------------------------------------------------------------------------- 1 | // ---- < 1 > ----------------------------------------------------------------- 2 | // Enter your application name, organization and version below 3 | // This information will be used when the binary jar is packaged 4 | name := "DissolveStructApplication" 5 | 6 | organization := "ch.ethz.dalab" 7 | 8 | version := "0.1-SNAPSHOT" // Keep this unchanged for development releases 9 | // ---- ---------------------------------------------------------------- 10 | 11 | scalaVersion := "2.10.4" 12 | 13 | libraryDependencies += "ch.ethz.dalab" %% "dissolvestruct" % "0.1-SNAPSHOT" 14 | 15 | libraryDependencies += "org.scalatest" % "scalatest_2.10" % "2.0" % "test" 16 | 17 | libraryDependencies += "org.apache.spark" %% "spark-core" % "1.4.1" 18 | 19 | libraryDependencies += "org.apache.spark" %% "spark-mllib" % "1.4.1" 20 | 21 | resolvers += "IESL Release" at "http://dev-iesl.cs.umass.edu/nexus/content/groups/public" 22 | 23 | libraryDependencies += "cc.factorie" % "factorie" % "1.0" 24 | 25 | libraryDependencies += "com.github.scopt" %% "scopt" % "3.3.0" 26 | 27 | // ---- < 2 > ----------------------------------------------------------------- 28 | // Add additional dependencies in the space provided below, like above. 29 | // Libraries often provide the exact line that needs to be added here on their 30 | // webpage. 31 | // PS: Keep your eyes peeled -- there is a difference between "%%" and "%" 32 | 33 | // libraryDependencies += "organization" %% "application_name" % "version" 34 | 35 | // ---- ---------------------------------------------------------------- 36 | 37 | resolvers += Resolver.sonatypeRepo("public") 38 | 39 | EclipseKeys.createSrc := EclipseCreateSrc.Default + EclipseCreateSrc.Resource 40 | 41 | mergeStrategy in assembly <<= (mergeStrategy in assembly) { (old) => 42 | { 43 | case PathList("javax", "servlet", xs @ _*) => MergeStrategy.first 44 | case PathList(ps @ _*) if ps.last endsWith ".html" => MergeStrategy.first 45 | case "application.conf" => MergeStrategy.concat 46 | case "reference.conf" => MergeStrategy.concat 47 | case "log4j.properties" => MergeStrategy.discard 48 | case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard 49 | case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard 50 | case _ => MergeStrategy.first 51 | } 52 | } 53 | 54 | test in assembly := {} 55 | -------------------------------------------------------------------------------- /dissolve-struct-application/project/plugins.sbt: -------------------------------------------------------------------------------- 1 | logLevel := Level.Warn 2 | 3 | addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.5.0") 4 | 5 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.13.0") 6 | -------------------------------------------------------------------------------- /dissolve-struct-application/src/main/scala/ch/ethz/dalab/dissolve/app/DSApp.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.app 2 | 3 | import org.apache.spark.SparkConf 4 | import org.apache.spark.SparkContext 5 | import org.apache.spark.rdd.RDD 6 | 7 | import breeze.linalg.Vector 8 | import ch.ethz.dalab.dissolve.classification.StructSVMModel 9 | import ch.ethz.dalab.dissolve.classification.StructSVMWithDBCFW 10 | import ch.ethz.dalab.dissolve.optimization.DissolveFunctions 11 | import ch.ethz.dalab.dissolve.optimization.GapThresholdCriterion 12 | import ch.ethz.dalab.dissolve.optimization.SolverOptions 13 | import ch.ethz.dalab.dissolve.regression.LabeledObject 14 | 15 | /** 16 | * This defines the x-part of the training example 17 | * For example, in case of sequence OCR this would be a (d x n) matrix, with 18 | * each column containing the pixel representation of a character 19 | */ 20 | case class Pattern() { 21 | 22 | } 23 | 24 | /** 25 | * This defined the y-part of the training example 26 | * Once again, in case of OCR this would be a n-dimensional vector, with the 27 | * i-th element containing label for i-th character of x. 28 | */ 29 | case class Label() { 30 | 31 | } 32 | 33 | /** 34 | * This is the core of your Structured SVM application. 35 | * In here, you'll find three functions and a driver program that you'll need 36 | * to fill in to get your application running. 37 | * 38 | * The interface is inspired by SVM^struct by Joachims et al. 39 | * (http://www.cs.cornell.edu/people/tj/svm_light/svm_struct.html) 40 | */ 41 | object DSApp extends DissolveFunctions[Pattern, Label] { 42 | 43 | /** 44 | * ============== Joint feature Map: \phi(x, y) ============== 45 | * 46 | * This encodes the complex input-output (x, y) pair in a Vector space (done 47 | * using vectors from the Breeze library) 48 | */ 49 | def featureFn(x: Pattern, y: Label): Vector[Double] = { 50 | 51 | // Insert code here for Joint feature map here 52 | 53 | ??? 54 | } 55 | 56 | /** 57 | * ============== Structured Loss Function: \Delta(y, y^m) ============== 58 | * 59 | * Loss for predicting instead of . 60 | * This needs to be 0 if == 61 | */ 62 | def lossFn(yPredicted: Label, yTruth: Label): Double = { 63 | 64 | // Insert code for Loss function here 65 | 66 | ??? 67 | } 68 | 69 | /** 70 | * ============== Maximization Oracle: H^m(w) ============== 71 | * 72 | * Finds the most violating constraint by solving the loss-augmented decoding 73 | * subproblem. 74 | * This is equivalent to predicting 75 | * y* = argmax_{y} \Delta(y, y^m) + < w, \phi(x^m, y) > 76 | * for some training example (x^m, y^m) and parameters w 77 | * 78 | * Make sure the loss-augmentation is consistent with the \Delta defined above. 79 | * 80 | * By default, the prediction function calls this oracle with y^m = null. 81 | * In which case, the loss-augmentation can be skipped using a simple check 82 | * on y^m. 83 | * 84 | * For examples, or common oracle/decoding functions (like BP Loopy, Viterbi 85 | * or BP on Chain CF) refer to the examples package. 86 | */ 87 | def oracleFn(model: StructSVMModel[Pattern, Label], x: Pattern, y: Label): Label = { 88 | 89 | val weightVec = model.weights 90 | 91 | // Insert code for maximization Oracle here 92 | 93 | ??? 94 | } 95 | 96 | /** 97 | * ============== Prediction Function ============== 98 | * 99 | * Finds the best output candidate for x, given parameters w. 100 | * This is equivalent to solving: 101 | * y* = argmax_{y} < w, \phi(x^m, y) > 102 | * 103 | * Note that this is very similar to the maximization oracle, but without 104 | * the loss-augmentation. So, by default, we call the oracle function by 105 | * setting y as null. 106 | */ 107 | def predictFn(model: StructSVMModel[Pattern, Label], x: Pattern): Label = 108 | oracleFn(model, x, null) 109 | 110 | /** 111 | * ============== Driver ============== 112 | * 113 | * This is the entry point into the program. 114 | * In here, we initialize the SparkContext, set the parameters and call the 115 | * optimization routine. 116 | * 117 | * To begin with the training, we'll need three things: 118 | * a. A SparkContext instance (Defaults provided) 119 | * b. Solver Parameters (Defaults provided) 120 | * c. Data 121 | * 122 | * To execute, you should package this into a jar and provide it using 123 | * spark-submit (http://spark.apache.org/docs/latest/submitting-applications.html). 124 | * 125 | * Alternately, you can right-click and Run As -> Scala Application to run 126 | * within Eclipse. 127 | */ 128 | def main(args: Array[String]): Unit = { 129 | 130 | val appname = "DSApp" 131 | 132 | /** 133 | * ============== Initialize Spark ============== 134 | * 135 | * Alternately, use: 136 | * val conf = new SparkConf().setAppName(appname).setMaster("local[4]") 137 | * if you're planning to execute within Eclipse using 4 cores 138 | */ 139 | val conf = new SparkConf().setAppName(appname) 140 | val sc = new SparkContext(conf) 141 | sc.setCheckpointDir("checkpoint-files") 142 | 143 | /** 144 | * ============== Set Solver parameters ============== 145 | */ 146 | val solverOptions = new SolverOptions[Pattern, Label]() 147 | // Regularization paramater 148 | solverOptions.lambda = 0.01 149 | 150 | // Stopping criterion 151 | solverOptions.stoppingCriterion = GapThresholdCriterion 152 | solverOptions.gapThreshold = 1e-3 153 | solverOptions.gapCheck = 25 // Checks for gap every gapCheck rounds 154 | 155 | // Set the fraction of data to be used in training during each round 156 | // In this case, 50% of the data is uniformly sampled for training at the 157 | // beginning of each round 158 | solverOptions.sampleFrac = 0.5 159 | 160 | // Set how many partitions you want to split the data into. 161 | // These partitions will be local to each machine and the respective dual 162 | // variables associated with these partitions will reside locally. 163 | // Ideally, you want to set this to: #cores x #workers x 2. 164 | // If this is disabled, Spark decides on the partitioning, which be may 165 | // be suboptimal. 166 | solverOptions.enableManualPartitionSize = true 167 | solverOptions.NUM_PART = 8 168 | 169 | // Optionally, you can enable obtaining additional statistics like the 170 | // the training, test errors w.r.t to rounds, along with the gap 171 | // This is expensive as it involves a complete pass through the data. 172 | solverOptions.debug = false 173 | // This computes the statistics every debugMultiplier^i rounds. 174 | // So, in this case, it does so in 1, 2, 4, 8, ... 175 | // Beyond the 50th round, statistics is collected every 10 rounds. 176 | solverOptions.debugMultiplier = 2 177 | // Writes the statistics in CSV format in the provided path 178 | solverOptions.debugInfoPath = "path/to/statistics.csv" 179 | 180 | /** 181 | * ============== Provide Data ============== 182 | */ 183 | val trainDataRDD: RDD[LabeledObject[Pattern, Label]] = { 184 | 185 | // Insert code to load TRAIN data here 186 | 187 | ??? 188 | } 189 | val testDataRDD: RDD[LabeledObject[Pattern, Label]] = { 190 | 191 | // Insert code to load TEST data here 192 | 193 | ??? 194 | } 195 | // Optionally, set to None in case you don't want statistics on test data 196 | solverOptions.testDataRDD = Some(testDataRDD) 197 | 198 | /** 199 | * ============== Training ============== 200 | */ 201 | val trainer: StructSVMWithDBCFW[Pattern, Label] = 202 | new StructSVMWithDBCFW[Pattern, Label]( 203 | trainDataRDD, 204 | DSApp, 205 | solverOptions) 206 | 207 | val model: StructSVMModel[Pattern, Label] = trainer.trainModel() 208 | 209 | /** 210 | * ============== Store Model ============== 211 | * 212 | * Optionally, you can store the model's weight parameters. 213 | * 214 | * To load a model, you can use 215 | * val weights = breeze.linalg.csvread(new java.io.File(weightOutPath)) 216 | * val model = new StructSVMModel[Pattern, Label](weights, 0.0, null, DSApp) 217 | */ 218 | val weightOutPath = "path/to/weights.csv" 219 | val weights = model.weights.toDenseVector.toDenseMatrix 220 | breeze.linalg.csvwrite(new java.io.File(weightOutPath), weights) 221 | 222 | } 223 | } -------------------------------------------------------------------------------- /dissolve-struct-examples/build.sbt: -------------------------------------------------------------------------------- 1 | name := "DissolveStructExample" 2 | 3 | organization := "ch.ethz.dalab" 4 | 5 | version := "0.1-SNAPSHOT" 6 | 7 | scalaVersion := "2.10.4" 8 | 9 | libraryDependencies += "ch.ethz.dalab" %% "dissolvestruct" % "0.1-SNAPSHOT" 10 | 11 | libraryDependencies += "org.scalatest" % "scalatest_2.10" % "2.2.4" % "test" 12 | 13 | libraryDependencies += "org.apache.spark" %% "spark-core" % "1.4.1" 14 | 15 | libraryDependencies += "org.apache.spark" %% "spark-mllib" % "1.4.1" 16 | 17 | resolvers += "IESL Release" at "http://dev-iesl.cs.umass.edu/nexus/content/groups/public" 18 | 19 | libraryDependencies += "cc.factorie" % "factorie" % "1.0" 20 | 21 | libraryDependencies += "com.github.scopt" %% "scopt" % "3.3.0" 22 | 23 | resolvers += Resolver.sonatypeRepo("public") 24 | 25 | EclipseKeys.createSrc := EclipseCreateSrc.Default + EclipseCreateSrc.Resource 26 | 27 | mergeStrategy in assembly <<= (mergeStrategy in assembly) { (old) => 28 | { 29 | case PathList("javax", "servlet", xs @ _*) => MergeStrategy.first 30 | case PathList(ps @ _*) if ps.last endsWith ".html" => MergeStrategy.first 31 | case "application.conf" => MergeStrategy.concat 32 | case "reference.conf" => MergeStrategy.concat 33 | case "log4j.properties" => MergeStrategy.discard 34 | case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard 35 | case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard 36 | case _ => MergeStrategy.first 37 | } 38 | } 39 | 40 | test in assembly := {} 41 | -------------------------------------------------------------------------------- /dissolve-struct-examples/conf/log4j.properties: -------------------------------------------------------------------------------- 1 | # Set everything to be logged to the console 2 | log4j.rootCategory=INFO, console, file 3 | 4 | log4j.appender.console=org.apache.log4j.ConsoleAppender 5 | log4j.appender.console.target=System.err 6 | log4j.appender.console.layout=org.apache.log4j.PatternLayout 7 | log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n 8 | log4j.appender.console.Threshold=WARN 9 | 10 | # Direct log messages to a log file 11 | log4j.appender.file=org.apache.log4j.RollingFileAppender 12 | log4j.appender.file.File=logging.log 13 | log4j.appender.file.MaxFileSize=10MB 14 | log4j.appender.file.MaxBackupIndex=10 15 | log4j.appender.file.layout=org.apache.log4j.PatternLayout 16 | log4j.appender.file.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss} %-5p %c{1}:%L - %m%n 17 | log4j.appender.file.Threshold=WARN -------------------------------------------------------------------------------- /dissolve-struct-examples/lib/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | 3 | !.gitignore -------------------------------------------------------------------------------- /dissolve-struct-examples/project/plugins.sbt: -------------------------------------------------------------------------------- 1 | logLevel := Level.Warn 2 | 3 | addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.5.0") 4 | 5 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.13.0") 6 | -------------------------------------------------------------------------------- /dissolve-struct-examples/src/main/resources/adj.txt: -------------------------------------------------------------------------------- 1 | able 2 | abnormal 3 | absent 4 | absolute 5 | abstract 6 | abundant 7 | academic 8 | acceptable 9 | accessible 10 | accurate 11 | active 12 | acute 13 | addicted 14 | adequate 15 | aesthetic 16 | afraid 17 | aggressive 18 | agile 19 | agricultural 20 | alert 21 | alive 22 | aloof 23 | amber 24 | ambiguous 25 | ambitious 26 | ample 27 | angry 28 | annual 29 | anonymous 30 | applied 31 | appropriate 32 | arbitrary 33 | archaeological 34 | arrogant 35 | artificial 36 | artistic 37 | ashamed 38 | asleep 39 | assertive 40 | astonishing 41 | attractive 42 | automatic 43 | available 44 | awake 45 | aware 46 | awful 47 | awkward 48 | bad 49 | balanced 50 | bald 51 | bare 52 | basic 53 | beautiful 54 | bitter 55 | black 56 | bland 57 | blank 58 | blind 59 | blonde 60 | bloody 61 | bold 62 | brave 63 | broken 64 | brown 65 | bureaucratic 66 | busy 67 | capable 68 | careful 69 | cautious 70 | central 71 | certain 72 | characteristic 73 | charismatic 74 | cheap 75 | cheerful 76 | childish 77 | chronic 78 | civic 79 | civilian 80 | classical 81 | clean 82 | clear 83 | close 84 | closed 85 | cold 86 | color-blind 87 | colourful 88 | comfortable 89 | commercial 90 | common 91 | comparable 92 | compatible 93 | competent 94 | competitive 95 | complete 96 | complex 97 | comprehensive 98 | concrete 99 | confident 100 | conscious 101 | conservative 102 | considerable 103 | consistent 104 | constant 105 | constitutional 106 | constructive 107 | content 108 | continental 109 | continuous 110 | controversial 111 | convenient 112 | conventional 113 | cool 114 | cooperative 115 | corporate 116 | critical 117 | crude 118 | cruel 119 | cultural 120 | curious 121 | current 122 | cute 123 | daily 124 | dangerous 125 | dark 126 | dead 127 | deadly 128 | deaf 129 | decisive 130 | decorative 131 | deep 132 | definite 133 | delicate 134 | democratic 135 | dependent 136 | desirable 137 | different 138 | difficult 139 | digital 140 | diplomatic 141 | direct 142 | dirty 143 | discreet 144 | distant 145 | distinct 146 | domestic 147 | dominant 148 | dramatic 149 | dry 150 | due 151 | dull 152 | dynamic 153 | eager 154 | early 155 | easy 156 | economic 157 | educational 158 | effective 159 | efficient 160 | electronic 161 | elegant 162 | eligible 163 | eloquent 164 | emotional 165 | empirical 166 | empty 167 | encouraging 168 | enjoyable 169 | enthusiastic 170 | environmental 171 | equal 172 | essential 173 | established 174 | eternal 175 | ethical 176 | ethnic 177 | even 178 | exact 179 | excited 180 | exciting 181 | exclusive 182 | exotic 183 | expected 184 | expensive 185 | experienced 186 | experimental 187 | explicit 188 | express 189 | external 190 | extinct 191 | extraordinary 192 | fair 193 | faithful 194 | false 195 | familiar 196 | far 197 | fashionable 198 | fast 199 | fastidious 200 | fat 201 | favorable 202 | federal 203 | feminine 204 | financial 205 | fine 206 | finished 207 | first 208 | firsthand 209 | flat 210 | flawed 211 | flexible 212 | foolish 213 | formal 214 | forward 215 | fragrant 216 | frank 217 | free 218 | frequent 219 | fresh 220 | friendly 221 | frozen 222 | full 223 | full-time 224 | functional 225 | funny 226 | general 227 | generous 228 | genetic 229 | genuine 230 | geological 231 | glad 232 | glorious 233 | good 234 | gradual 235 | grand 236 | graphic 237 | grateful 238 | great 239 | green 240 | gregarious 241 | handy 242 | happy 243 | hard 244 | harmful 245 | harsh 246 | healthy 247 | heavy 248 | helpful 249 | helpless 250 | high 251 | hilarious 252 | historical 253 | holy 254 | homosexual 255 | honest 256 | honorable 257 | horizontal 258 | hostile 259 | hot 260 | huge 261 | human 262 | hungry 263 | ignorant 264 | illegal 265 | immune 266 | imperial 267 | implicit 268 | important 269 | impossible 270 | impressive 271 | inadequate 272 | inappropriate 273 | incapable 274 | incongruous 275 | incredible 276 | independent 277 | indigenous 278 | indirect 279 | indoor 280 | industrial 281 | inevitable 282 | infinite 283 | influential 284 | informal 285 | inner 286 | innocent 287 | insufficient 288 | integrated 289 | intellectual 290 | intense 291 | interactive 292 | interesting 293 | intermediate 294 | internal 295 | international 296 | invisible 297 | irrelevant 298 | jealous 299 | joint 300 | judicial 301 | junior 302 | just 303 | kind 304 | large 305 | last 306 | late 307 | latest 308 | lazy 309 | left 310 | legal 311 | legislative 312 | liberal 313 | light 314 | likely 315 | limited 316 | linear 317 | liquid 318 | literary 319 | live 320 | lively 321 | logical 322 | lonely 323 | long 324 | loose 325 | lost 326 | loud 327 | low 328 | loyal 329 | lucky 330 | magnetic 331 | main 332 | major 333 | manual 334 | marine 335 | married 336 | mathematical 337 | mature 338 | maximum 339 | meaningful 340 | mechanical 341 | medieval 342 | memorable 343 | mental 344 | middle-class 345 | mild 346 | military 347 | minimum 348 | minor 349 | miserable 350 | mobile 351 | modern 352 | modest 353 | molecular 354 | monstrous 355 | monthly 356 | moral 357 | moving 358 | multiple 359 | municipal 360 | musical 361 | mutual 362 | narrow 363 | national 364 | native 365 | necessary 366 | negative 367 | nervous 368 | neutral 369 | new 370 | nice 371 | noble 372 | noisy 373 | normal 374 | notorious 375 | nuclear 376 | obese 377 | objective 378 | obscure 379 | obvious 380 | occupational 381 | odd 382 | offensive 383 | official 384 | old 385 | open 386 | operational 387 | opposed 388 | optimistic 389 | optional 390 | oral 391 | ordinary 392 | organic 393 | original 394 | orthodox 395 | other 396 | outer 397 | outside 398 | painful 399 | parallel 400 | paralyzed 401 | parental 402 | particular 403 | part-time 404 | passionate 405 | passive 406 | past 407 | patient 408 | peaceful 409 | perfect 410 | permanent 411 | persistent 412 | personal 413 | petty 414 | philosophical 415 | physical 416 | plain 417 | pleasant 418 | polite 419 | political 420 | poor 421 | popular 422 | portable 423 | positive 424 | possible 425 | powerful 426 | practical 427 | precise 428 | predictable 429 | pregnant 430 | premature 431 | present 432 | presidential 433 | primary 434 | private 435 | privileged 436 | productive 437 | professional 438 | profound 439 | progressive 440 | prolonged 441 | proper 442 | proportional 443 | proud 444 | provincial 445 | public 446 | pure 447 | qualified 448 | quantitative 449 | quiet 450 | racial 451 | random 452 | rare 453 | rational 454 | raw 455 | ready 456 | real 457 | realistic 458 | reasonable 459 | reckless 460 | regional 461 | regular 462 | related 463 | relative 464 | relevant 465 | reliable 466 | religious 467 | representative 468 | resident 469 | residential 470 | respectable 471 | responsible 472 | restless 473 | restricted 474 | retired 475 | revolutionary 476 | rich 477 | right 478 | romantic 479 | rotten 480 | rough 481 | round 482 | rural 483 | sacred 484 | sad 485 | safe 486 | satisfactory 487 | satisfied 488 | scientific 489 | seasonal 490 | secondary 491 | secular 492 | secure 493 | senior 494 | sensitive 495 | separate 496 | serious 497 | sexual 498 | shallow 499 | sharp 500 | short 501 | shy 502 | sick 503 | similar 504 | single 505 | skilled 506 | slippery 507 | slow 508 | small 509 | smart 510 | smooth 511 | social 512 | socialist 513 | soft 514 | solar 515 | solid 516 | sophisticated 517 | sound 518 | sour 519 | spatial 520 | specified 521 | spontaneous 522 | square 523 | stable 524 | standard 525 | statistical 526 | steady 527 | steep 528 | sticky 529 | still 530 | straight 531 | strange 532 | strategic 533 | strict 534 | strong 535 | structural 536 | stubborn 537 | stunning 538 | stupid 539 | subjective 540 | subsequent 541 | successful 542 | sudden 543 | sufficient 544 | superior 545 | supplementary 546 | surprised 547 | surprising 548 | sweet 549 | sympathetic 550 | systematic 551 | talented 552 | talkative 553 | tall 554 | tasty 555 | technical 556 | temporary 557 | tender 558 | tense 559 | terminal 560 | thick 561 | thin 562 | thirsty 563 | thoughtful 564 | tidy 565 | tight 566 | tired 567 | tolerant 568 | tough 569 | toxic 570 | traditional 571 | transparent 572 | trivial 573 | tropical 574 | true 575 | typical 576 | ugly 577 | ultimate 578 | unanimous 579 | unaware 580 | uncomfortable 581 | uneasy 582 | unemployed 583 | unexpected 584 | unfair 585 | unfortunate 586 | uniform 587 | unique 588 | universal 589 | unlawful 590 | unlike 591 | unlikely 592 | unpleasant 593 | urban 594 | useful 595 | useless 596 | usual 597 | vacant 598 | vague 599 | vain 600 | valid 601 | valuable 602 | varied 603 | verbal 604 | vertical 605 | viable 606 | vicious 607 | vigorous 608 | violent 609 | visible 610 | visual 611 | vocational 612 | voluntary 613 | vulnerable 614 | warm 615 | weak 616 | weekly 617 | welcome 618 | well 619 | wet 620 | white 621 | whole 622 | wild 623 | wise 624 | written 625 | wrong 626 | young 627 | -------------------------------------------------------------------------------- /dissolve-struct-examples/src/main/resources/chain_test.csv: -------------------------------------------------------------------------------- 1 | 214,150,602,408,267,440,253,191,67,338,111,364,118,114,10,611,415,342,8,82,411,141,117,102,281,386,544,204,409,42,622,27,594,47,153,272,596,325,606,53,570,344,162,203,341,13,477,502,474,107,146,569,113,311,275,260,14,470,308,432,236,583,232,259,464,418,32,567,453,462,181,592,459,609,626,244,575,241,531,194,263,103,530,410,94,334,268,18,549,139,65,37,235,105,525,358,615,62,161,589,333,121,86,254,424,206,288,230,74,4,492,603,258,573,387,414,202,447,448,185,557,207,97,585,31,55,506,100,476,367,584,234,158,439,467,540,261,483,518,610,271,189,226,331,50,522,197,147,362,149,283,222,345,72,382,12,78,108,501,310,302,402,122,473,23,552,558,237,217,363,163,328,46,346,51,61,239,520,126,19,266,306,240,297,179,106,223,95,457,625,20,400,265,535,456,417,543,209,136,145,24,144,485,542,252,129,227,6,322,93,152,335,243,554,425,290,96,182,154,351,279,172,286,138,354,379,478,127,555,532,164,56,577,284,3,348,388,228,39,463,416,200,446,454,605,143,578,495,488,368,216,445,356,365,250,620,422,572,427,372,187,16,498,273,77,564,623,312,79,291,330,539,370,595,49,270,287,134,303,229,429,514,505,378,87,212,238,534,196,249,529,443,195,507,504,561,30,9,517,607,225,170,7,112,437,513,304,546,99,489,233,277,289,198,70,374,384,460,282,52,71,523,475,391,83,503,190,455,579,321,465,208,135,593,452,43,140,123,614,496,451,441,300,116,280,40,213,180,436,493,177,101,508,524,376,500,481,586,15,613,420,617,242,90,480,22,66,419,218,115,307,34,317,479,245,199,91,574,377,119,98,73,357,430,339,324,110,619,45,571,527,278,148,537,399,315,92,175,512,515,178,431,151,44,393,309,392,255,582,547,167,587,215,526,406,381,327,482,128,450,433,60,556,423,89,76,41,326,292,171,336,600,188,299,25,320,486,497,545,131,156,26,251,624,169,28,359,124,295,301,519,159,125,490,5,219,201,510,298,405,142,347,565,494,63,469,458,360,566,487,269,337,395,349,516,421,444,548,396,428,68,413,247,581,580,468,88,536,183,604,394,366,407,257,314,355,521,616,438,401,155,601,296,21,11,350,404,211,369,559,80,57,84,176,560,2,205,48,157,133,383,426,294,551,33,316,186,120,568,285,332,130,35,192,305,81,385,193,319,329,590,598,618,221,373,313,621,434,412,472,1,293,165,397,361,389,36,390,264,442,375,343,173,323,85,541,109,248,435,449,380,256,553,599,276,491,59,64,563,484,184,168,246,403,38,398,274,220,528,75,340,262,132,318,104,591,461,588,160,576,471,466,29,612,17,511,597,210,166,353,54,608,58,538,509,69,499,174,231,352,224,533,371,550,137,562 2 | -------------------------------------------------------------------------------- /dissolve-struct-examples/src/main/resources/imageseg_cattle_test.txt: -------------------------------------------------------------------------------- 1 | 1_1_s.bmp 2 | 1_3_s.bmp 3 | 1_13_s.bmp 4 | 1_30_s.bmp 5 | 1_26_s.bmp 6 | 1_14_s.bmp 7 | 1_16_s.bmp 8 | 1_12_s.bmp 9 | 1_18_s.bmp 10 | 1_9_s.bmp 11 | 1_27_s.bmp 12 | 1_23_s.bmp 13 | 1_24_s.bmp 14 | -------------------------------------------------------------------------------- /dissolve-struct-examples/src/main/resources/imageseg_cattle_train.txt: -------------------------------------------------------------------------------- 1 | 1_6_s.bmp 2 | 1_10_s.bmp 3 | 1_4_s.bmp 4 | 1_22_s.bmp 5 | 1_11_s.bmp 6 | 1_21_s.bmp 7 | 1_19_s.bmp 8 | 1_8_s.bmp 9 | 1_7_s.bmp 10 | 1_15_s.bmp 11 | 1_5_s.bmp 12 | 1_20_s.bmp 13 | 1_29_s.bmp 14 | 1_17_s.bmp 15 | 1_2_s.bmp 16 | 1_25_s.bmp 17 | 1_28_s.bmp 18 | -------------------------------------------------------------------------------- /dissolve-struct-examples/src/main/resources/imageseg_colormap.txt: -------------------------------------------------------------------------------- 1 | 0 0 0 0 0 void 2 | 1 128 0 0 128 cow 3 | 2 16320 0 64 0 bird 4 | 3 32640 0 128 0 grass 5 | 4 32768 0 128 128 sheep 6 | 5 48960 0 192 0 chair 7 | 6 49088 0 192 128 cat 8 | 7 4161600 64 0 0 mountain 9 | 8 4161728 64 0 128 car 10 | 9 4177920 64 64 0 body 11 | 10 4194240 64 128 0 water 12 | 11 4194368 64 128 128 flower 13 | 12 8323200 128 0 0 building 14 | 13 8323328 128 0 128 horse 15 | 14 8339520 128 64 0 book 16 | 15 8339648 128 64 128 road 17 | 16 8355840 128 128 0 tree 18 | 17 8355968 128 128 128 sky 19 | 18 8372288 128 192 128 dog 20 | 19 12484800 192 0 0 aeroplane 21 | 20 12484928 192 0 128 bicycle 22 | 21 12501120 192 64 0 boat 23 | 22 12517440 192 128 0 face 24 | 23 12517568 192 128 128 sign 25 | -------------------------------------------------------------------------------- /dissolve-struct-examples/src/main/resources/imageseg_lab_freq.txt: -------------------------------------------------------------------------------- 1 | 0,0.269829 2 | 1,0.021016 3 | 2,0.009487 4 | 3,0.142275 5 | 4,0.016398 6 | 5,0.013469 7 | 6,0.011448 8 | 7,0.007311 9 | 8,0.023944 10 | 9,0.015565 11 | 10,0.055292 12 | 11,0.022215 13 | 12,0.076981 14 | 13,0.000735 15 | 14,0.038273 16 | 15,0.061869 17 | 16,0.064517 18 | 17,0.069084 19 | 18,0.012804 20 | 19,0.010377 21 | 20,0.020019 22 | 21,0.006815 23 | 22,0.012681 24 | 23,0.017599 25 | -------------------------------------------------------------------------------- /dissolve-struct-examples/src/main/resources/imageseg_label_color_map.txt: -------------------------------------------------------------------------------- 1 | 12 8323200 128 0 0 building 0 2 | 3 32640 0 128 0 grass 1 3 | 16 8355840 128 128 0 tree 2 4 | 1 128 0 0 128 cow 3 5 | 4 32768 0 128 128 sheep 4 6 | 17 8355968 128 128 128 sky 5 7 | 19 12484800 192 0 0 aeroplane 6 8 | 10 4194240 64 128 0 water 7 9 | 22 12517440 192 128 0 face 8 10 | 8 4161728 64 0 128 car 9 11 | 20 12484928 192 0 128 bicycle 10 12 | 11 4194368 64 128 128 flower 11 13 | 23 12517568 192 128 128 sign 12 14 | 2 16320 0 64 0 bird 13 15 | 14 8339520 128 64 0 book 14 16 | 5 48960 0 192 0 chair 15 17 | 15 8339648 128 64 128 road 16 18 | 6 49088 0 192 128 cat 17 19 | 18 8372288 128 192 128 dog 18 20 | 9 4177920 64 64 0 body 19 21 | 21 12501120 192 64 0 boat 20 22 | 0 0 0 0 0 void 21 23 | -------------------------------------------------------------------------------- /dissolve-struct-examples/src/main/resources/imageseg_test.txt: -------------------------------------------------------------------------------- 1 | 18_5_s.bmp 2 | 18_21_s.bmp 3 | 15_22_s.bmp 4 | 15_1_s.bmp 5 | 1_1_s.bmp 6 | 20_9_s.bmp 7 | 2_8_s.bmp 8 | 15_19_s.bmp 9 | 2_10_s.bmp 10 | 18_26_s.bmp 11 | 15_5_s.bmp 12 | 5_11_s.bmp 13 | 18_30_s.bmp 14 | 1_3_s.bmp 15 | 18_29_s.bmp 16 | 1_13_s.bmp 17 | 15_13_s.bmp 18 | 13_11_s.bmp 19 | 15_2_s.bmp 20 | 20_17_s.bmp 21 | 11_9_s.bmp 22 | 1_30_s.bmp 23 | 13_7_s.bmp 24 | 16_12_s.bmp 25 | 9_3_s.bmp 26 | 4_28_s.bmp 27 | 12_19_s.bmp 28 | 13_8_s.bmp 29 | 3_22_s.bmp 30 | 11_5_s.bmp 31 | 16_24_s.bmp 32 | 3_20_s.bmp 33 | 4_27_s.bmp 34 | 17_15_s.bmp 35 | 13_22_s.bmp 36 | 12_26_s.bmp 37 | 5_3_s.bmp 38 | 18_23_s.bmp 39 | 16_4_s.bmp 40 | 18_20_s.bmp 41 | 13_25_s.bmp 42 | 11_15_s.bmp 43 | 5_13_s.bmp 44 | 19_7_s.bmp 45 | 12_7_s.bmp 46 | 2_25_s.bmp 47 | 12_8_s.bmp 48 | 20_2_s.bmp 49 | 15_3_s.bmp 50 | 7_26_s.bmp 51 | 9_20_s.bmp 52 | 1_26_s.bmp 53 | 10_2_s.bmp 54 | 13_20_s.bmp 55 | 14_21_s.bmp 56 | 10_18_s.bmp 57 | 5_2_s.bmp 58 | 14_13_s.bmp 59 | 4_26_s.bmp 60 | 15_7_s.bmp 61 | 17_19_s.bmp 62 | 18_24_s.bmp 63 | 9_24_s.bmp 64 | 10_1_s.bmp 65 | 1_14_s.bmp 66 | 10_22_s.bmp 67 | 10_25_s.bmp 68 | 12_33_s.bmp 69 | 3_16_s.bmp 70 | 4_16_s.bmp 71 | 9_7_s.bmp 72 | 4_1_s.bmp 73 | 12_9_s.bmp 74 | 4_13_s.bmp 75 | 19_3_s.bmp 76 | 16_16_s.bmp 77 | 17_5_s.bmp 78 | 17_30_s.bmp 79 | 20_1_s.bmp 80 | 9_1_s.bmp 81 | 20_16_s.bmp 82 | 17_8_s.bmp 83 | 8_19_s.bmp 84 | 9_9_s.bmp 85 | 8_6_s.bmp 86 | 19_12_s.bmp 87 | 3_5_s.bmp 88 | 8_18_s.bmp 89 | 18_7_s.bmp 90 | 16_5_s.bmp 91 | 10_29_s.bmp 92 | 7_29_s.bmp 93 | 20_20_s.bmp 94 | 8_22_s.bmp 95 | 2_7_s.bmp 96 | 13_9_s.bmp 97 | 16_10_s.bmp 98 | 2_14_s.bmp 99 | 9_13_s.bmp 100 | 14_6_s.bmp 101 | 5_14_s.bmp 102 | 18_12_s.bmp 103 | 13_1_s.bmp 104 | 6_16_s.bmp 105 | 7_5_s.bmp 106 | 3_9_s.bmp 107 | 10_12_s.bmp 108 | 3_12_s.bmp 109 | 3_8_s.bmp 110 | 4_24_s.bmp 111 | 2_29_s.bmp 112 | 13_4_s.bmp 113 | 11_13_s.bmp 114 | 20_8_s.bmp 115 | 10_4_s.bmp 116 | 10_31_s.bmp 117 | 3_26_s.bmp 118 | 3_7_s.bmp 119 | 19_23_s.bmp 120 | 8_27_s.bmp 121 | 6_22_s.bmp 122 | 14_15_s.bmp 123 | 7_12_s.bmp 124 | 7_21_s.bmp 125 | 2_18_s.bmp 126 | 11_4_s.bmp 127 | 7_14_s.bmp 128 | 8_29_s.bmp 129 | 4_30_s.bmp 130 | 14_17_s.bmp 131 | 8_10_s.bmp 132 | 1_16_s.bmp 133 | 2_26_s.bmp 134 | 14_9_s.bmp 135 | 18_28_s.bmp 136 | 5_23_s.bmp 137 | 14_10_s.bmp 138 | 19_30_s.bmp 139 | 12_31_s.bmp 140 | 16_14_s.bmp 141 | 12_14_s.bmp 142 | 14_18_s.bmp 143 | 11_14_s.bmp 144 | 18_3_s.bmp 145 | 16_13_s.bmp 146 | 3_30_s.bmp 147 | 6_5_s.bmp 148 | 19_24_s.bmp 149 | 17_17_s.bmp 150 | 19_15_s.bmp 151 | 11_25_s.bmp 152 | 2_22_s.bmp 153 | 2_30_s.bmp 154 | 17_28_s.bmp 155 | 6_13_s.bmp 156 | 9_2_s.bmp 157 | 11_18_s.bmp 158 | 19_16_s.bmp 159 | 8_17_s.bmp 160 | 2_28_s.bmp 161 | 16_23_s.bmp 162 | 1_12_s.bmp 163 | 4_20_s.bmp 164 | 7_15_s.bmp 165 | 11_8_s.bmp 166 | 11_6_s.bmp 167 | 3_18_s.bmp 168 | 1_18_s.bmp 169 | 1_9_s.bmp 170 | 19_29_s.bmp 171 | 12_2_s.bmp 172 | 17_26_s.bmp 173 | 16_1_s.bmp 174 | 19_8_s.bmp 175 | 13_2_s.bmp 176 | 16_6_s.bmp 177 | 8_24_s.bmp 178 | 5_28_s.bmp 179 | 15_6_s.bmp 180 | 13_17_s.bmp 181 | 17_23_s.bmp 182 | 9_19_s.bmp 183 | 19_13_s.bmp 184 | 9_18_s.bmp 185 | 14_20_s.bmp 186 | 17_4_s.bmp 187 | 6_11_s.bmp 188 | 7_25_s.bmp 189 | 2_16_s.bmp 190 | 12_16_s.bmp 191 | 12_23_s.bmp 192 | 6_24_s.bmp 193 | 12_27_s.bmp 194 | 16_8_s.bmp 195 | 10_17_s.bmp 196 | 5_17_s.bmp 197 | 12_12_s.bmp 198 | 20_14_s.bmp 199 | 12_30_s.bmp 200 | 9_17_s.bmp 201 | 17_25_s.bmp 202 | 14_16_s.bmp 203 | 4_15_s.bmp 204 | 14_28_s.bmp 205 | 7_10_s.bmp 206 | 5_19_s.bmp 207 | 10_3_s.bmp 208 | 5_27_s.bmp 209 | 8_11_s.bmp 210 | 5_4_s.bmp 211 | 5_10_s.bmp 212 | 10_28_s.bmp 213 | 6_7_s.bmp 214 | 4_23_s.bmp 215 | 11_3_s.bmp 216 | 10_7_s.bmp 217 | 13_21_s.bmp 218 | 7_7_s.bmp 219 | 16_27_s.bmp 220 | 13_27_s.bmp 221 | 6_25_s.bmp 222 | 14_14_s.bmp 223 | 17_6_s.bmp 224 | 12_13_s.bmp 225 | 18_6_s.bmp 226 | 19_28_s.bmp 227 | 14_5_s.bmp 228 | 8_2_s.bmp 229 | 3_29_s.bmp 230 | 6_12_s.bmp 231 | 4_3_s.bmp 232 | 1_27_s.bmp 233 | 19_20_s.bmp 234 | 8_4_s.bmp 235 | 7_22_s.bmp 236 | 4_10_s.bmp 237 | 6_2_s.bmp 238 | 6_15_s.bmp 239 | 8_8_s.bmp 240 | 3_13_s.bmp 241 | 9_5_s.bmp 242 | 9_26_s.bmp 243 | 5_12_s.bmp 244 | 1_23_s.bmp 245 | 15_23_s.bmp 246 | 1_24_s.bmp 247 | 20_6_s.bmp 248 | 17_12_s.bmp 249 | 6_10_s.bmp 250 | 2_17_s.bmp 251 | 11_7_s.bmp 252 | 11_19_s.bmp 253 | 7_23_s.bmp 254 | 10_16_s.bmp 255 | 7_3_s.bmp 256 | 6_3_s.bmp 257 | -------------------------------------------------------------------------------- /dissolve-struct-examples/src/main/resources/imageseg_train.txt: -------------------------------------------------------------------------------- 1 | 11_16_s.bmp 2 | 11_27_s.bmp 3 | 11_12_s.bmp 4 | 12_34_s.bmp 5 | 11_2_s.bmp 6 | 12_28_s.bmp 7 | 16_11_s.bmp 8 | 15_24_s.bmp 9 | 8_25_s.bmp 10 | 2_15_s.bmp 11 | 11_22_s.bmp 12 | 18_9_s.bmp 13 | 9_23_s.bmp 14 | 19_2_s.bmp 15 | 2_12_s.bmp 16 | 3_11_s.bmp 17 | 20_11_s.bmp 18 | 5_9_s.bmp 19 | 14_22_s.bmp 20 | 3_24_s.bmp 21 | 4_6_s.bmp 22 | 9_10_s.bmp 23 | 11_11_s.bmp 24 | 2_20_s.bmp 25 | 1_6_s.bmp 26 | 16_30_s.bmp 27 | 12_1_s.bmp 28 | 9_22_s.bmp 29 | 15_18_s.bmp 30 | 18_27_s.bmp 31 | 17_16_s.bmp 32 | 10_8_s.bmp 33 | 1_10_s.bmp 34 | 16_17_s.bmp 35 | 17_13_s.bmp 36 | 13_3_s.bmp 37 | 1_4_s.bmp 38 | 5_29_s.bmp 39 | 1_22_s.bmp 40 | 20_21_s.bmp 41 | 2_19_s.bmp 42 | 1_11_s.bmp 43 | 18_13_s.bmp 44 | 3_15_s.bmp 45 | 5_24_s.bmp 46 | 4_18_s.bmp 47 | 18_11_s.bmp 48 | 13_26_s.bmp 49 | 9_14_s.bmp 50 | 19_27_s.bmp 51 | 12_10_s.bmp 52 | 18_8_s.bmp 53 | 5_6_s.bmp 54 | 12_11_s.bmp 55 | 18_17_s.bmp 56 | 4_8_s.bmp 57 | 10_32_s.bmp 58 | 18_2_s.bmp 59 | 17_21_s.bmp 60 | 14_11_s.bmp 61 | 3_19_s.bmp 62 | 3_4_s.bmp 63 | 8_14_s.bmp 64 | 11_30_s.bmp 65 | 11_29_s.bmp 66 | 9_27_s.bmp 67 | 17_11_s.bmp 68 | 13_13_s.bmp 69 | 4_21_s.bmp 70 | 19_14_s.bmp 71 | 11_10_s.bmp 72 | 18_15_s.bmp 73 | 10_21_s.bmp 74 | 14_3_s.bmp 75 | 1_21_s.bmp 76 | 7_9_s.bmp 77 | 16_28_s.bmp 78 | 15_10_s.bmp 79 | 15_21_s.bmp 80 | 13_16_s.bmp 81 | 2_9_s.bmp 82 | 13_24_s.bmp 83 | 14_19_s.bmp 84 | 15_17_s.bmp 85 | 19_26_s.bmp 86 | 15_20_s.bmp 87 | 20_18_s.bmp 88 | 18_22_s.bmp 89 | 10_15_s.bmp 90 | 2_1_s.bmp 91 | 9_30_s.bmp 92 | 6_14_s.bmp 93 | 12_15_s.bmp 94 | 14_2_s.bmp 95 | 15_4_s.bmp 96 | 12_6_s.bmp 97 | 15_14_s.bmp 98 | 16_19_s.bmp 99 | 8_23_s.bmp 100 | 11_24_s.bmp 101 | 14_1_s.bmp 102 | 9_21_s.bmp 103 | 8_13_s.bmp 104 | 19_10_s.bmp 105 | 11_20_s.bmp 106 | 8_30_s.bmp 107 | 13_29_s.bmp 108 | 15_12_s.bmp 109 | 7_6_s.bmp 110 | 14_8_s.bmp 111 | 13_14_s.bmp 112 | 12_29_s.bmp 113 | 19_18_s.bmp 114 | 4_25_s.bmp 115 | 15_11_s.bmp 116 | 17_9_s.bmp 117 | 7_11_s.bmp 118 | 18_16_s.bmp 119 | 20_12_s.bmp 120 | 3_25_s.bmp 121 | 14_12_s.bmp 122 | 2_5_s.bmp 123 | 11_26_s.bmp 124 | 20_3_s.bmp 125 | 9_16_s.bmp 126 | 6_27_s.bmp 127 | 10_6_s.bmp 128 | 15_16_s.bmp 129 | 17_2_s.bmp 130 | 8_5_s.bmp 131 | 16_25_s.bmp 132 | 18_18_s.bmp 133 | 18_4_s.bmp 134 | 10_5_s.bmp 135 | 12_32_s.bmp 136 | 17_10_s.bmp 137 | 18_25_s.bmp 138 | 19_5_s.bmp 139 | 8_12_s.bmp 140 | 19_25_s.bmp 141 | 3_3_s.bmp 142 | 13_12_s.bmp 143 | 3_10_s.bmp 144 | 18_19_s.bmp 145 | 6_17_s.bmp 146 | 9_29_s.bmp 147 | 12_24_s.bmp 148 | 19_11_s.bmp 149 | 12_21_s.bmp 150 | 17_29_s.bmp 151 | 6_9_s.bmp 152 | 10_27_s.bmp 153 | 4_29_s.bmp 154 | 17_1_s.bmp 155 | 14_25_s.bmp 156 | 6_6_s.bmp 157 | 1_19_s.bmp 158 | 10_10_s.bmp 159 | 19_21_s.bmp 160 | 7_4_s.bmp 161 | 3_21_s.bmp 162 | 7_20_s.bmp 163 | 1_8_s.bmp 164 | 11_1_s.bmp 165 | 5_26_s.bmp 166 | 12_18_s.bmp 167 | 1_7_s.bmp 168 | 12_4_s.bmp 169 | 7_16_s.bmp 170 | 2_13_s.bmp 171 | 20_10_s.bmp 172 | 14_30_s.bmp 173 | 19_19_s.bmp 174 | 5_15_s.bmp 175 | 17_14_s.bmp 176 | 17_18_s.bmp 177 | 5_16_s.bmp 178 | 20_5_s.bmp 179 | 14_23_s.bmp 180 | 11_23_s.bmp 181 | 10_20_s.bmp 182 | 13_5_s.bmp 183 | 5_1_s.bmp 184 | 5_25_s.bmp 185 | 17_20_s.bmp 186 | 17_7_s.bmp 187 | 5_8_s.bmp 188 | 9_28_s.bmp 189 | 6_18_s.bmp 190 | 5_21_s.bmp 191 | 2_21_s.bmp 192 | 7_1_s.bmp 193 | 3_27_s.bmp 194 | 20_7_s.bmp 195 | 12_17_s.bmp 196 | 6_29_s.bmp 197 | 1_15_s.bmp 198 | 5_20_s.bmp 199 | 13_15_s.bmp 200 | 16_21_s.bmp 201 | 1_5_s.bmp 202 | 6_23_s.bmp 203 | 10_24_s.bmp 204 | 7_19_s.bmp 205 | 8_28_s.bmp 206 | 2_24_s.bmp 207 | 4_4_s.bmp 208 | 2_4_s.bmp 209 | 16_26_s.bmp 210 | 1_20_s.bmp 211 | 5_18_s.bmp 212 | 20_13_s.bmp 213 | 5_22_s.bmp 214 | 12_3_s.bmp 215 | 13_18_s.bmp 216 | 16_7_s.bmp 217 | 7_8_s.bmp 218 | 2_6_s.bmp 219 | 2_23_s.bmp 220 | 1_29_s.bmp 221 | 6_1_s.bmp 222 | 1_17_s.bmp 223 | 20_4_s.bmp 224 | 16_22_s.bmp 225 | 10_11_s.bmp 226 | 8_26_s.bmp 227 | 7_27_s.bmp 228 | 6_8_s.bmp 229 | 12_22_s.bmp 230 | 3_17_s.bmp 231 | 2_3_s.bmp 232 | 4_5_s.bmp 233 | 4_14_s.bmp 234 | 4_12_s.bmp 235 | 16_18_s.bmp 236 | 4_7_s.bmp 237 | 9_4_s.bmp 238 | 19_17_s.bmp 239 | 3_2_s.bmp 240 | 8_21_s.bmp 241 | 7_28_s.bmp 242 | 14_27_s.bmp 243 | 9_12_s.bmp 244 | 17_22_s.bmp 245 | 13_19_s.bmp 246 | 7_18_s.bmp 247 | 14_29_s.bmp 248 | 19_6_s.bmp 249 | 4_9_s.bmp 250 | 4_2_s.bmp 251 | 10_9_s.bmp 252 | 14_26_s.bmp 253 | 10_13_s.bmp 254 | 13_28_s.bmp 255 | 8_3_s.bmp 256 | 4_17_s.bmp 257 | 16_3_s.bmp 258 | 16_15_s.bmp 259 | 3_1_s.bmp 260 | 13_10_s.bmp 261 | 16_9_s.bmp 262 | 10_30_s.bmp 263 | 19_4_s.bmp 264 | 7_17_s.bmp 265 | 6_26_s.bmp 266 | 9_11_s.bmp 267 | 9_6_s.bmp 268 | 6_4_s.bmp 269 | 8_1_s.bmp 270 | 7_2_s.bmp 271 | 8_9_s.bmp 272 | 10_14_s.bmp 273 | 6_28_s.bmp 274 | 3_14_s.bmp 275 | 6_20_s.bmp 276 | 8_16_s.bmp 277 | 18_14_s.bmp 278 | 15_8_s.bmp 279 | 18_1_s.bmp 280 | 18_10_s.bmp 281 | 15_9_s.bmp 282 | 11_21_s.bmp 283 | 15_15_s.bmp 284 | 1_2_s.bmp 285 | 5_5_s.bmp 286 | 11_28_s.bmp 287 | 1_25_s.bmp 288 | 1_28_s.bmp 289 | 5_30_s.bmp 290 | 2_11_s.bmp 291 | 5_7_s.bmp 292 | 2_27_s.bmp 293 | 20_19_s.bmp 294 | 20_15_s.bmp 295 | 17_24_s.bmp 296 | 2_2_s.bmp 297 | 12_20_s.bmp 298 | 17_27_s.bmp 299 | 13_30_s.bmp 300 | 13_6_s.bmp 301 | 4_22_s.bmp 302 | 14_7_s.bmp 303 | 12_25_s.bmp 304 | 16_2_s.bmp 305 | 9_15_s.bmp 306 | 10_23_s.bmp 307 | 11_17_s.bmp 308 | 13_23_s.bmp 309 | 17_3_s.bmp 310 | 16_20_s.bmp 311 | 7_24_s.bmp 312 | 3_23_s.bmp 313 | 3_28_s.bmp 314 | 10_19_s.bmp 315 | 16_29_s.bmp 316 | 14_24_s.bmp 317 | 9_8_s.bmp 318 | 12_5_s.bmp 319 | 9_25_s.bmp 320 | 4_19_s.bmp 321 | 19_9_s.bmp 322 | 4_11_s.bmp 323 | 19_1_s.bmp 324 | 7_30_s.bmp 325 | 3_6_s.bmp 326 | 10_26_s.bmp 327 | 19_22_s.bmp 328 | 14_4_s.bmp 329 | 7_13_s.bmp 330 | 8_20_s.bmp 331 | 8_7_s.bmp 332 | 6_19_s.bmp 333 | 8_15_s.bmp 334 | 6_30_s.bmp 335 | 6_21_s.bmp -------------------------------------------------------------------------------- /dissolve-struct-examples/src/main/scala/ch/ethz/dalab/dissolve/examples/binaryclassification/AdultBinary.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.examples.binaryclassification 2 | 3 | import ch.ethz.dalab.dissolve.regression.LabeledObject 4 | import ch.ethz.dalab.dissolve.optimization.SolverOptions 5 | import ch.ethz.dalab.dissolve.classification.StructSVMWithDBCFW 6 | import ch.ethz.dalab.dissolve.classification.BinarySVMWithDBCFW 7 | import ch.ethz.dalab.dissolve.classification.StructSVMModel 8 | import ch.ethz.dalab.dissolve.optimization.SolverUtils 9 | import org.apache.spark.mllib.util.MLUtils 10 | import org.apache.spark.SparkConf 11 | import org.apache.spark.SparkContext 12 | import org.apache.spark.rdd.RDD 13 | import org.apache.spark.mllib.regression.LabeledPoint 14 | import org.apache.spark.mllib.classification.SVMWithSGD 15 | import breeze.linalg._ 16 | import breeze.numerics.abs 17 | 18 | /** 19 | * 20 | * Dataset: Adult (http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#a1a) 21 | * Type: Binary 22 | * 23 | * Created by tribhu on 12/11/14. 24 | */ 25 | object AdultBinary { 26 | 27 | /** 28 | * DBCFW Implementation 29 | */ 30 | def dbcfwAdult() { 31 | val a1aPath = "../data/generated/a1a" 32 | 33 | // Fix seed for reproducibility 34 | util.Random.setSeed(1) 35 | 36 | val conf = new SparkConf().setAppName("Adult-example").setMaster("local") 37 | val sc = new SparkContext(conf) 38 | sc.setCheckpointDir("checkpoint-files") 39 | 40 | val solverOptions: SolverOptions[Vector[Double], Double] = new SolverOptions() 41 | 42 | solverOptions.roundLimit = 20 // After these many passes, each slice of the RDD returns a trained model 43 | solverOptions.debug = true 44 | solverOptions.lambda = 0.01 45 | solverOptions.doWeightedAveraging = false 46 | solverOptions.doLineSearch = true 47 | solverOptions.debug = false 48 | 49 | solverOptions.sampleWithReplacement = false 50 | 51 | solverOptions.enableManualPartitionSize = true 52 | solverOptions.NUM_PART = 1 53 | 54 | solverOptions.sample = "frac" 55 | solverOptions.sampleFrac = 0.5 56 | 57 | solverOptions.enableOracleCache = false 58 | 59 | solverOptions.debugInfoPath = "../debug/debugInfo-a1a-%d.csv".format(System.currentTimeMillis()) 60 | 61 | 62 | val data = MLUtils.loadLibSVMFile(sc, a1aPath) 63 | 64 | // Split data into training and test set 65 | val splits = data.randomSplit(Array(0.8, 0.2), seed = 1L) 66 | val training = splits(0) 67 | val test = splits(1) 68 | 69 | val objectifiedTraining: RDD[LabeledObject[Vector[Double], Double]] = 70 | training.map { 71 | case x: LabeledPoint => 72 | new LabeledObject[Vector[Double], Double](x.label, Vector(x.features.toArray)) // Is the asInstanceOf required? 73 | } 74 | 75 | val objectifiedTest: RDD[LabeledObject[Vector[Double], Double]] = 76 | test.map { 77 | case x: LabeledPoint => 78 | new LabeledObject[Vector[Double], Double](x.label, Vector(x.features.toArray)) // Is the asInstanceOf required? 79 | } 80 | 81 | solverOptions.testDataRDD = Some(objectifiedTest) 82 | val model = BinarySVMWithDBCFW.train(training, solverOptions) 83 | 84 | // Training Errors 85 | val trueTrainingPredictions = 86 | objectifiedTraining.map { 87 | case x: LabeledObject[Vector[Double], Double] => 88 | val prediction = model.predict(x.pattern) 89 | if (prediction == x.label) 90 | 1 91 | else 92 | 0 93 | }.fold(0)((acc, ele) => acc + ele) 94 | 95 | println("Accuracy on Training set = %d/%d = %.4f".format(trueTrainingPredictions, 96 | objectifiedTraining.count(), 97 | (trueTrainingPredictions.toDouble / objectifiedTraining.count().toDouble) * 100)) 98 | 99 | // Test Errors 100 | val trueTestPredictions = 101 | objectifiedTest.map { 102 | case x: LabeledObject[Vector[Double], Double] => 103 | val prediction = model.predict(x.pattern) 104 | if (prediction == x.label) 105 | 1 106 | else 107 | 0 108 | }.fold(0)((acc, ele) => acc + ele) 109 | 110 | println("Accuracy on Test set = %d/%d = %.4f".format(trueTestPredictions, 111 | objectifiedTest.count(), 112 | (trueTestPredictions.toDouble / objectifiedTest.count().toDouble) * 100)) 113 | } 114 | 115 | /** 116 | * MLLib's SVMWithSGD implementation 117 | */ 118 | def mllibAdult() { 119 | 120 | val conf = new SparkConf().setAppName("Adult-example").setMaster("local") 121 | val sc = new SparkContext(conf) 122 | sc.setCheckpointDir("checkpoint-files") 123 | 124 | val data = MLUtils.loadLibSVMFile(sc, "../data/a1a_mllib.txt") 125 | 126 | // Split data into training and test set 127 | val splits = data.randomSplit(Array(0.8, 0.2), seed = 1L) 128 | val training = splits(0) 129 | val test = splits(1) 130 | 131 | // Run training algorithm to build the model 132 | val numIterations = 100 133 | val model = SVMWithSGD.train(training, numIterations) 134 | 135 | val trainError = training.map { point => 136 | val score = model.predict(point.features) 137 | score == point.label 138 | }.collect().toList.count(_ == true).toDouble / training.count().toDouble 139 | 140 | val testError = test.map { point => 141 | val score = model.predict(point.features) 142 | score == point.label 143 | }.collect().toList.count(_ == true).toDouble / test.count().toDouble 144 | 145 | println("Training accuracy = " + trainError) 146 | println("Test accuracy = " + testError) 147 | 148 | } 149 | 150 | def main(args: Array[String]): Unit = { 151 | // mllibAdult() 152 | 153 | dbcfwAdult() 154 | } 155 | 156 | } 157 | -------------------------------------------------------------------------------- /dissolve-struct-examples/src/main/scala/ch/ethz/dalab/dissolve/examples/binaryclassification/BinaryClassificationDemo.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.examples.binaryclassification 2 | 3 | import ch.ethz.dalab.dissolve.regression.LabeledObject 4 | import ch.ethz.dalab.dissolve.optimization.SolverOptions 5 | import ch.ethz.dalab.dissolve.classification.BinarySVMWithDBCFW 6 | import ch.ethz.dalab.dissolve.classification.StructSVMModel 7 | import ch.ethz.dalab.dissolve.optimization.SolverUtils 8 | import org.apache.spark.mllib.util.MLUtils 9 | import org.apache.spark.SparkConf 10 | import org.apache.spark.SparkContext 11 | import org.apache.spark.rdd.RDD 12 | import org.apache.spark.mllib.regression.LabeledPoint 13 | import org.apache.spark.mllib.classification.SVMWithSGD 14 | import org.apache.log4j.PropertyConfigurator 15 | import org.apache.spark.mllib.linalg.Vectors 16 | import breeze.linalg._ 17 | 18 | /** 19 | * Runs the binary 20 | * 21 | * Created by tribhu on 12/11/14. 22 | */ 23 | object BinaryClassificationDemo { 24 | 25 | /** 26 | * Training Implementation 27 | */ 28 | def dissolveTrain(trainData: RDD[LabeledPoint], testData: RDD[LabeledPoint]) { 29 | 30 | val solverOptions: SolverOptions[Vector[Double], Double] = new SolverOptions() 31 | 32 | solverOptions.roundLimit = 20 // After these many passes, each slice of the RDD returns a trained model 33 | solverOptions.debug = true 34 | solverOptions.lambda = 0.01 35 | solverOptions.doWeightedAveraging = false 36 | solverOptions.doLineSearch = true 37 | solverOptions.debug = false 38 | 39 | solverOptions.sampleWithReplacement = false 40 | 41 | solverOptions.enableManualPartitionSize = true 42 | solverOptions.NUM_PART = 2 43 | 44 | solverOptions.sample = "frac" 45 | solverOptions.sampleFrac = 0.5 46 | 47 | solverOptions.enableOracleCache = false 48 | 49 | solverOptions.debugInfoPath = "../debug/debugInfo-%d.csv".format(System.currentTimeMillis()) 50 | 51 | val trainDataConverted: RDD[LabeledObject[Vector[Double], Double]] = 52 | trainData.map { 53 | case x: LabeledPoint => 54 | new LabeledObject[Vector[Double], Double](x.label, Vector(x.features.toArray)) 55 | } 56 | 57 | val testDataConverted: RDD[LabeledObject[Vector[Double], Double]] = 58 | testData.map { 59 | case x: LabeledPoint => 60 | new LabeledObject[Vector[Double], Double](x.label, Vector(x.features.toArray)) 61 | } 62 | 63 | solverOptions.testDataRDD = Some(testDataConverted) 64 | val model = BinarySVMWithDBCFW.train(trainData, solverOptions) 65 | 66 | // Training Errors 67 | val trueTrainingPredictions = 68 | trainDataConverted.map { 69 | case x: LabeledObject[Vector[Double], Double] => 70 | val prediction = model.predict(x.pattern) 71 | if (prediction == x.label) 72 | 1 73 | else 74 | 0 75 | }.fold(0)((acc, ele) => acc + ele) 76 | 77 | println("Accuracy on training set = %d/%d = %.4f".format(trueTrainingPredictions, 78 | trainDataConverted.count(), 79 | (trueTrainingPredictions.toDouble / trainDataConverted.count().toDouble) * 100)) 80 | 81 | // Test Errors 82 | val trueTestPredictions = 83 | testDataConverted.map { 84 | case x: LabeledObject[Vector[Double], Double] => 85 | val prediction = model.predict(x.pattern) 86 | if (prediction == x.label) 87 | 1 88 | else 89 | 0 90 | }.fold(0)((acc, ele) => acc + ele) 91 | 92 | println("Accuracy on test set = %d/%d = %.4f".format(trueTestPredictions, 93 | testDataConverted.count(), 94 | (trueTestPredictions.toDouble / testDataConverted.count().toDouble) * 100)) 95 | } 96 | 97 | /** 98 | * MLLib's SVMWithSGD implementation 99 | */ 100 | def mllibTrain(trainData: RDD[LabeledPoint], testData: RDD[LabeledPoint]) { 101 | println("running MLlib's standard gradient descent solver") 102 | 103 | // labels are assumed to be 0,1 for MLlib 104 | val trainDataConverted: RDD[LabeledPoint] = 105 | trainData.map { 106 | case x: LabeledPoint => 107 | new LabeledPoint(if (x.label > 0) 1 else 0, x.features) 108 | } 109 | 110 | val testDataConverted: RDD[LabeledPoint] = 111 | testData.map { 112 | case x: LabeledPoint => 113 | new LabeledPoint(if (x.label > 0) 1 else 0, x.features) 114 | } 115 | 116 | // Run training algorithm to build the model 117 | val model = SVMWithSGD.train(trainDataConverted, numIterations = 200, stepSize = 1.0, regParam = 0.01) 118 | 119 | // report accuracy on train and test 120 | val trainAcc = testDataConverted.map { point => 121 | val score = model.predict(point.features) 122 | score == point.label 123 | }.collect().toList.count(_ == true).toDouble / testDataConverted.count().toDouble 124 | 125 | val testAcc = testDataConverted.map { point => 126 | val score = model.predict(point.features) 127 | score == point.label 128 | }.collect().toList.count(_ == true).toDouble / testDataConverted.count().toDouble 129 | 130 | println("MLlib Training accuracy = " + trainAcc) 131 | println("MLlib Test accuracy = " + testAcc) 132 | 133 | } 134 | 135 | def main(args: Array[String]): Unit = { 136 | PropertyConfigurator.configure("conf/log4j.properties") 137 | 138 | val dataDir = "../data/generated" 139 | 140 | // Dataset: Adult (http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#a1a) 141 | val datasetFilename = "a1a" 142 | 143 | val conf = new SparkConf().setAppName("BinaryClassificationDemo").setMaster("local") 144 | val sc = new SparkContext(conf) 145 | sc.setCheckpointDir("checkpoint-files") 146 | 147 | // load dataset 148 | val data = MLUtils.loadLibSVMFile(sc, dataDir+"/"+datasetFilename) 149 | 150 | // Split data into training and test set 151 | val splits = data.randomSplit(Array(0.8, 0.2), seed = 1L) 152 | val trainingData = splits(0) 153 | val testData = splits(1) 154 | 155 | // Fix seed for reproducibility 156 | util.Random.setSeed(1) 157 | 158 | 159 | mllibTrain(trainingData, testData) 160 | 161 | dissolveTrain(trainingData, testData) 162 | } 163 | 164 | } 165 | -------------------------------------------------------------------------------- /dissolve-struct-examples/src/main/scala/ch/ethz/dalab/dissolve/examples/binaryclassification/COVBinary.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.examples.binaryclassification 2 | 3 | import org.apache.log4j.PropertyConfigurator 4 | import org.apache.spark.SparkConf 5 | import org.apache.spark.SparkContext 6 | import org.apache.spark.mllib.classification.SVMWithSGD 7 | import org.apache.spark.mllib.regression.LabeledPoint 8 | import org.apache.spark.mllib.util.MLUtils 9 | import org.apache.spark.rdd.RDD 10 | import breeze.linalg.Vector 11 | import ch.ethz.dalab.dissolve.classification.BinarySVMWithDBCFW 12 | import ch.ethz.dalab.dissolve.examples.utils.ExampleUtils 13 | import ch.ethz.dalab.dissolve.optimization.GapThresholdCriterion 14 | import ch.ethz.dalab.dissolve.optimization.RoundLimitCriterion 15 | import ch.ethz.dalab.dissolve.optimization.SolverOptions 16 | import ch.ethz.dalab.dissolve.optimization.TimeLimitCriterion 17 | import ch.ethz.dalab.dissolve.regression.LabeledObject 18 | import ch.ethz.dalab.dissolve.utils.cli.CLAParser 19 | import java.io.File 20 | 21 | object COVBinary { 22 | 23 | /** 24 | * MLLib's classifier 25 | */ 26 | def mllibCov() { 27 | val conf = new SparkConf().setAppName("Adult-example").setMaster("local") 28 | val sc = new SparkContext(conf) 29 | sc.setCheckpointDir("checkpoint-files") 30 | 31 | val data = MLUtils.loadLibSVMFile(sc, "../data/generated/covtype.libsvm.binary.scale.head.mllib") 32 | 33 | // Split data into training and test set 34 | val splits = data.randomSplit(Array(0.8, 0.2), seed = 1L) 35 | val training = splits(0) 36 | val test = splits(1) 37 | 38 | // Run training algorithm to build the model 39 | val numIterations = 1000 40 | val model = SVMWithSGD.train(training, numIterations) 41 | 42 | val trainError = training.map { point => 43 | val score = model.predict(point.features) 44 | score == point.label 45 | }.collect().toList.count(_ == true).toDouble / training.count().toDouble 46 | 47 | val testError = test.map { point => 48 | val score = model.predict(point.features) 49 | score == point.label 50 | }.collect().toList.count(_ == true).toDouble / test.count().toDouble 51 | 52 | println("Training accuracy = " + trainError) 53 | println("Test accuracy = " + testError) 54 | } 55 | 56 | /** 57 | * DBCFW classifier 58 | */ 59 | def dbcfwCov(args: Array[String]) { 60 | /** 61 | * Load all options 62 | */ 63 | val (solverOptions, kwargs) = CLAParser.argsToOptions[Vector[Double], Double](args) 64 | val covPath = kwargs.getOrElse("input_path", "../data/generated/covtype.libsvm.binary.scale") 65 | val appname = kwargs.getOrElse("appname", "cov_binary") 66 | val debugPath = kwargs.getOrElse("debug_file", "cov_binary-%d.csv".format(System.currentTimeMillis() / 1000)) 67 | solverOptions.debugInfoPath = debugPath 68 | 69 | println(covPath) 70 | println(kwargs) 71 | 72 | println("Current directory:" + new File(".").getAbsolutePath) 73 | 74 | // Fix seed for reproducibility 75 | util.Random.setSeed(1) 76 | 77 | val conf = new SparkConf().setAppName(appname) 78 | val sc = new SparkContext(conf) 79 | sc.setCheckpointDir("checkpoint-files") 80 | 81 | // Labels needs to be in a +1/-1 format 82 | val data = MLUtils 83 | .loadLibSVMFile(sc, covPath) 84 | .map { 85 | case x: LabeledPoint => 86 | val label = 87 | if (x.label == 1) 88 | +1.00 89 | else 90 | -1.00 91 | LabeledPoint(label, x.features) 92 | } 93 | 94 | // Split data into training and test set 95 | val splits = data.randomSplit(Array(0.8, 0.2), seed = 1L) 96 | val training = splits(0) 97 | val test = splits(1) 98 | 99 | val objectifiedTest: RDD[LabeledObject[Vector[Double], Double]] = 100 | test.map { 101 | case x: LabeledPoint => 102 | new LabeledObject[Vector[Double], Double](x.label, Vector(x.features.toArray)) // Is the asInstanceOf required? 103 | } 104 | 105 | solverOptions.testDataRDD = Some(objectifiedTest) 106 | val model = BinarySVMWithDBCFW.train(training, solverOptions) 107 | 108 | // Test Errors 109 | val trueTestPredictions = 110 | objectifiedTest.map { 111 | case x: LabeledObject[Vector[Double], Double] => 112 | val prediction = model.predict(x.pattern) 113 | if (prediction == x.label) 114 | 1 115 | else 116 | 0 117 | }.fold(0)((acc, ele) => acc + ele) 118 | 119 | println("Accuracy on Test set = %d/%d = %.4f".format(trueTestPredictions, 120 | objectifiedTest.count(), 121 | (trueTestPredictions.toDouble / objectifiedTest.count().toDouble) * 100)) 122 | } 123 | 124 | def main(args: Array[String]): Unit = { 125 | 126 | PropertyConfigurator.configure("conf/log4j.properties") 127 | 128 | System.setProperty("spark.akka.frameSize", "512") 129 | 130 | dbcfwCov(args) 131 | } 132 | 133 | } -------------------------------------------------------------------------------- /dissolve-struct-examples/src/main/scala/ch/ethz/dalab/dissolve/examples/binaryclassification/RCV1Binary.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.examples.binaryclassification 2 | 3 | import org.apache.log4j.PropertyConfigurator 4 | import org.apache.spark.SparkConf 5 | import org.apache.spark.SparkContext 6 | import org.apache.spark.mllib.regression.LabeledPoint 7 | import org.apache.spark.mllib.util.MLUtils 8 | import org.apache.spark.rdd.RDD 9 | import breeze.linalg.SparseVector 10 | import breeze.linalg.Vector 11 | import ch.ethz.dalab.dissolve.classification.BinarySVMWithDBCFW 12 | import ch.ethz.dalab.dissolve.examples.utils.ExampleUtils 13 | import ch.ethz.dalab.dissolve.optimization.GapThresholdCriterion 14 | import ch.ethz.dalab.dissolve.optimization.RoundLimitCriterion 15 | import ch.ethz.dalab.dissolve.optimization.SolverOptions 16 | import ch.ethz.dalab.dissolve.optimization.TimeLimitCriterion 17 | import ch.ethz.dalab.dissolve.regression.LabeledObject 18 | import ch.ethz.dalab.dissolve.utils.cli.CLAParser 19 | 20 | object RCV1Binary { 21 | 22 | def dbcfwRcv1(args: Array[String]) { 23 | /** 24 | * Load all options 25 | */ 26 | val (solverOptions, kwargs) = CLAParser.argsToOptions[Vector[Double], Double](args) 27 | val rcv1Path = kwargs.getOrElse("input_path", "../data/generated/rcv1_train.binary") 28 | val appname = kwargs.getOrElse("appname", "rcv1_binary") 29 | val debugPath = kwargs.getOrElse("debug_file", "rcv1_binary-%d.csv".format(System.currentTimeMillis() / 1000)) 30 | solverOptions.debugInfoPath = debugPath 31 | 32 | println(rcv1Path) 33 | println(kwargs) 34 | 35 | // Fix seed for reproducibility 36 | util.Random.setSeed(1) 37 | 38 | println(solverOptions.toString()) 39 | 40 | val conf = new SparkConf().setAppName(appname) 41 | val sc = new SparkContext(conf) 42 | sc.setCheckpointDir("checkpoint-files") 43 | 44 | // Labels needs to be in a +1/-1 format 45 | val data = MLUtils.loadLibSVMFile(sc, rcv1Path) 46 | 47 | // Split data into training and test set 48 | val splits = data.randomSplit(Array(0.8, 0.2), seed = 1L) 49 | val training = splits(0) 50 | val test = splits(1) 51 | 52 | val objectifiedTest: RDD[LabeledObject[Vector[Double], Double]] = 53 | test.map { 54 | case x: LabeledPoint => 55 | new LabeledObject[Vector[Double], Double](x.label, SparseVector(x.features.toArray)) // Is the asInstanceOf required? 56 | } 57 | 58 | solverOptions.testDataRDD = Some(objectifiedTest) 59 | val model = BinarySVMWithDBCFW.train(training, solverOptions) 60 | 61 | // Test Errors 62 | val trueTestPredictions = 63 | objectifiedTest.map { 64 | case x: LabeledObject[Vector[Double], Double] => 65 | val prediction = model.predict(x.pattern) 66 | if (prediction == x.label) 67 | 1 68 | else 69 | 0 70 | }.fold(0)((acc, ele) => acc + ele) 71 | 72 | println("Accuracy on Test set = %d/%d = %.4f".format(trueTestPredictions, 73 | objectifiedTest.count(), 74 | (trueTestPredictions.toDouble / objectifiedTest.count().toDouble) * 100)) 75 | } 76 | 77 | def main(args: Array[String]): Unit = { 78 | 79 | PropertyConfigurator.configure("conf/log4j.properties") 80 | 81 | System.setProperty("spark.akka.frameSize", "512") 82 | 83 | dbcfwRcv1(args) 84 | 85 | } 86 | 87 | } -------------------------------------------------------------------------------- /dissolve-struct-examples/src/main/scala/ch/ethz/dalab/dissolve/examples/imageseg/ImageSeg.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.examples.imageseg 2 | 3 | import scala.collection.mutable.ArrayBuffer 4 | 5 | import breeze.linalg.DenseMatrix 6 | import breeze.linalg.DenseVector 7 | import breeze.linalg.Vector 8 | import breeze.linalg.normalize 9 | import cc.factorie.infer.MaximizeByBPLoopy 10 | import cc.factorie.la.DenseTensor1 11 | import cc.factorie.la.Tensor 12 | import cc.factorie.model._ 13 | import cc.factorie.singleFactorIterable 14 | import cc.factorie.variable.DiscreteDomain 15 | import cc.factorie.variable.DiscreteVariable 16 | import ch.ethz.dalab.dissolve.classification.StructSVMModel 17 | import ch.ethz.dalab.dissolve.examples.imageseg.ImageSegTypes.AdjacencyList 18 | import ch.ethz.dalab.dissolve.examples.imageseg.ImageSegTypes.Label 19 | import ch.ethz.dalab.dissolve.examples.imageseg.ImageSegTypes.RGB_INT 20 | import ch.ethz.dalab.dissolve.optimization.DissolveFunctions 21 | 22 | /** 23 | * `data` is a `d` x `N` matrix; each column contains the features of super-pixel 24 | * `transitions` is an adjacency list. transitions(i) contains neighbours of super-pixel `i+1`'th super-pixel 25 | * `pixelMapping` contains mapping of super-pixel to corresponding pixel in the original image (column-major) 26 | */ 27 | case class QuantizedImage(unaries: DenseMatrix[Double], 28 | pairwise: Array[Array[Int]], 29 | pixelMapping: Array[Int], 30 | width: Int, 31 | height: Int, 32 | filename: String = "NA", 33 | unaryFeatures: DenseMatrix[Double] = null, 34 | rgbArray: Array[RGB_INT] = null, 35 | globalFeatures: Vector[Double] = null) 36 | 37 | /** 38 | * labels(i) contains label for `i`th super-pixel 39 | */ 40 | case class QuantizedLabel(labels: Array[Int], 41 | filename: String = "NA") 42 | 43 | /** 44 | * Functions for dissolve^struct 45 | * 46 | * Designed for the MSRC-21 dataset 47 | */ 48 | object ImageSeg 49 | extends DissolveFunctions[QuantizedImage, QuantizedLabel] { 50 | 51 | val NUM_CLASSES: Int = 22 // # Classes (0-indexed) 52 | val BACKGROUND_CLASS: Int = 21 // The last label 53 | 54 | val INTENSITY_LEVELS: Int = 8 55 | val NUM_BINS = INTENSITY_LEVELS * INTENSITY_LEVELS * INTENSITY_LEVELS // Size of feature vector x_i 56 | 57 | var DISABLE_PAIRWISE = false 58 | 59 | /** 60 | * ======= Joint Feature Map ======= 61 | */ 62 | def featureFn(x: QuantizedImage, y: QuantizedLabel): Vector[Double] = { 63 | 64 | val numSuperpixels = x.unaries.cols // # Super-pixels 65 | val classifierScore = x.unaries.rows // Score of SP per label 66 | 67 | assert(numSuperpixels == x.pairwise.length, 68 | "numSuperpixels == x.pairwise.length") 69 | 70 | /** 71 | * Unary Features 72 | */ 73 | val d = x.unaryFeatures.rows 74 | assert(d == NUM_BINS, "d == NUM_BINS") 75 | 76 | val dCombined = 77 | if (x.globalFeatures != null) 78 | d + x.globalFeatures.size 79 | else 80 | d 81 | 82 | val unaryFeatures = DenseMatrix.zeros[Double](dCombined, NUM_CLASSES) 83 | for (superIdx <- 0 until numSuperpixels) { 84 | val x_i = x.unaryFeatures(::, superIdx) 85 | val x_global = x.globalFeatures 86 | 87 | val x_comb = 88 | if (x_global == null) 89 | x_i 90 | else 91 | Vector(Array.concat(x_i.toArray, x_global.toArray)) 92 | val label = y.labels(superIdx) 93 | unaryFeatures(::, label) += x_comb 94 | } 95 | 96 | if (DISABLE_PAIRWISE) 97 | unaryFeatures.toDenseVector 98 | else { 99 | /** 100 | * Pairwise features 101 | */ 102 | val transitions = DenseMatrix.zeros[Double](NUM_CLASSES, NUM_CLASSES) 103 | for (superIdx <- 0 until numSuperpixels) { 104 | val thisLabel = y.labels(superIdx) 105 | 106 | x.pairwise(superIdx).foreach { 107 | case adjacentSuperIdx => 108 | val nextLabel = y.labels(adjacentSuperIdx) 109 | 110 | transitions(thisLabel, nextLabel) += 1.0 111 | transitions(nextLabel, thisLabel) += 1.0 112 | } 113 | } 114 | DenseVector.vertcat(unaryFeatures.toDenseVector, 115 | normalize(transitions.toDenseVector, 2)) 116 | } 117 | } 118 | 119 | /** 120 | * Per-label Hamming loss 121 | */ 122 | def perLabelLoss(labTruth: Label, labPredict: Label): Double = 123 | if (labTruth == labPredict) 124 | 0.0 125 | else 126 | 1.0 127 | 128 | /** 129 | * ======= Structured Error Function ======= 130 | */ 131 | def lossFn(yTruth: QuantizedLabel, yPredict: QuantizedLabel): Double = { 132 | 133 | assert(yTruth.labels.size == yPredict.labels.size, 134 | "Failed: yTruth.labels.size == yPredict.labels.size") 135 | 136 | val stuctHammingLoss = yTruth.labels 137 | .zip(yPredict.labels) 138 | .map { 139 | case (labTruth, labPredict) => 140 | perLabelLoss(labTruth, labPredict) 141 | } 142 | 143 | // Return normalized hamming loss 144 | stuctHammingLoss.sum / stuctHammingLoss.length 145 | } 146 | 147 | /** 148 | * Construct Factor graph and run MAP 149 | * (Max-product using Loopy Belief Propogation) 150 | */ 151 | def decode(unaryPot: DenseMatrix[Double], 152 | pairwisePot: DenseMatrix[Double], 153 | adj: AdjacencyList): Array[Label] = { 154 | 155 | val nSuperpixels = unaryPot.cols 156 | val nClasses = unaryPot.rows 157 | 158 | assert(nClasses == NUM_CLASSES) 159 | if (!DISABLE_PAIRWISE) 160 | assert(pairwisePot.rows == NUM_CLASSES) 161 | 162 | object PixelDomain extends DiscreteDomain(nClasses) 163 | 164 | class Pixel(i: Int) extends DiscreteVariable(i) { 165 | def domain = PixelDomain 166 | } 167 | 168 | def getUnaryFactor(yi: Pixel, idx: Int): Factor = { 169 | new Factor1(yi) { 170 | val weights: DenseTensor1 = new DenseTensor1(unaryPot(::, idx).toArray) 171 | def score(k: Pixel#Value) = unaryPot(k.intValue, idx) 172 | override def valuesScore(tensor: Tensor): Double = { 173 | weights dot tensor 174 | } 175 | } 176 | } 177 | 178 | def getPairwiseFactor(yi: Pixel, yj: Pixel): Factor = { 179 | new Factor2(yi, yj) { 180 | val weights: DenseTensor1 = new DenseTensor1(pairwisePot.toArray) 181 | def score(i: Pixel#Value, j: Pixel#Value) = pairwisePot(i.intValue, j.intValue) 182 | override def valuesScore(tensor: Tensor): Double = { 183 | weights dot tensor 184 | } 185 | } 186 | } 187 | 188 | val pixelSeq: IndexedSeq[Pixel] = 189 | (0 until nSuperpixels).map(x => new Pixel(12)) 190 | 191 | val unaryFactors: IndexedSeq[Factor] = 192 | (0 until nSuperpixels).map { 193 | case idx => 194 | getUnaryFactor(pixelSeq(idx), idx) 195 | } 196 | 197 | val model = new ItemizedModel 198 | model ++= unaryFactors 199 | 200 | if (!DISABLE_PAIRWISE) { 201 | val pairwiseFactors = 202 | (0 until nSuperpixels).flatMap { 203 | case thisIdx => 204 | val thisFactors = new ArrayBuffer[Factor] 205 | 206 | adj(thisIdx).foreach { 207 | case nextIdx => 208 | thisFactors ++= 209 | getPairwiseFactor(pixelSeq(thisIdx), pixelSeq(nextIdx)) 210 | } 211 | thisFactors 212 | } 213 | model ++= pairwiseFactors 214 | } 215 | 216 | MaximizeByBPLoopy.maximize(pixelSeq, model) 217 | 218 | val mapLabels: Array[Label] = (0 until nSuperpixels).map { 219 | idx => 220 | pixelSeq(idx).intValue 221 | }.toArray 222 | 223 | mapLabels 224 | } 225 | 226 | /** 227 | * Unpack weight vector to Unary and Pairwise weights 228 | */ 229 | 230 | def unpackWeightVec(weightv: DenseVector[Double], d: Int): (DenseMatrix[Double], DenseMatrix[Double]) = { 231 | 232 | assert(weightv.size >= (NUM_CLASSES * d)) 233 | 234 | val unaryWeights = weightv(0 until NUM_CLASSES * d) 235 | val unaryWeightMat = unaryWeights.toDenseMatrix.reshape(d, NUM_CLASSES) 236 | 237 | val pairwisePot = 238 | if (!DISABLE_PAIRWISE) { 239 | assert(weightv.size == (NUM_CLASSES * d) + (NUM_CLASSES * NUM_CLASSES)) 240 | val pairwiseWeights = weightv((NUM_CLASSES * d) until weightv.size) 241 | pairwiseWeights.toDenseMatrix.reshape(NUM_CLASSES, NUM_CLASSES) 242 | } else null 243 | 244 | (unaryWeightMat, pairwisePot) 245 | } 246 | 247 | /** 248 | * ======= Maximization Oracle ======= 249 | */ 250 | override def oracleFn(model: StructSVMModel[QuantizedImage, QuantizedLabel], 251 | xi: QuantizedImage, 252 | yi: QuantizedLabel): QuantizedLabel = { 253 | 254 | val nSuperpixels = xi.unaryFeatures.cols 255 | val d = xi.unaryFeatures.rows 256 | val dComb = 257 | if (xi.globalFeatures == null) 258 | d 259 | else 260 | d + xi.globalFeatures.size 261 | 262 | assert(xi.pairwise.length == nSuperpixels, 263 | "xi.pairwise.length == nSuperpixels") 264 | assert(xi.unaryFeatures.cols == nSuperpixels, 265 | "xi.unaryFeatures.cols == nSuperpixels") 266 | 267 | val (unaryWeights, pairwisePot) = unpackWeightVec(model.weights.toDenseVector, dComb) 268 | val localFeatures = 269 | if (xi.globalFeatures == null) 270 | xi.unaryFeatures 271 | else { 272 | // Concatenate global features to local features 273 | // The order is : local || global 274 | val dGlob = xi.globalFeatures.size 275 | val glob = xi.globalFeatures.toDenseVector 276 | val globalFeatures = DenseMatrix.zeros[Double](dGlob, nSuperpixels) 277 | for (superIdx <- 0 until nSuperpixels) { 278 | globalFeatures(::, superIdx) := glob 279 | } 280 | DenseMatrix.vertcat(xi.unaryFeatures, globalFeatures) 281 | } 282 | val unaryPot = unaryWeights.t * localFeatures 283 | 284 | if (yi != null) { 285 | assert(yi.labels.length == xi.pairwise.length, 286 | "yi.labels.length == xi.pairwise.length") 287 | 288 | // Loss augment the scores 289 | for (superIdx <- 0 until nSuperpixels) { 290 | val trueLabel = yi.labels(superIdx) 291 | // FIXME Use \delta here 292 | unaryPot(::, superIdx) += (1.0 / nSuperpixels) 293 | unaryPot(trueLabel, superIdx) -= (1.0 / nSuperpixels) 294 | } 295 | } 296 | 297 | val t0 = System.currentTimeMillis() 298 | val decodedLabels = decode(unaryPot, pairwisePot, xi.pairwise) 299 | val oracleSolution = QuantizedLabel(decodedLabels, xi.filename) 300 | val t1 = System.currentTimeMillis() 301 | 302 | oracleSolution 303 | } 304 | 305 | def predictFn(model: StructSVMModel[QuantizedImage, QuantizedLabel], 306 | xi: QuantizedImage): QuantizedLabel = { 307 | oracleFn(model, xi, null) 308 | } 309 | 310 | } -------------------------------------------------------------------------------- /dissolve-struct-examples/src/main/scala/ch/ethz/dalab/dissolve/examples/imageseg/ImageSegRunner.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.examples.imageseg 2 | 3 | import java.nio.file.Paths 4 | import org.apache.log4j.PropertyConfigurator 5 | import org.apache.spark.SparkConf 6 | import org.apache.spark.SparkContext 7 | import ch.ethz.dalab.dissolve.classification.StructSVMModel 8 | import ch.ethz.dalab.dissolve.classification.StructSVMWithDBCFW 9 | import ch.ethz.dalab.dissolve.utils.cli.CLAParser 10 | import javax.imageio.ImageIO 11 | 12 | /** 13 | * @author torekond 14 | */ 15 | object ImageSegRunner { 16 | 17 | def main(args: Array[String]): Unit = { 18 | 19 | PropertyConfigurator.configure("conf/log4j.properties") 20 | System.setProperty("spark.akka.frameSize", "512") 21 | 22 | val startTime = System.currentTimeMillis() / 1000 23 | 24 | /** 25 | * Load all options 26 | */ 27 | val (solverOptions, kwargs) = CLAParser.argsToOptions[QuantizedImage, QuantizedLabel](args) 28 | val dataDir = kwargs.getOrElse("input_path", "../data/generated/msrc") 29 | val appname = kwargs.getOrElse("appname", "imageseg-%d".format(startTime)) 30 | val debugPath = kwargs.getOrElse("debug_file", "imageseg-%d.csv".format(startTime)) 31 | 32 | val unariesOnly = kwargs.getOrElse("unaries", "true").toBoolean 33 | 34 | val trainFile = kwargs.getOrElse("train", "Train.txt") 35 | val validationFile = kwargs.getOrElse("validation", "Validation.txt") 36 | solverOptions.debugInfoPath = debugPath 37 | 38 | println(dataDir) 39 | println(kwargs) 40 | 41 | solverOptions.doLineSearch = true 42 | 43 | if (unariesOnly) 44 | ImageSeg.DISABLE_PAIRWISE = true 45 | 46 | /** 47 | * Setup Spark 48 | */ 49 | val conf = new SparkConf().setAppName(appname).setMaster("local") 50 | val sc = new SparkContext(conf) 51 | sc.setCheckpointDir("checkpoint-files") 52 | 53 | val trainFilePath = Paths.get(dataDir, trainFile) 54 | val valFilePath = Paths.get(dataDir, validationFile) 55 | 56 | val trainDataSeq = ImageSegUtils.loadData(dataDir, trainFilePath) 57 | val valDataSeq = ImageSegUtils.loadData(dataDir, valFilePath) 58 | 59 | val trainData = sc.parallelize(trainDataSeq, 1).cache 60 | val valData = sc.parallelize(valDataSeq, 1).cache 61 | 62 | solverOptions.testDataRDD = Some(valData) 63 | 64 | println(solverOptions) 65 | 66 | val trainer: StructSVMWithDBCFW[QuantizedImage, QuantizedLabel] = 67 | new StructSVMWithDBCFW[QuantizedImage, QuantizedLabel]( 68 | trainData, 69 | ImageSeg, 70 | solverOptions) 71 | 72 | val model: StructSVMModel[QuantizedImage, QuantizedLabel] = trainer.trainModel() 73 | 74 | // Create directories for image out, if it doesn't exist 75 | val imageOutDir = Paths.get(dataDir, "debug", appname) 76 | if (!imageOutDir.toFile().exists()) 77 | imageOutDir.toFile().mkdirs() 78 | 79 | println("Test time!") 80 | for (lo <- trainDataSeq) { 81 | val t0 = System.currentTimeMillis() 82 | val prediction = model.predict(lo.pattern) 83 | val t1 = System.currentTimeMillis() 84 | 85 | val filename = lo.pattern.filename 86 | val format = "bmp" 87 | val outPath = Paths.get(imageOutDir.toString(), "train-%s.%s".format(filename, format)) 88 | 89 | // Image 90 | val imgPath = Paths.get(dataDir.toString(), "All", "%s.bmp".format(filename)) 91 | val img = ImageIO.read(imgPath.toFile()) 92 | 93 | val width = img.getWidth() 94 | val height = img.getHeight() 95 | 96 | // Write loss info 97 | val predictTime: Long = t1 - t0 98 | val loss: Double = ImageSeg.lossFn(prediction, lo.label) 99 | val text = "filename = %s\nprediction time = %d ms\nerror = %f\n#spx = %d".format(filename, predictTime, loss, lo.pattern.unaries.cols) 100 | val textInfoImg = ImageSegUtils.getImageWithText(width, height, text) 101 | 102 | // GT 103 | val gtImage = ImageSegUtils.getQuantizedLabelImage(lo.label, 104 | lo.pattern.pixelMapping, 105 | lo.pattern.width, 106 | lo.pattern.height) 107 | 108 | // Prediction 109 | val predImage = ImageSegUtils.getQuantizedLabelImage(prediction, 110 | lo.pattern.pixelMapping, 111 | lo.pattern.width, 112 | lo.pattern.height) 113 | 114 | val prettyOut = ImageSegUtils.printImageTile(img, gtImage, textInfoImg, predImage) 115 | ImageSegUtils.writeImage(prettyOut, outPath.toString()) 116 | 117 | } 118 | 119 | for (lo <- valDataSeq) { 120 | val t0 = System.currentTimeMillis() 121 | val prediction = model.predict(lo.pattern) 122 | val t1 = System.currentTimeMillis() 123 | 124 | val filename = lo.pattern.filename 125 | val format = "bmp" 126 | val outPath = Paths.get(imageOutDir.toString(), "val-%s.%s".format(filename, format)) 127 | 128 | // Image 129 | val imgPath = Paths.get(dataDir.toString(), "All", "%s.bmp".format(filename)) 130 | val img = ImageIO.read(imgPath.toFile()) 131 | 132 | val width = img.getWidth() 133 | val height = img.getHeight() 134 | 135 | // Write loss info 136 | val predictTime: Long = t1 - t0 137 | val loss: Double = ImageSeg.lossFn(prediction, lo.label) 138 | val text = "filename = %s\nprediction time = %d ms\nerror = %f\n#spx = %d".format(filename, predictTime, loss, lo.pattern.unaries.cols) 139 | val textInfoImg = ImageSegUtils.getImageWithText(width, height, text) 140 | 141 | // GT 142 | val gtImage = ImageSegUtils.getQuantizedLabelImage(lo.label, 143 | lo.pattern.pixelMapping, 144 | lo.pattern.width, 145 | lo.pattern.height) 146 | 147 | // Prediction 148 | val predImage = ImageSegUtils.getQuantizedLabelImage(prediction, 149 | lo.pattern.pixelMapping, 150 | lo.pattern.width, 151 | lo.pattern.height) 152 | 153 | val prettyOut = ImageSegUtils.printImageTile(img, gtImage, textInfoImg, predImage) 154 | ImageSegUtils.writeImage(prettyOut, outPath.toString()) 155 | 156 | } 157 | 158 | val unaryDebugPath = 159 | Paths.get("/home/torekond/dev-local/dissolve-struct/data/generated/msrc/debug", 160 | "%s-unary.csv".format(appname)) 161 | val transDebugPath = 162 | Paths.get("/home/torekond/dev-local/dissolve-struct/data/generated/msrc/debug", 163 | "%s-trans.csv".format(appname)) 164 | val weights = model.getWeights().toDenseVector 165 | 166 | val xi = trainDataSeq(0).pattern 167 | val d = 168 | if (xi.globalFeatures == null) 169 | xi.unaryFeatures.rows 170 | else 171 | xi.unaryFeatures.rows + xi.globalFeatures.size 172 | val (unaryMat, transMat) = ImageSeg.unpackWeightVec(weights, d) 173 | 174 | breeze.linalg.csvwrite(unaryDebugPath.toFile(), unaryMat) 175 | if (transMat != null) 176 | breeze.linalg.csvwrite(transDebugPath.toFile(), transMat) 177 | 178 | } 179 | 180 | } -------------------------------------------------------------------------------- /dissolve-struct-examples/src/main/scala/ch/ethz/dalab/dissolve/examples/imageseg/ImageSegTypes.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.examples.imageseg 2 | 3 | /** 4 | * @author torekond 5 | */ 6 | object ImageSegTypes { 7 | type Label = Int 8 | type Index = Int 9 | type SuperIndex = Int 10 | type RGB_INT = Int 11 | type AdjacencyList = Array[Array[SuperIndex]] 12 | } -------------------------------------------------------------------------------- /dissolve-struct-examples/src/main/scala/ch/ethz/dalab/dissolve/examples/multiclass/COVMulticlass.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.examples.multiclass 2 | 3 | import org.apache.log4j.PropertyConfigurator 4 | import org.apache.spark.SparkConf 5 | import org.apache.spark.SparkContext 6 | import org.apache.spark.mllib.regression.LabeledPoint 7 | import org.apache.spark.mllib.util.MLUtils 8 | import org.apache.spark.rdd.RDD 9 | 10 | import breeze.linalg.Vector 11 | import ch.ethz.dalab.dissolve.classification.MultiClassLabel 12 | import ch.ethz.dalab.dissolve.classification.MultiClassSVMWithDBCFW 13 | import ch.ethz.dalab.dissolve.regression.LabeledObject 14 | import ch.ethz.dalab.dissolve.utils.cli.CLAParser 15 | 16 | object COVMulticlass { 17 | 18 | def dissoveCovMulti(args: Array[String]) { 19 | 20 | /** 21 | * Load all options 22 | */ 23 | val (solverOptions, kwargs) = CLAParser.argsToOptions[Vector[Double], MultiClassLabel](args) 24 | val covPath = kwargs.getOrElse("input_path", "../data/generated/covtype.scale") 25 | val appname = kwargs.getOrElse("appname", "cov_multi") 26 | val debugPath = kwargs.getOrElse("debug_file", "cov_multi-%d.csv".format(System.currentTimeMillis() / 1000)) 27 | solverOptions.debugInfoPath = debugPath 28 | 29 | println(covPath) 30 | println(kwargs) 31 | 32 | // Fix seed for reproducibility 33 | util.Random.setSeed(1) 34 | 35 | val conf = new SparkConf().setAppName(appname) 36 | 37 | val sc = new SparkContext(conf) 38 | sc.setCheckpointDir("checkpoint-files") 39 | 40 | // Needs labels \in [0, numClasses) 41 | val data: RDD[LabeledPoint] = MLUtils 42 | .loadLibSVMFile(sc, covPath) 43 | .map { 44 | case x: LabeledPoint => 45 | val label = x.label - 1 46 | LabeledPoint(label, x.features) 47 | } 48 | 49 | val minlabel = data.map(_.label).min() 50 | val maxlabel = data.map(_.label).max() 51 | println("min = %f, max = %f".format(minlabel, maxlabel)) 52 | 53 | // Split data into training and test set 54 | val splits = data.randomSplit(Array(0.8, 0.2), seed = 1L) 55 | val training = splits(0) 56 | val test = splits(1) 57 | 58 | val numClasses = 7 59 | 60 | val objectifiedTest: RDD[LabeledObject[Vector[Double], MultiClassLabel]] = 61 | test.map { 62 | case x: LabeledPoint => 63 | new LabeledObject[Vector[Double], MultiClassLabel](MultiClassLabel(x.label, numClasses), 64 | Vector(x.features.toArray)) 65 | } 66 | 67 | solverOptions.testDataRDD = Some(objectifiedTest) 68 | val model = MultiClassSVMWithDBCFW.train(data, numClasses, solverOptions) 69 | 70 | // Test Errors 71 | val trueTestPredictions = 72 | objectifiedTest.map { 73 | case x: LabeledObject[Vector[Double], MultiClassLabel] => 74 | val prediction = model.predict(x.pattern) 75 | if (prediction == x.label) 76 | 1 77 | else 78 | 0 79 | }.fold(0)((acc, ele) => acc + ele) 80 | 81 | println("Accuracy on Test set = %d/%d = %.4f".format(trueTestPredictions, 82 | objectifiedTest.count(), 83 | (trueTestPredictions.toDouble / objectifiedTest.count().toDouble) * 100)) 84 | 85 | } 86 | 87 | def main(args: Array[String]): Unit = { 88 | 89 | PropertyConfigurator.configure("conf/log4j.properties") 90 | 91 | System.setProperty("spark.akka.frameSize", "512") 92 | 93 | dissoveCovMulti(args) 94 | } 95 | 96 | } -------------------------------------------------------------------------------- /dissolve-struct-examples/src/main/scala/ch/ethz/dalab/dissolve/examples/utils/ExampleUtils.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.examples.utils 2 | 3 | import scala.io.Source 4 | import scala.util.Random 5 | 6 | object ExampleUtils { 7 | 8 | /** 9 | * Obtained adj and noun from: https://github.com/kohsuke/wordnet-random-name 10 | */ 11 | val adjFilename = "adj.txt" 12 | val nounFilename = "noun.txt" 13 | 14 | val adjList = Source.fromURL(getClass.getResource("/adj.txt")).getLines().toArray 15 | val nounList = Source.fromURL(getClass.getResource("/noun.txt")).getLines().toArray 16 | 17 | def getRandomElement[T](lst: Seq[T]): T = lst(Random.nextInt(lst.size)) 18 | 19 | def generateExperimentName(prefix: Seq[String] = List.empty, suffix: Seq[String] = List.empty, separator: String = "-"): String = { 20 | 21 | val nameList: Seq[String] = prefix ++ List(getRandomElement(adjList), getRandomElement(nounList)) ++ suffix 22 | 23 | val separatedNameList = nameList 24 | .flatMap { x => x :: separator :: Nil } // Juxtapose with separators 25 | .dropRight(1) // Drop the last separator 26 | 27 | separatedNameList.reduce(_ + _) 28 | } 29 | 30 | def main(args: Array[String]): Unit = { 31 | println(generateExperimentName()) 32 | } 33 | 34 | } -------------------------------------------------------------------------------- /dissolve-struct-examples/src/test/scala/ch/ethz/dalab/dissolve/diagnostics/FeatureFnSpec.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.diagnostics 2 | 3 | /** 4 | * @author torekond 5 | */ 6 | class FeatureFnSpec extends UnitSpec { 7 | 8 | val NUM_ATTEMPTS = 100 // = # times each test case is attempted 9 | 10 | // A sample datapoint 11 | val lo = data(0) 12 | // Size of joint feature map 13 | val d = phi(lo.pattern, lo.label).size 14 | // No. of data examples 15 | val M = data.length 16 | 17 | "dim( ϕ(x_m, y_m) )" should "be fixed for all GIVEN (x_m, y_m)" in { 18 | 19 | val dimDiffSeq: Seq[Int] = for (k <- 0 until NUM_ATTEMPTS) yield { 20 | 21 | // Choose a random example 22 | val m = scala.util.Random.nextInt(M) 23 | val lo = data(m) 24 | val x_m = lo.pattern 25 | val y_m = lo.label 26 | 27 | phi(x_m, y_m).size - d 28 | 29 | } 30 | 31 | // This should be empty 32 | val uneqDimDiffSeq: Seq[Int] = dimDiffSeq.filter(_ != 0) 33 | 34 | assert(uneqDimDiffSeq.length == 0, 35 | "%d / %d cases failed".format(uneqDimDiffSeq.length, dimDiffSeq.length)) 36 | 37 | } 38 | 39 | it should "be fixed for all PERTURBED (x_m, y_m)" in { 40 | 41 | val dimDiffSeq: Seq[Int] = for (k <- 0 until NUM_ATTEMPTS) yield { 42 | 43 | // Choose a random example 44 | val m = scala.util.Random.nextInt(M) 45 | val lo = data(m) 46 | val x_m = lo.pattern 47 | val y_m = perturb(lo.label, 0.5) 48 | 49 | phi(x_m, y_m).size - d 50 | 51 | } 52 | 53 | // This should be empty 54 | val uneqDimDiffSeq: Seq[Int] = dimDiffSeq.filter(_ != 0) 55 | 56 | assert(uneqDimDiffSeq.length == 0, 57 | "%d / %d cases failed".format(uneqDimDiffSeq.length, dimDiffSeq.length)) 58 | 59 | } 60 | } -------------------------------------------------------------------------------- /dissolve-struct-examples/src/test/scala/ch/ethz/dalab/dissolve/diagnostics/OracleSpec.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.diagnostics 2 | 3 | import org.scalatest.FlatSpec 4 | import breeze.linalg._ 5 | import scala.collection.mutable.ArrayBuffer 6 | import ch.ethz.dalab.dissolve.regression.LabeledObject 7 | 8 | /** 9 | * @author torekond 10 | */ 11 | class OracleSpec extends UnitSpec { 12 | 13 | val NUM_WEIGHT_VECS = 500 // = # times each test case is attempted 14 | val EPSILON = 2.2204E-16 15 | 16 | // A sample datapoint 17 | val lo = data(0) 18 | // Size of joint feature map 19 | val d = phi(lo.pattern, lo.label).size 20 | // No. of data examples 21 | val M = data.length 22 | 23 | /** 24 | * Initialize a bunch of weight vectors 25 | */ 26 | type WeightVector = Vector[Double] 27 | val weightVectors: Array[WeightVector] = { 28 | val seq = new ArrayBuffer[WeightVector] 29 | 30 | // All 1's 31 | seq += DenseVector.ones[Double](d) 32 | 33 | // All 0's 34 | seq += DenseVector.zeros[Double](d) 35 | 36 | // Few random vectors 37 | for (k <- 0 until NUM_WEIGHT_VECS - seq.size) 38 | seq += DenseVector.rand(d) 39 | 40 | seq.toArray 41 | } 42 | 43 | /** 44 | * Tests - Structured Hinge Loss 45 | */ 46 | 47 | "Structured Hinge Loss [Δ(y_m, y*) - < w, ψ(x_m, y*) >]" should "be >= 0 for GIVEN (x_m, y_m) pairs" in { 48 | 49 | val shlSeq: Seq[Double] = for (k <- 0 until NUM_WEIGHT_VECS) yield { 50 | 51 | // Set weight vector 52 | val w: WeightVector = weightVectors(k) 53 | model.setWeights(w) 54 | 55 | // Choose a random example 56 | val m = scala.util.Random.nextInt(M) 57 | val lo = data(m) 58 | val x_m = lo.pattern 59 | val y_m = lo.label 60 | 61 | // Get loss-augmented argmax prediction 62 | val ystar = maxoracle(model, x_m, y_m) 63 | val shl = delta(y_m, ystar) - deltaF(lo, ystar, w) 64 | 65 | shl 66 | 67 | } 68 | 69 | // This should be empty 70 | val negShlSeq: Seq[Double] = shlSeq.filter(_ < -EPSILON) 71 | 72 | assert(negShlSeq.length == 0, 73 | "%d / %d cases failed".format(negShlSeq.length, shlSeq.length)) 74 | 75 | } 76 | 77 | it should "be >= 0 for PERTURBED (x_m, y_m) pairs" in { 78 | 79 | val shlSeq: Seq[Double] = for (k <- 0 until NUM_WEIGHT_VECS) yield { 80 | 81 | // Set weight vector 82 | val w: WeightVector = weightVectors(k) 83 | model.setWeights(w) 84 | 85 | // Sample a random (x, y) pair 86 | val m = scala.util.Random.nextInt(M) 87 | val x_m = data(m).pattern 88 | val y_m = perturb(data(m).label, 0.1) // Perturbed copy 89 | val lo = LabeledObject(y_m, x_m) 90 | 91 | // Get loss-augmented argmax prediction 92 | val ystar = maxoracle(model, x_m, y_m) 93 | val shl = delta(y_m, ystar) - deltaF(lo, ystar, w) 94 | 95 | shl 96 | 97 | } 98 | 99 | // This should be empty 100 | val negShlSeq: Seq[Double] = shlSeq.filter(_ < -EPSILON) 101 | 102 | assert(negShlSeq.length == 0, 103 | "%d / %d cases failed".format(negShlSeq.length, shlSeq.length)) 104 | 105 | } 106 | 107 | /** 108 | * Tests - Discriminant function 109 | */ 110 | "F(x_m, y*)" should "be >= F(x_m, y_m)" in { 111 | 112 | val diffSeq = for (k <- 0 until NUM_WEIGHT_VECS) yield { 113 | // Set weight vector 114 | val w: WeightVector = weightVectors(k) 115 | model.setWeights(w) 116 | 117 | // Choose a random example 118 | val m = scala.util.Random.nextInt(M) 119 | val lo = data(m) 120 | val x_m = lo.pattern 121 | val y_m = lo.label 122 | 123 | // Get argmax prediction 124 | val ystar = predict(model, x_m) 125 | 126 | val F_ystar = F(x_m, ystar, w) 127 | val F_gt = F(x_m, y_m, w) 128 | 129 | val diff = F_ystar - F_gt 130 | 131 | diff 132 | } 133 | 134 | /*println(diffSeq)*/ 135 | 136 | // This should be empty 137 | val negDiffSeq: Seq[Double] = diffSeq.filter(_ < -EPSILON) 138 | 139 | assert(negDiffSeq.length == 0, 140 | "%d / %d cases failed".format(negDiffSeq.length, diffSeq.length)) 141 | 142 | } 143 | 144 | it should "be >= PERTURBED F(x_m, y_m)" in { 145 | 146 | val diffSeq = for (k <- 0 until NUM_WEIGHT_VECS) yield { 147 | // Set weight vector 148 | val w: WeightVector = weightVectors(k) 149 | model.setWeights(w) 150 | 151 | // Choose a random example 152 | val m = scala.util.Random.nextInt(M) 153 | val lo = data(m) 154 | val x_m = lo.pattern 155 | val y_m = perturb(lo.label, 0.1) 156 | 157 | // Get argmax prediction 158 | val ystar = predict(model, x_m) 159 | 160 | val F_ystar = F(x_m, ystar, w) 161 | val F_gt = F(x_m, y_m, w) 162 | 163 | F_ystar - F_gt 164 | } 165 | 166 | /*println(diffSeq)*/ 167 | 168 | // This should be empty 169 | val negDiffSeq: Seq[Double] = diffSeq.filter(_ < -EPSILON) 170 | 171 | assert(negDiffSeq.length == 0, 172 | "%d / %d cases failed".format(negDiffSeq.length, diffSeq.length)) 173 | 174 | } 175 | 176 | "H(w; x_m, y_m)" should "be >= Δ(y_m, y_m) + F(x_m, y_m)" in { 177 | 178 | val diffSeq = for (k <- 0 until NUM_WEIGHT_VECS) yield { 179 | // Set weight vector 180 | val w: WeightVector = weightVectors(k) 181 | model.setWeights(w) 182 | 183 | // Choose a random example 184 | val m = scala.util.Random.nextInt(M) 185 | val lo = data(m) 186 | val x_m = lo.pattern 187 | val y_m = lo.label 188 | 189 | // Get loss-augmented argmax prediction 190 | val ystar = maxoracle(model, x_m, y_m) 191 | 192 | val H = delta(y_m, ystar) - deltaF(lo, ystar, w) 193 | val F_loss_aug = delta(y_m, y_m) - deltaF(lo, y_m, w) 194 | 195 | H - F_loss_aug 196 | } 197 | 198 | /*println(diffSeq)*/ 199 | 200 | // This should be empty 201 | val negDiffSeq: Seq[Double] = diffSeq.filter(_ < -EPSILON) 202 | 203 | assert(negDiffSeq.length == 0, 204 | "%d / %d cases failed".format(negDiffSeq.length, diffSeq.length)) 205 | } 206 | 207 | it should "be >= PERTURBED Δ(y_m, y_m) + F(x_m, y_m)" in { 208 | 209 | val diffSeq = for (k <- 0 until NUM_WEIGHT_VECS) yield { 210 | // Set weight vector 211 | val w: WeightVector = weightVectors(k) 212 | model.setWeights(w) 213 | 214 | // Choose a random example 215 | val m = scala.util.Random.nextInt(M) 216 | val lo = data(m) 217 | val x_m = lo.pattern 218 | val y_m = perturb(lo.label, 0.1) 219 | 220 | // Get loss-augmented argmax prediction 221 | val ystar = maxoracle(model, x_m, y_m) 222 | 223 | val H = delta(y_m, ystar) - deltaF(lo, ystar, w) 224 | val F_loss_aug = delta(y_m, y_m) - deltaF(lo, y_m, w) 225 | 226 | H - F_loss_aug 227 | } 228 | 229 | /*println(diffSeq)*/ 230 | 231 | // This should be empty 232 | val negDiffSeq: Seq[Double] = diffSeq.filter(_ < -EPSILON) 233 | 234 | assert(negDiffSeq.length == 0, 235 | "%d / %d cases failed".format(negDiffSeq.length, diffSeq.length)) 236 | } 237 | 238 | } -------------------------------------------------------------------------------- /dissolve-struct-examples/src/test/scala/ch/ethz/dalab/dissolve/diagnostics/StructLossSpec.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.diagnostics 2 | 3 | /** 4 | * @author torekond 5 | */ 6 | class StructLossSpec extends UnitSpec { 7 | 8 | val NUM_ATTEMPTS = 100 // = # times each test case is attempted 9 | 10 | // A sample datapoint 11 | val lo = data(0) 12 | // Size of joint feature map 13 | val d = phi(lo.pattern, lo.label).size 14 | // No. of data examples 15 | val M = data.length 16 | 17 | "Δ(y, y)" should "= 0" in { 18 | 19 | val lossSeq: Seq[Double] = for (k <- 0 until NUM_ATTEMPTS) yield { 20 | 21 | // Choose a random example 22 | val m = scala.util.Random.nextInt(M) 23 | val lo = data(m) 24 | val x_m = lo.pattern 25 | val y_m = lo.label 26 | 27 | delta(y_m, y_m) 28 | 29 | } 30 | 31 | // This should be empty 32 | val uneqLossSeq: Seq[Double] = lossSeq.filter(_ != 0.0) 33 | 34 | assert(uneqLossSeq.length == 0, 35 | "%d / %d cases failed".format(uneqLossSeq.length, lossSeq.length)) 36 | } 37 | 38 | "Δ(y, y')" should ">= 0" in { 39 | 40 | val lossSeq: Seq[Double] = for (k <- 0 until NUM_ATTEMPTS) yield { 41 | 42 | // Choose a random example 43 | val m = scala.util.Random.nextInt(M) 44 | val lo = data(m) 45 | val x_m = lo.pattern 46 | val degree = scala.util.Random.nextDouble() 47 | val y_m = perturb(lo.label, degree) 48 | 49 | delta(y_m, y_m) 50 | 51 | } 52 | 53 | // This should be empty 54 | val uneqLossSeq: Seq[Double] = lossSeq.filter(_ < 0.0) 55 | 56 | assert(uneqLossSeq.length == 0, 57 | "%d / %d cases failed".format(uneqLossSeq.length, lossSeq.length)) 58 | } 59 | 60 | } -------------------------------------------------------------------------------- /dissolve-struct-examples/src/test/scala/ch/ethz/dalab/dissolve/diagnostics/UnitSpec.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.diagnostics 2 | 3 | import java.nio.file.Paths 4 | import org.scalatest.FlatSpec 5 | import org.scalatest.Inside 6 | import org.scalatest.Inspectors 7 | import org.scalatest.Matchers 8 | import org.scalatest.OptionValues 9 | import breeze.linalg.DenseVector 10 | import breeze.linalg.Matrix 11 | import breeze.linalg.Vector 12 | import ch.ethz.dalab.dissolve.classification.StructSVMModel 13 | import ch.ethz.dalab.dissolve.examples.chain.ChainDemo 14 | import ch.ethz.dalab.dissolve.examples.imageseg.ImageSeg 15 | import ch.ethz.dalab.dissolve.examples.imageseg.ImageSegUtils 16 | import ch.ethz.dalab.dissolve.examples.imageseg.QuantizedImage 17 | import ch.ethz.dalab.dissolve.examples.imageseg.QuantizedLabel 18 | import ch.ethz.dalab.dissolve.optimization.DissolveFunctions 19 | import ch.ethz.dalab.dissolve.regression.LabeledObject 20 | import breeze.linalg.max 21 | 22 | object ImageTestAdapter { 23 | type X = QuantizedImage 24 | type Y = QuantizedLabel 25 | 26 | /** 27 | * Dissolve Functions 28 | */ 29 | val dissolveFunctions: DissolveFunctions[X, Y] = ImageSeg 30 | /** 31 | * Some Data 32 | */ 33 | val data = { 34 | val dataDir = "../data/generated/msrc" 35 | val trainFilePath = Paths.get(dataDir, "Train.txt") 36 | val trainDataSeq = ImageSegUtils.loadData(dataDir, trainFilePath, limit = 50) 37 | 38 | trainDataSeq 39 | } 40 | /** 41 | * A dummy model 42 | */ 43 | val lo = data(0) 44 | val numd = ImageSeg.featureFn(lo.pattern, lo.label).size 45 | val model: StructSVMModel[X, Y] = 46 | new StructSVMModel[X, Y](DenseVector.zeros(numd), 0.0, 47 | DenseVector.zeros(numd), dissolveFunctions, 1) 48 | 49 | def perturb(y: Y, degree: Double = 0.1): Y = { 50 | val d = y.labels.size 51 | val numSwaps = max(1, (degree * d).toInt) 52 | 53 | for (swapNo <- 0 until numSwaps) { 54 | // Swap two random values in y 55 | val (i, j) = (scala.util.Random.nextInt(d), scala.util.Random.nextInt(d)) 56 | val temp = y.labels(i) 57 | y.labels(i) = y.labels(j) 58 | y.labels(j) = temp 59 | } 60 | 61 | y 62 | 63 | } 64 | } 65 | 66 | object ChainTestAdapter { 67 | type X = Matrix[Double] 68 | type Y = Vector[Double] 69 | 70 | /** 71 | * Dissolve Functions 72 | */ 73 | val dissolveFunctions: DissolveFunctions[X, Y] = ChainDemo 74 | /** 75 | * Some Data 76 | */ 77 | val data = { 78 | val dataDir = "../data/generated" 79 | val trainDataSeq: Vector[LabeledObject[Matrix[Double], Vector[Double]]] = 80 | ChainDemo.loadData(dataDir + "/patterns_train.csv", 81 | dataDir + "/labels_train.csv", 82 | dataDir + "/folds_train.csv") 83 | 84 | trainDataSeq.toArray 85 | } 86 | /** 87 | * A dummy model 88 | */ 89 | val lo = data(0) 90 | val numd = ChainDemo.featureFn(lo.pattern, lo.label).size 91 | val model: StructSVMModel[X, Y] = 92 | new StructSVMModel[X, Y](DenseVector.zeros(numd), 0.0, 93 | DenseVector.zeros(numd), dissolveFunctions, 1) 94 | 95 | /** 96 | * Perturb 97 | * Return a compatible perturbed Y 98 | * Higher the degree, more perturbed y is 99 | * 100 | * This function perturbs `degree` of the values by swapping 101 | */ 102 | def perturb(y: Y, degree: Double = 0.1): Y = { 103 | val d = y.size 104 | val numSwaps = max(1, (degree * d).toInt) 105 | 106 | for (swapNo <- 0 until numSwaps) { 107 | // Swap two random values in y 108 | val (i, j) = (scala.util.Random.nextInt(d), scala.util.Random.nextInt(d)) 109 | val temp = y(i) 110 | y(i) = y(j) 111 | y(j) = temp 112 | } 113 | 114 | y 115 | 116 | } 117 | } 118 | 119 | /** 120 | * @author torekond 121 | */ 122 | abstract class UnitSpec extends FlatSpec with Matchers with OptionValues with Inside with Inspectors { 123 | 124 | val DissolveAdapter = ImageTestAdapter 125 | 126 | type X = DissolveAdapter.X 127 | type Y = DissolveAdapter.Y 128 | 129 | val dissolveFunctions = DissolveAdapter.dissolveFunctions 130 | val data = DissolveAdapter.data 131 | val model = DissolveAdapter.model 132 | 133 | /** 134 | * Helper functions 135 | */ 136 | def perturb = DissolveAdapter.perturb _ 137 | 138 | // Joint Feature Map 139 | def phi = dissolveFunctions.featureFn _ 140 | def delta = dissolveFunctions.lossFn _ 141 | def maxoracle = dissolveFunctions.oracleFn _ 142 | def predict = dissolveFunctions.predictFn _ 143 | 144 | def psi(lo: LabeledObject[X, Y], ymap: Y) = 145 | phi(lo.pattern, lo.label) - phi(lo.pattern, ymap) 146 | 147 | def F(x: X, y: Y, w: Vector[Double]) = 148 | w dot phi(x, y) 149 | def deltaF(lo: LabeledObject[X, Y], ystar: Y, w: Vector[Double]) = 150 | w dot psi(lo, ystar) 151 | 152 | } -------------------------------------------------------------------------------- /dissolve-struct-lib/.gitignore: -------------------------------------------------------------------------------- 1 | /bin/ 2 | -------------------------------------------------------------------------------- /dissolve-struct-lib/build.sbt: -------------------------------------------------------------------------------- 1 | name := "DissolveStruct" 2 | 3 | organization := "ch.ethz.dalab" 4 | 5 | version := "0.1-SNAPSHOT" 6 | 7 | scalaVersion := "2.10.4" 8 | 9 | libraryDependencies += "org.apache.spark" %% "spark-core" % "1.4.1" 10 | 11 | libraryDependencies += "org.apache.spark" %% "spark-mllib" % "1.4.1" 12 | 13 | libraryDependencies += "org.scalanlp" %% "breeze" % "0.11.1" 14 | 15 | libraryDependencies += "org.scalanlp" %% "breeze-natives" % "0.11.1" 16 | 17 | libraryDependencies += "org.scalatest" % "scalatest_2.10" % "2.0" % "test" 18 | 19 | libraryDependencies += "com.github.scopt" %% "scopt" % "3.3.0" 20 | 21 | resolvers += Resolver.sonatypeRepo("public") 22 | 23 | mergeStrategy in assembly <<= (mergeStrategy in assembly) { (old) => 24 | { 25 | case PathList("javax", "servlet", xs @ _*) => MergeStrategy.first 26 | case PathList(ps @ _*) if ps.last endsWith ".html" => MergeStrategy.first 27 | case "application.conf" => MergeStrategy.concat 28 | case "reference.conf" => MergeStrategy.concat 29 | case "log4j.properties" => MergeStrategy.discard 30 | case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard 31 | case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard 32 | case _ => MergeStrategy.first 33 | } 34 | } 35 | 36 | test in assembly := {} 37 | -------------------------------------------------------------------------------- /dissolve-struct-lib/lib/jython.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dalab/dissolve-struct/67f37377b74c32cf05d8f43a0e3658a10864f9bf/dissolve-struct-lib/lib/jython.jar -------------------------------------------------------------------------------- /dissolve-struct-lib/project/assembly.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.13.0") 2 | -------------------------------------------------------------------------------- /dissolve-struct-lib/project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "3.0.0") 2 | -------------------------------------------------------------------------------- /dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/classification/BinarySVMWithDBCFW.scala: -------------------------------------------------------------------------------- 1 | 2 | /** 3 | * 4 | */ 5 | package ch.ethz.dalab.dissolve.classification 6 | 7 | import java.io.FileWriter 8 | import org.apache.spark.mllib.regression.LabeledPoint 9 | import org.apache.spark.rdd.RDD 10 | import breeze.linalg.DenseVector 11 | import breeze.linalg.SparseVector 12 | import breeze.linalg.Vector 13 | import ch.ethz.dalab.dissolve.optimization.DBCFWSolverTuned 14 | import ch.ethz.dalab.dissolve.optimization.DissolveFunctions 15 | import ch.ethz.dalab.dissolve.optimization.SolverOptions 16 | import ch.ethz.dalab.dissolve.optimization.SolverUtils 17 | import ch.ethz.dalab.dissolve.regression.LabeledObject 18 | import breeze.linalg.VectorBuilder 19 | import scala.collection.mutable.HashMap 20 | import org.apache.spark.rdd.PairRDDFunctions 21 | 22 | /** 23 | * @author tribhu 24 | * 25 | */ 26 | object BinarySVMWithDBCFW extends DissolveFunctions[Vector[Double], Double] { 27 | 28 | val labelToWeight = HashMap[Double, Double]() 29 | 30 | override def classWeights(label: Double): Double = { 31 | labelToWeight.get(label).getOrElse(3.0) 32 | } 33 | 34 | def generateClassWeights(data: RDD[LabeledPoint]): Unit = { 35 | val labels: Array[Double] = data.map { x => x.label }.distinct().collect() 36 | 37 | val classOccur: PairRDDFunctions[Double, Double] = data.map(x => (x.label, 1.0)) 38 | val labelOccur: PairRDDFunctions[Double, Double] = classOccur.reduceByKey((x, y) => x + y) 39 | val labelWeight: PairRDDFunctions[Double, Double] = labelOccur.mapValues { x => 1 / x } 40 | 41 | val weightSum: Double = labelWeight.values.sum() 42 | val nClasses: Int = 2 43 | val scaleValue: Double = nClasses / weightSum 44 | 45 | for ((label, weight) <- labelWeight.collectAsMap()) { 46 | labelToWeight.put(label, scaleValue * weight) 47 | } 48 | } 49 | 50 | /** 51 | * Feature function 52 | * 53 | * Analogous to phi(y) in (2) 54 | * Returns y_i * x_i 55 | * 56 | */ 57 | def featureFn(x: Vector[Double], y: Double): Vector[Double] = { 58 | x * y 59 | } 60 | 61 | /** 62 | * Loss function 63 | * 64 | * Returns 0 if yTruth == yPredict, 1 otherwise 65 | * Equivalent to max(0, 1 - y w^T x) 66 | */ 67 | def lossFn(yTruth: Double, yPredict: Double): Double = 68 | if (yTruth == yPredict) 69 | 0.0 70 | else 71 | 1.0 72 | 73 | /** 74 | * Maximization Oracle 75 | * 76 | * Want: max L(y_i, y) - 77 | * This returns the most violating (Loss-augmented) label. 78 | */ 79 | override def oracleFn(model: StructSVMModel[Vector[Double], Double], xi: Vector[Double], yi: Double): Double = { 80 | 81 | val weights = model.getWeights() 82 | 83 | var score_neg1 = weights dot featureFn(xi, -1.0) 84 | var score_pos1 = weights dot featureFn(xi, 1.0) 85 | 86 | // Loss augment the scores 87 | score_neg1 += 1.0 88 | score_pos1 += 1.0 89 | 90 | if (yi == -1.0) 91 | score_neg1 -= 1.0 92 | else if (yi == 1.0) 93 | score_pos1 -= 1.0 94 | else 95 | throw new IllegalArgumentException("yi not in [-1, 1], yi = " + yi) 96 | 97 | if (score_neg1 > score_pos1) 98 | -1.0 99 | else 100 | 1.0 101 | } 102 | 103 | /** 104 | * Prediction function 105 | */ 106 | def predictFn(model: StructSVMModel[Vector[Double], Double], xi: Vector[Double]): Double = { 107 | 108 | val weights = model.getWeights() 109 | 110 | val score_neg1 = weights dot featureFn(xi, -1.0) 111 | val score_pos1 = weights dot featureFn(xi, 1.0) 112 | 113 | if (score_neg1 > score_pos1) 114 | -1.0 115 | else 116 | +1.0 117 | 118 | } 119 | 120 | /** 121 | * Classifying with in-built functions 122 | */ 123 | def train( 124 | data: RDD[LabeledPoint], 125 | solverOptions: SolverOptions[Vector[Double], Double]): StructSVMModel[Vector[Double], Double] = { 126 | 127 | train(data, this, solverOptions) 128 | 129 | } 130 | 131 | /** 132 | * Classifying with user-submitted functions 133 | */ 134 | def train( 135 | data: RDD[LabeledPoint], 136 | dissolveFunctions: DissolveFunctions[Vector[Double], Double], 137 | solverOptions: SolverOptions[Vector[Double], Double]): StructSVMModel[Vector[Double], Double] = { 138 | 139 | if (solverOptions.classWeights) { 140 | generateClassWeights(data) 141 | } 142 | 143 | // Convert the RDD[LabeledPoint] to RDD[LabeledObject] 144 | val objectifiedData: RDD[LabeledObject[Vector[Double], Double]] = 145 | data.map { 146 | case x: LabeledPoint => 147 | new LabeledObject[Vector[Double], Double](x.label, 148 | if (solverOptions.sparse) { 149 | val features: Vector[Double] = x.features match { 150 | case features: org.apache.spark.mllib.linalg.SparseVector => 151 | val builder: VectorBuilder[Double] = new VectorBuilder(features.indices, features.values, features.indices.length, x.features.size) 152 | builder.toSparseVector 153 | case _ => SparseVector(x.features.toArray) 154 | } 155 | features 156 | } else 157 | DenseVector(x.features.toArray)) 158 | } 159 | 160 | val repartData = 161 | if (solverOptions.enableManualPartitionSize) 162 | objectifiedData.repartition(solverOptions.NUM_PART) 163 | else 164 | objectifiedData 165 | 166 | println("Running BinarySVMWithDBCFW solver") 167 | println(solverOptions) 168 | 169 | val (trainedModel, debugInfo) = new DBCFWSolverTuned[Vector[Double], Double]( 170 | repartData, 171 | dissolveFunctions, 172 | solverOptions, 173 | miniBatchEnabled = false).optimize() 174 | 175 | // Dump debug information into a file 176 | val fw = new FileWriter(solverOptions.debugInfoPath) 177 | // Write the current parameters being used 178 | fw.write(solverOptions.toString()) 179 | fw.write("\n") 180 | 181 | // Write spark-specific parameters 182 | fw.write(SolverUtils.getSparkConfString(data.context.getConf)) 183 | fw.write("\n") 184 | 185 | // Write values noted from the run 186 | fw.write(debugInfo) 187 | fw.close() 188 | 189 | println(debugInfo) 190 | 191 | trainedModel 192 | 193 | } 194 | 195 | } -------------------------------------------------------------------------------- /dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/classification/BinarySVMWithSSG.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.classification 2 | 3 | import java.io.FileWriter 4 | 5 | import org.apache.spark.mllib.regression.LabeledPoint 6 | import org.apache.spark.rdd.RDD 7 | 8 | import breeze.linalg.DenseVector 9 | import breeze.linalg.SparseVector 10 | import breeze.linalg.Vector 11 | import ch.ethz.dalab.dissolve.optimization.SSGSolver 12 | import ch.ethz.dalab.dissolve.optimization.DissolveFunctions 13 | import ch.ethz.dalab.dissolve.optimization.SolverOptions 14 | import ch.ethz.dalab.dissolve.optimization.SolverUtils 15 | import ch.ethz.dalab.dissolve.regression.LabeledObject 16 | 17 | /** 18 | * @author thijs 19 | * Adapted from BinarySVMWithDBCFW 20 | * 21 | */ 22 | object BinarySVMWithSSG extends DissolveFunctions[Vector[Double], Double] { 23 | 24 | /** 25 | * Feature function 26 | * 27 | * Analogous to phi(y) in (2) 28 | * Returns y_i * x_i 29 | * 30 | */ 31 | def featureFn(x: Vector[Double], y: Double): Vector[Double] = { 32 | x * y 33 | } 34 | 35 | /** 36 | * Loss function 37 | * 38 | * Returns 0 if yTruth == yPredict, 1 otherwise 39 | * Equivalent to max(0, 1 - y w^T x) 40 | */ 41 | def lossFn(yTruth: Double, yPredict: Double): Double = 42 | if (yTruth == yPredict) 43 | 0.0 44 | else 45 | 1.0 46 | 47 | /** 48 | * Maximization Oracle 49 | * 50 | * Want: max L(y_i, y) - 51 | * This returns the most violating (Loss-augmented) label. 52 | */ 53 | override def oracleFn(model: StructSVMModel[Vector[Double], Double], xi: Vector[Double], yi: Double): Double = { 54 | 55 | val weights = model.getWeights() 56 | 57 | var score_neg1 = weights dot featureFn(xi, -1.0) 58 | var score_pos1 = weights dot featureFn(xi, 1.0) 59 | 60 | // Loss augment the scores 61 | score_neg1 += 1.0 62 | score_pos1 += 1.0 63 | 64 | if (yi == -1.0) 65 | score_neg1 -= 1.0 66 | else if (yi == 1.0) 67 | score_pos1 -= 1.0 68 | else 69 | throw new IllegalArgumentException("yi not in [-1, 1], yi = " + yi) 70 | 71 | if (score_neg1 > score_pos1) 72 | -1.0 73 | else 74 | 1.0 75 | } 76 | 77 | /** 78 | * Prediction function 79 | */ 80 | def predictFn(model: StructSVMModel[Vector[Double], Double], xi: Vector[Double]): Double = { 81 | 82 | val weights = model.getWeights() 83 | 84 | val score_neg1 = weights dot featureFn(xi, -1.0) 85 | val score_pos1 = weights dot featureFn(xi, 1.0) 86 | 87 | if (score_neg1 > score_pos1) 88 | -1.0 89 | else 90 | +1.0 91 | 92 | } 93 | 94 | /** 95 | * Classifying with in-built functions 96 | */ 97 | def train( 98 | data: Seq[LabeledPoint], 99 | solverOptions: SolverOptions[Vector[Double], Double]): StructSVMModel[Vector[Double], Double] = { 100 | 101 | train(data,this,solverOptions) 102 | 103 | } 104 | 105 | /** 106 | * Classifying with user-submitted functions 107 | */ 108 | def train( 109 | data: Seq[LabeledPoint], 110 | dissolveFunctions: DissolveFunctions[Vector[Double], Double], 111 | solverOptions: SolverOptions[Vector[Double], Double]): StructSVMModel[Vector[Double], Double] = { 112 | 113 | // Convert the RDD[LabeledPoint] to RDD[LabeledObject] 114 | val objectifiedData: Seq[LabeledObject[Vector[Double], Double]] = 115 | data.map { 116 | case x: LabeledPoint => 117 | new LabeledObject[Vector[Double], Double](x.label, 118 | if (solverOptions.sparse) 119 | SparseVector(x.features.toArray) 120 | else 121 | DenseVector(x.features.toArray)) 122 | } 123 | 124 | println("Running BinarySVMWithSSGsolver") 125 | println(solverOptions) 126 | 127 | val trainedModel = new SSGSolver[Vector[Double], Double]( 128 | objectifiedData, 129 | dissolveFunctions, 130 | solverOptions).optimize() 131 | 132 | // Dump debug information into a file 133 | val fw = new FileWriter(solverOptions.debugInfoPath) 134 | // Write the current parameters being used 135 | fw.write(solverOptions.toString()) 136 | fw.write("\n") 137 | 138 | 139 | // Write values noted from the run 140 | fw.close() 141 | 142 | 143 | trainedModel 144 | 145 | } 146 | 147 | } -------------------------------------------------------------------------------- /dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/classification/ClassificationUtils.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.classification 2 | 3 | import org.apache.spark.rdd.RDD 4 | import org.apache.spark.rdd.PairRDDFunctions 5 | import ch.ethz.dalab.dissolve.regression.LabeledObject 6 | import scala.collection.mutable.HashMap 7 | import ch.ethz.dalab.dissolve.regression.LabeledObject 8 | import scala.reflect.ClassTag 9 | import ch.ethz.dalab.dissolve.optimization.SolverOptions 10 | import ch.ethz.dalab.dissolve.regression.LabeledObject 11 | import scala.collection.mutable.MutableList 12 | import ch.ethz.dalab.dissolve.regression.LabeledObject 13 | import org.apache.spark.mllib.regression.LabeledPoint 14 | 15 | object ClassificationUtils { 16 | 17 | /** 18 | * Generates class weights. If classWeights is flase the default value of 1.0 is used. Otherwise if the user submitted a custom weight array 19 | * the weights in there will be used. If the user did not submit a custom array of weights the inverse class freq. will be used 20 | */ 21 | def generateClassWeights[X, Y: ClassTag](data: RDD[LabeledObject[X, Y]], classWeights: Boolean = true, customWeights: Option[HashMap[Y,Double]] = None): HashMap[Y, Double] = { 22 | val map = HashMap[Y, Double]() 23 | val labels: Array[Y] = data.map { x: LabeledObject[X, Y] => x.label }.distinct().collect() 24 | if (classWeights) { 25 | if (customWeights.getOrElse(null) == null) { 26 | //inverse class frequency as weight 27 | val classOccur: PairRDDFunctions[Y, Double] = data.map(x => (x.label, 1.0)) 28 | val labelOccur: PairRDDFunctions[Y, Double] = classOccur.reduceByKey((x, y) => x + y) 29 | val labelWeight: PairRDDFunctions[Y, Double] = labelOccur.mapValues { x => 1 / x } 30 | 31 | val weightSum: Double = labelWeight.values.sum() 32 | val nClasses: Int = labels.length 33 | val scaleValue: Double = nClasses / weightSum 34 | 35 | var sum: Double = 0.0 36 | for ((label, weight) <- labelWeight.collectAsMap()) { 37 | val clWeight = scaleValue * weight 38 | sum += clWeight 39 | map.put(label, clWeight) 40 | } 41 | 42 | assert(sum == nClasses) 43 | } else { 44 | //use custom weights 45 | assert(labels.length == customWeights.get.size) 46 | for (label <- labels) { 47 | map.put(label, customWeights.get(label)) 48 | } 49 | } 50 | } else { 51 | // default weight of 1.0 52 | for (label <- labels) { 53 | map.put(label, 1.0) 54 | } 55 | } 56 | map 57 | } 58 | 59 | def resample[X,Y:ClassTag](data: RDD[LabeledObject[X, Y]],nSamples:HashMap[Y,Int],nSlices:Int): RDD[LabeledObject[X, Y]] = { 60 | val buckets: HashMap[Y, RDD[LabeledObject[X, Y]]] = HashMap() 61 | val newData = MutableList[LabeledObject[X, Y]]() 62 | 63 | val labels: Array[Y] = data.map { x => x.label }.distinct().collect() 64 | 65 | labels.foreach { x => buckets.put(x, data.filter { point => point.label == x }) } 66 | 67 | for (cls <- buckets.keySet) { 68 | val sampledData = buckets.get(cls).get.takeSample(true, nSamples.get(cls).get) 69 | for (x: LabeledObject[X, Y] <- sampledData) { 70 | newData.+=(x) 71 | } 72 | } 73 | data.context.parallelize(newData, nSlices) 74 | } 75 | 76 | def resample(data: RDD[LabeledPoint],nSamples:HashMap[Double,Int],nSlices:Int): RDD[LabeledPoint] = { 77 | val buckets: HashMap[Double, RDD[LabeledPoint]] = HashMap() 78 | val newData = MutableList[LabeledPoint]() 79 | 80 | val labels: Array[Double] = data.map { x => x.label }.distinct().collect() 81 | 82 | labels.foreach { x => buckets.put(x, data.filter { point => point.label == x }) } 83 | 84 | for (cls <- buckets.keySet) { 85 | val sampledData = buckets.get(cls).get.takeSample(true, nSamples.get(cls).get) 86 | for (x: LabeledPoint <- sampledData) { 87 | newData.+=(x) 88 | } 89 | } 90 | data.context.parallelize(newData, nSlices) 91 | } 92 | 93 | } -------------------------------------------------------------------------------- /dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/classification/MultiClassSVMWithDBCFW.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.classification 2 | 3 | import ch.ethz.dalab.dissolve.optimization.SolverOptions 4 | import org.apache.spark.rdd.RDD 5 | import java.io.FileWriter 6 | import ch.ethz.dalab.dissolve.regression.LabeledObject 7 | import org.apache.spark.mllib.regression.LabeledPoint 8 | import breeze.linalg._ 9 | import ch.ethz.dalab.dissolve.optimization.SolverUtils 10 | import ch.ethz.dalab.dissolve.optimization.DissolveFunctions 11 | import ch.ethz.dalab.dissolve.optimization.DBCFWSolverTuned 12 | import scala.collection.mutable.HashMap 13 | import org.apache.spark.rdd.PairRDDFunctions 14 | import ch.ethz.dalab.dissolve.optimization.SSGSolver 15 | import ch.ethz.dalab.dissolve.optimization.SSGSolver 16 | import ch.ethz.dalab.dissolve.optimization.UseDBCFWSolver 17 | import ch.ethz.dalab.dissolve.optimization.UseSSGSolver 18 | 19 | case class MultiClassLabel(label: Double, numClasses: Int) 20 | 21 | object MultiClassSVMWithDBCFW extends DissolveFunctions[Vector[Double], MultiClassLabel] { 22 | 23 | var labelToWeight = HashMap[MultiClassLabel, Double]() 24 | 25 | 26 | override def classWeights(label: MultiClassLabel): Double = { 27 | labelToWeight.get(label).getOrElse(1.0) 28 | } 29 | 30 | /** 31 | * Feature function 32 | * 33 | * Analogous to phi(y) in (2) 34 | * Returns y_i * x_i 35 | * 36 | */ 37 | def featureFn(x: Vector[Double], y: MultiClassLabel): Vector[Double] = { 38 | assert(y.label.toInt < y.numClasses, 39 | "numClasses = %d. Found y_i.label = %d" 40 | .format(y.numClasses, y.label.toInt)) 41 | 42 | val featureVector = Vector.zeros[Double](x.size * y.numClasses) 43 | val numDims = x.size 44 | 45 | // Populate the featureVector in blocks [ ...]. 46 | val startIdx = y.label.toInt * numDims 47 | val endIdx = startIdx + numDims 48 | 49 | featureVector(startIdx until endIdx) := x 50 | 51 | featureVector 52 | } 53 | 54 | /** 55 | * Loss function 56 | * 57 | * Returns 0 if yTruth == yPredict, 1 otherwise 58 | * Equivalent to max(0, 1 - y w^T x) 59 | */ 60 | def lossFn(yTruth: MultiClassLabel, yPredict: MultiClassLabel): Double = 61 | if (yTruth.label == yPredict.label) 62 | 0.0 63 | else 64 | 1.0 65 | 66 | /** 67 | * Maximization Oracle 68 | * 69 | * Want: argmax L(y_i, y) - 70 | * This returns the most violating (Loss-augmented) label. 71 | */ 72 | override def oracleFn(model: StructSVMModel[Vector[Double], MultiClassLabel], xi: Vector[Double], yi: MultiClassLabel): MultiClassLabel = { 73 | 74 | val weights = model.getWeights() 75 | val numClasses = yi.numClasses 76 | 77 | // Obtain a list of scores for each class 78 | val mostViolatedContraint: (Double, Double) = 79 | (0 until numClasses).map { 80 | case cl => 81 | (cl, weights dot featureFn(xi, MultiClassLabel(cl, numClasses))) 82 | }.map { 83 | case (cl, score) => 84 | (cl.toDouble, score + 1.0) 85 | }.map { // Loss-augment the scores 86 | case (cl, score) => 87 | if (yi.label == cl) 88 | (cl, score - 1.0) 89 | else 90 | (cl, score) 91 | }.maxBy { // Obtain the class with the maximum value 92 | case (cl, score) => score 93 | } 94 | 95 | MultiClassLabel(mostViolatedContraint._1, numClasses) 96 | } 97 | 98 | /** 99 | * Prediction function 100 | */ 101 | def predictFn(model: StructSVMModel[Vector[Double], MultiClassLabel], xi: Vector[Double]): MultiClassLabel = { 102 | 103 | val weights = model.getWeights() 104 | val numClasses = model.numClasses 105 | 106 | assert(numClasses > 1) 107 | 108 | val prediction = 109 | (0 until numClasses).map { 110 | case cl => 111 | (cl.toDouble, weights dot featureFn(xi, MultiClassLabel(cl, numClasses))) 112 | }.maxBy { // Obtain the class with the maximum value 113 | case (cl, score) => score 114 | } 115 | 116 | MultiClassLabel(prediction._1, numClasses) 117 | 118 | } 119 | 120 | /** 121 | * Classifying with in-built functions 122 | * 123 | * data needs to be 0-indexed 124 | */ 125 | def train( 126 | data: RDD[LabeledPoint], 127 | numClasses: Int, 128 | solverOptions: SolverOptions[Vector[Double], MultiClassLabel], 129 | customWeights:Option[HashMap[MultiClassLabel,Double]]=None): StructSVMModel[Vector[Double], MultiClassLabel] = { 130 | 131 | solverOptions.numClasses = numClasses 132 | 133 | // Convert the RDD[LabeledPoint] to RDD[LabeledObject] 134 | val objectifiedData: RDD[LabeledObject[Vector[Double], MultiClassLabel]] = 135 | data.map { 136 | case x: LabeledPoint => 137 | val features: Vector[Double] = x.features match { 138 | case features: org.apache.spark.mllib.linalg.SparseVector => 139 | val builder: VectorBuilder[Double] = new VectorBuilder(features.indices, features.values, features.indices.length, x.features.size) 140 | builder.toSparseVector 141 | case _ => SparseVector(x.features.toArray) 142 | } 143 | new LabeledObject[Vector[Double], MultiClassLabel](MultiClassLabel(x.label, numClasses), features) 144 | } 145 | 146 | labelToWeight = ClassificationUtils.generateClassWeights(objectifiedData,solverOptions.classWeights,customWeights) 147 | 148 | val repartData = 149 | if (solverOptions.enableManualPartitionSize) 150 | objectifiedData.repartition(solverOptions.NUM_PART) 151 | else 152 | objectifiedData 153 | 154 | println(solverOptions) 155 | 156 | 157 | val (trainedModel,debugInfo) = solverOptions.solver match { 158 | case UseDBCFWSolver => new DBCFWSolverTuned[Vector[Double], MultiClassLabel]( 159 | repartData, 160 | this, 161 | solverOptions, 162 | miniBatchEnabled = false).optimize() 163 | case UseSSGSolver => (new SSGSolver[Vector[Double], MultiClassLabel]( 164 | repartData.collect(), 165 | this, 166 | solverOptions 167 | ).optimize(),"") 168 | } 169 | 170 | println(debugInfo) 171 | 172 | // Dump debug information into a file 173 | val fw = new FileWriter(solverOptions.debugInfoPath) 174 | // Write the current parameters being used 175 | fw.write(solverOptions.toString()) 176 | fw.write("\n") 177 | 178 | // Write spark-specific parameters 179 | fw.write(SolverUtils.getSparkConfString(data.context.getConf)) 180 | fw.write("\n") 181 | 182 | // Write values noted from the run 183 | fw.write(debugInfo) 184 | fw.close() 185 | 186 | trainedModel 187 | 188 | } 189 | 190 | /** 191 | * Classifying with user-submitted functions 192 | */ 193 | def train( 194 | data: RDD[LabeledPoint], 195 | dissolveFunctions: DissolveFunctions[Vector[Double], MultiClassLabel], 196 | solverOptions: SolverOptions[Vector[Double], MultiClassLabel]): StructSVMModel[Vector[Double], MultiClassLabel] = { 197 | 198 | val numClasses = solverOptions.numClasses 199 | assert(numClasses > 1) 200 | 201 | val minlabel = data.map(_.label).min() 202 | val maxlabel = data.map(_.label).max() 203 | assert(minlabel == 0, "Label classes need to be 0-indexed") 204 | assert(maxlabel - minlabel + 1 == numClasses, 205 | "Number of classes in data do not tally with passed argument") 206 | 207 | // Convert the RDD[LabeledPoint] to RDD[LabeledObject] 208 | val objectifiedData: RDD[LabeledObject[Vector[Double], MultiClassLabel]] = 209 | data.map { 210 | case x: LabeledPoint => 211 | new LabeledObject[Vector[Double], MultiClassLabel](MultiClassLabel(x.label, numClasses), 212 | if (solverOptions.sparse) { 213 | val features: Vector[Double] = x.features match { 214 | case features: org.apache.spark.mllib.linalg.SparseVector => 215 | val builder: VectorBuilder[Double] = new VectorBuilder(features.indices, features.values, features.indices.length, x.features.size) 216 | builder.toSparseVector 217 | case _ => SparseVector(x.features.toArray) 218 | } 219 | features 220 | } else 221 | Vector(x.features.toArray)) 222 | } 223 | 224 | val repartData = 225 | if (solverOptions.enableManualPartitionSize) 226 | objectifiedData.repartition(solverOptions.NUM_PART) 227 | else 228 | objectifiedData 229 | 230 | println(solverOptions) 231 | 232 | //choose optimizer 233 | val (trainedModel,debugInfo) = solverOptions.solver match { 234 | case UseDBCFWSolver => new DBCFWSolverTuned[Vector[Double], MultiClassLabel]( 235 | repartData, 236 | dissolveFunctions, 237 | solverOptions, 238 | miniBatchEnabled = false).optimize() 239 | case UseSSGSolver => (new SSGSolver[Vector[Double], MultiClassLabel]( 240 | repartData.collect(), 241 | dissolveFunctions, 242 | solverOptions 243 | ).optimize(),"") 244 | } 245 | 246 | // Dump debug information into a file 247 | val fw = new FileWriter(solverOptions.debugInfoPath) 248 | // Write the current parameters being used 249 | fw.write(solverOptions.toString()) 250 | fw.write("\n") 251 | 252 | // Write spark-specific parameters 253 | fw.write(SolverUtils.getSparkConfString(data.context.getConf)) 254 | fw.write("\n") 255 | 256 | // Write values noted from the run 257 | fw.write(debugInfo) 258 | fw.close() 259 | 260 | println(debugInfo) 261 | 262 | trainedModel 263 | 264 | } 265 | } -------------------------------------------------------------------------------- /dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/classification/StructSVMModel.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * 3 | */ 4 | package ch.ethz.dalab.dissolve.classification 5 | 6 | import breeze.linalg._ 7 | import ch.ethz.dalab.dissolve.optimization.DissolveFunctions 8 | 9 | /** 10 | * This is analogous to the model: 11 | * a) returned by Andrea Vedaldi's svm-struct-matlab 12 | * b) used by BCFWStruct. See (Lacoste-Julien, Jaggi, Schmidt, Pletscher; ICML 2013) 13 | * 14 | * @constructor Create a new StructSVM model 15 | * @param weights Primal variable. Corresponds to w in Algorithm 4 16 | * @param ell Corresponds to l in Algorithm 4 17 | * @param ellMat Corresponds to l_i in Algorithm 4 18 | * @param pred Prediction function 19 | */ 20 | class StructSVMModel[X, Y]( 21 | var weights: Vector[Double], 22 | var ell: Double, 23 | val ellMat: Vector[Double], 24 | val dissolveFunctions: DissolveFunctions[X, Y], 25 | val numClasses: Int) extends Serializable { 26 | 27 | def this( 28 | weights: Vector[Double], 29 | ell: Double, 30 | ellMat: Vector[Double], 31 | dissolveFunctions: DissolveFunctions[X, Y]) = 32 | this(weights, ell, ellMat, dissolveFunctions, -1) 33 | 34 | 35 | def getWeights(): Vector[Double] = { 36 | weights 37 | } 38 | 39 | def setWeights(newWeights: Vector[Double]) = { 40 | weights = newWeights 41 | } 42 | 43 | def getEll(): Double = 44 | ell 45 | 46 | def setEll(newEll: Double) = 47 | ell = newEll 48 | 49 | def predict(pattern: X): Y = { 50 | dissolveFunctions.predictFn(this, pattern) 51 | } 52 | 53 | override def clone(): StructSVMModel[X, Y] = { 54 | new StructSVMModel(this.weights.copy, 55 | ell, 56 | this.ellMat.copy, 57 | dissolveFunctions, 58 | numClasses) 59 | } 60 | } -------------------------------------------------------------------------------- /dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/classification/StructSVMWithBCFW.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * 3 | */ 4 | package ch.ethz.dalab.dissolve.classification 5 | 6 | import java.io.FileWriter 7 | 8 | import scala.reflect.ClassTag 9 | 10 | import ch.ethz.dalab.dissolve.optimization.BCFWSolver 11 | import ch.ethz.dalab.dissolve.optimization.DissolveFunctions 12 | import ch.ethz.dalab.dissolve.optimization.SolverOptions 13 | import ch.ethz.dalab.dissolve.regression.LabeledObject 14 | 15 | /** 16 | * Analogous to BCFWSolver 17 | * 18 | * 19 | */ 20 | class StructSVMWithBCFW[X, Y]( 21 | val data: Seq[LabeledObject[X, Y]], 22 | val dissolveFunctions: DissolveFunctions[X, Y], 23 | val solverOptions: SolverOptions[X, Y]) { 24 | 25 | def trainModel()(implicit m: ClassTag[Y]): StructSVMModel[X, Y] = { 26 | val (trainedModel, debugInfo) = new BCFWSolver(data, 27 | dissolveFunctions, 28 | solverOptions).optimize() 29 | 30 | // Dump debug information into a file 31 | val fw = new FileWriter(solverOptions.debugInfoPath) 32 | // Write the current parameters being used 33 | fw.write("# BCFW\n") 34 | fw.write(solverOptions.toString()) 35 | fw.write("\n") 36 | 37 | // Write values noted from the run 38 | fw.write(debugInfo) 39 | fw.close() 40 | 41 | trainedModel 42 | } 43 | 44 | } -------------------------------------------------------------------------------- /dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/classification/StructSVMWithDBCFW.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.classification 2 | 3 | import ch.ethz.dalab.dissolve.regression.LabeledObject 4 | import breeze.linalg._ 5 | import ch.ethz.dalab.dissolve.optimization.SolverOptions 6 | import org.apache.spark.SparkContext 7 | import org.apache.spark.rdd.RDD 8 | import java.io.FileWriter 9 | import ch.ethz.dalab.dissolve.optimization.SolverUtils 10 | import scala.reflect.ClassTag 11 | import ch.ethz.dalab.dissolve.optimization.DBCFWSolverTuned 12 | import ch.ethz.dalab.dissolve.optimization.DissolveFunctions 13 | 14 | class StructSVMWithDBCFW[X, Y]( 15 | val data: RDD[LabeledObject[X, Y]], 16 | val dissolveFunctions: DissolveFunctions[X, Y], 17 | val solverOptions: SolverOptions[X, Y]) { 18 | 19 | def trainModel()(implicit m: ClassTag[Y]): StructSVMModel[X, Y] = { 20 | val (trainedModel, debugInfo) = new DBCFWSolverTuned[X, Y]( 21 | data, 22 | dissolveFunctions, 23 | solverOptions, 24 | miniBatchEnabled = false).optimize() 25 | 26 | // Dump debug information into a file 27 | val fw = new FileWriter(solverOptions.debugInfoPath) 28 | // Write the current parameters being used 29 | fw.write(solverOptions.toString()) 30 | fw.write("\n") 31 | 32 | // Write spark-specific parameters 33 | fw.write(SolverUtils.getSparkConfString(data.context.getConf)) 34 | fw.write("\n") 35 | 36 | // Write values noted from the run 37 | fw.write(debugInfo) 38 | fw.close() 39 | 40 | print(debugInfo) 41 | 42 | // Return the trained model 43 | trainedModel 44 | } 45 | } 46 | 47 | object StructSVMWithDBCFW { 48 | def train[X, Y](data: RDD[LabeledObject[X, Y]], 49 | dissolveFunctions: DissolveFunctions[X, Y], 50 | solverOptions: SolverOptions[X, Y])(implicit m: ClassTag[Y]): StructSVMModel[X, Y] = { 51 | val (trainedModel, debugInfo) = new DBCFWSolverTuned[X, Y]( 52 | data, 53 | dissolveFunctions, 54 | solverOptions, 55 | miniBatchEnabled = false).optimize() 56 | 57 | // Dump debug information into a file 58 | val fw = new FileWriter(solverOptions.debugInfoPath) 59 | // Write the current parameters being used 60 | fw.write(solverOptions.toString()) 61 | fw.write("\n") 62 | 63 | // Write spark-specific parameters 64 | fw.write(SolverUtils.getSparkConfString(data.context.getConf)) 65 | fw.write("\n") 66 | 67 | // Write values noted from the run 68 | fw.write(debugInfo) 69 | fw.close() 70 | 71 | print(debugInfo) 72 | 73 | // Return the trained model 74 | trainedModel 75 | 76 | } 77 | } -------------------------------------------------------------------------------- /dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/classification/StructSVMWithMiniBatch.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.classification 2 | 3 | import scala.reflect.ClassTag 4 | 5 | import org.apache.spark.rdd.RDD 6 | 7 | import ch.ethz.dalab.dissolve.optimization.DBCFWSolverTuned 8 | import ch.ethz.dalab.dissolve.optimization.DissolveFunctions 9 | import ch.ethz.dalab.dissolve.optimization.SolverOptions 10 | import ch.ethz.dalab.dissolve.regression.LabeledObject 11 | 12 | class StructSVMWithMiniBatch[X, Y]( 13 | val data: RDD[LabeledObject[X, Y]], 14 | val dissolveFunctions: DissolveFunctions[X, Y], 15 | val solverOptions: SolverOptions[X, Y]) { 16 | 17 | def trainModel()(implicit m: ClassTag[Y]): StructSVMModel[X, Y] = 18 | new DBCFWSolverTuned( 19 | data, 20 | dissolveFunctions, 21 | solverOptions, 22 | miniBatchEnabled = true).optimize()._1 23 | } -------------------------------------------------------------------------------- /dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/classification/StructSVMWithSSG.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * 3 | */ 4 | package ch.ethz.dalab.dissolve.classification 5 | 6 | import scala.reflect.ClassTag 7 | 8 | import ch.ethz.dalab.dissolve.optimization.DissolveFunctions 9 | import ch.ethz.dalab.dissolve.optimization.SSGSolver 10 | import ch.ethz.dalab.dissolve.optimization.SolverOptions 11 | import ch.ethz.dalab.dissolve.regression.LabeledObject 12 | 13 | /** 14 | * 15 | */ 16 | class StructSVMWithSSG[X, Y]( 17 | val data: Seq[LabeledObject[X, Y]], 18 | val dissolveFunctions: DissolveFunctions[X, Y], 19 | val solverOptions: SolverOptions[X, Y]) { 20 | 21 | def trainModel()(implicit m: ClassTag[Y]): StructSVMModel[X, Y] = 22 | new SSGSolver(data, 23 | dissolveFunctions, 24 | solverOptions).optimize() 25 | 26 | } -------------------------------------------------------------------------------- /dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/classification/Types.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.classification 2 | 3 | import breeze.linalg.{Vector, DenseVector} 4 | import scala.collection.mutable.MutableList 5 | 6 | object Types { 7 | 8 | type Index = Int 9 | type Level = Int 10 | type PrimalInfo = Tuple2[Vector[Double], Double] 11 | type BoundedCacheList[Y] = MutableList[Y] 12 | 13 | } -------------------------------------------------------------------------------- /dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/optimization/DissolveFunctions.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.optimization 2 | 3 | import breeze.linalg.Vector 4 | import ch.ethz.dalab.dissolve.classification.StructSVMModel 5 | 6 | trait DissolveFunctions[X, Y] extends Serializable { 7 | 8 | def featureFn(x: X, y: Y): Vector[Double] 9 | 10 | def lossFn(yPredicted: Y, yTruth: Y): Double 11 | 12 | // Override either `oracleFn` or `oracleCandidateStream` 13 | def oracleFn(model: StructSVMModel[X, Y], x: X, y: Y): Y = 14 | oracleCandidateStream(model, x, y).head 15 | 16 | def oracleCandidateStream(model: StructSVMModel[X, Y], x: X, y: Y, initLevel: Int = 0): Stream[Y] = 17 | oracleFn(model, x, y) #:: Stream.empty 18 | 19 | def predictFn(model: StructSVMModel[X, Y], x: X): Y 20 | 21 | def classWeights(y:Y): Double = 1.0 22 | 23 | } -------------------------------------------------------------------------------- /dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/optimization/SSGSolver.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.optimization 2 | 3 | import java.io.File 4 | import java.io.PrintWriter 5 | 6 | import breeze.linalg._ 7 | import breeze.linalg.DenseVector 8 | import breeze.linalg.Vector 9 | import breeze.linalg.csvwrite 10 | import breeze.numerics._ 11 | import ch.ethz.dalab.dissolve.classification.StructSVMModel 12 | import ch.ethz.dalab.dissolve.regression.LabeledObject 13 | 14 | /** 15 | * Train a structured SVM using standard Stochastic (Sub)Gradient Descent (SGD). 16 | * The implementation here is single machine, not distributed. 17 | * 18 | * Input: 19 | * Each data point (x_i, y_i) is composed of: 20 | * x_i, the data example 21 | * y_i, the label 22 | * 23 | * @param type for the data examples 24 | * @param type for the labels of each example 25 | */ 26 | class SSGSolver[X, Y]( 27 | val data: Seq[LabeledObject[X, Y]], 28 | val dissolveFunctions: DissolveFunctions[X, Y], 29 | val solverOptions: SolverOptions[X, Y]) { 30 | 31 | val roundLimit = solverOptions.roundLimit 32 | val lambda = solverOptions.lambda 33 | val debugOn: Boolean = solverOptions.debug 34 | val gamma0 = solverOptions.ssg_gamma0 35 | 36 | val maxOracle = dissolveFunctions.oracleFn _ 37 | val phi = dissolveFunctions.featureFn _ 38 | val lossFn = dissolveFunctions.lossFn _ 39 | val cWeight = dissolveFunctions.classWeights _ 40 | // Number of dimensions of \phi(x, y) 41 | val ndims: Int = phi(data(0).pattern, data(0).label).size 42 | 43 | // Filenames 44 | val lossWriterFileName = "data/debug/ssg-loss.csv" 45 | 46 | /** 47 | * SSG optimizer 48 | */ 49 | def optimize(): StructSVMModel[X, Y] = { 50 | 51 | var k: Integer = 0 52 | val n: Int = data.length 53 | val d: Int = phi(data(0).pattern, data(0).label).size 54 | // Use first example to determine dimension of w 55 | val model: StructSVMModel[X, Y] = new StructSVMModel(DenseVector.zeros(phi(data(0).pattern, data(0).label).size), 56 | 0.0, 57 | DenseVector.zeros(ndims), 58 | dissolveFunctions) 59 | 60 | // Initialization in case of Weighted Averaging 61 | var wAvg: DenseVector[Double] = 62 | if (solverOptions.doWeightedAveraging) 63 | DenseVector.zeros(d) 64 | else null 65 | 66 | var debugIter = if (solverOptions.debugMultiplier == 0) { 67 | solverOptions.debugMultiplier = 100 68 | n 69 | } else { 70 | 1 71 | } 72 | val debugModel: StructSVMModel[X, Y] = new StructSVMModel(DenseVector.zeros(d), 0.0, DenseVector.zeros(ndims), dissolveFunctions) 73 | 74 | val lossWriter = if (solverOptions.debug) new PrintWriter(new File(lossWriterFileName)) else null 75 | if (solverOptions.debug) { 76 | if (solverOptions.testData != null) 77 | lossWriter.write("pass_num,iter,primal,dual,duality_gap,train_error,test_error\n") 78 | else 79 | lossWriter.write("pass_num,iter,primal,dual,duality_gap,train_error\n") 80 | } 81 | 82 | if (debugOn) { 83 | println("Beginning training of %d data points in %d passes with lambda=%f".format(n, roundLimit, lambda)) 84 | } 85 | 86 | for (passNum <- 0 until roundLimit) { 87 | 88 | if (debugOn) 89 | println("Starting pass #%d".format(passNum)) 90 | 91 | for (dummy <- 0 until n) { 92 | // 1) Pick example 93 | val i: Int = dummy 94 | val pattern: X = data(i).pattern 95 | val label: Y = data(i).label 96 | 97 | // 2) Solve loss-augmented inference for point i 98 | val ystar_i: Y = maxOracle(model, pattern, label) 99 | 100 | // 3) Get the subgradient 101 | val psi_i: Vector[Double] = (phi(pattern, label) - phi(pattern, ystar_i))*cWeight(label) 102 | val w_s: Vector[Double] = psi_i :* (1 / (n * lambda)) 103 | 104 | if (debugOn && dummy == (n - 1)) 105 | csvwrite(new File("data/debug/scala-w-%d.csv".format(passNum + 1)), w_s.toDenseVector.toDenseMatrix) 106 | 107 | // 4) Step size gamma 108 | val gamma: Double = 1.0 / (gamma0*(k + 1.0)) 109 | 110 | // 5) Update the weights of the model 111 | val newWeights: Vector[Double] = (model.getWeights() :* (1 - gamma)) + (w_s :* (gamma * n)) 112 | model.setWeights(newWeights) 113 | 114 | k = k + 1 115 | 116 | if (solverOptions.doWeightedAveraging) { 117 | val rho: Double = 2.0 / (k + 2.0) 118 | wAvg = wAvg * (1.0 - rho) + model.getWeights() * rho 119 | } 120 | 121 | if (debugOn && k >= debugIter) { 122 | 123 | if (solverOptions.doWeightedAveraging) { 124 | debugModel.setWeights(wAvg) 125 | } else { 126 | debugModel.setWeights(model.getWeights) 127 | } 128 | 129 | val primal = SolverUtils.primalObjective(data, dissolveFunctions, debugModel, lambda) 130 | val trainError = SolverUtils.averageLoss(data, dissolveFunctions, debugModel)._1 131 | 132 | if (solverOptions.testData != null) { 133 | val testError = 134 | if (solverOptions.testData.isDefined) 135 | SolverUtils.averageLoss(solverOptions.testData.get, dissolveFunctions, debugModel)._1 136 | else 137 | 0.00 138 | println("Pass %d Iteration %d, SVM primal = %f, Train error = %f, Test error = %f" 139 | .format(passNum + 1, k, primal, trainError, testError)) 140 | 141 | if (solverOptions.debug) 142 | lossWriter.write("%d,%d,%f,%f,%f\n".format(passNum + 1, k, primal, trainError, testError)) 143 | } else { 144 | println("Pass %d Iteration %d, SVM primal = %f, Train error = %f" 145 | .format(passNum + 1, k, primal, trainError)) 146 | if (solverOptions.debug) 147 | lossWriter.write("%d,%d,%f,%f,\n".format(passNum + 1, k, primal, trainError)) 148 | } 149 | 150 | debugIter = min(debugIter + n, ceil(debugIter * (1 + solverOptions.debugMultiplier / 100))) 151 | 152 | } 153 | 154 | } 155 | if (debugOn) 156 | println("Completed pass #%d".format(passNum)) 157 | 158 | } 159 | 160 | return model 161 | } 162 | 163 | } -------------------------------------------------------------------------------- /dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/optimization/SolverOptions.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.optimization 2 | 3 | import ch.ethz.dalab.dissolve.regression.LabeledObject 4 | import breeze.linalg.Vector 5 | import org.apache.spark.rdd.RDD 6 | import java.io.File 7 | 8 | sealed trait StoppingCriterion 9 | 10 | // Option A - Limit number of communication rounds 11 | case object RoundLimitCriterion extends StoppingCriterion { 12 | override def toString(): String = { "RoundLimitCriterion" } 13 | } 14 | 15 | // Option B - Check gap 16 | case object GapThresholdCriterion extends StoppingCriterion { 17 | override def toString(): String = { "GapThresholdCriterion" } 18 | } 19 | 20 | // Option C - Run for this amount of time (in secs) 21 | case object TimeLimitCriterion extends StoppingCriterion { 22 | override def toString(): String = { "TimeLimitCriterion" } 23 | } 24 | 25 | sealed trait Solver 26 | 27 | case object UseDBCFWSolver extends Solver{ 28 | override def toString(): String = {"DBCFWSolver"} 29 | } 30 | 31 | case object UseSSGSolver extends Solver{ 32 | override def toString(): String = {"SSGSolver"} 33 | } 34 | 35 | class SolverOptions[X, Y] extends Serializable { 36 | var doWeightedAveraging: Boolean = false 37 | 38 | var randSeed: Int = 42 39 | /** 40 | * BCFW - "uniform", "perm" or "iter" 41 | * DBCFW - "count", "frac" 42 | */ 43 | var sample: String = "frac" 44 | var lambda: Double = 0.01 // FIXME This is 1/n in Matlab code 45 | 46 | var testData: Option[Seq[LabeledObject[X, Y]]] = Option.empty[Seq[LabeledObject[X, Y]]] 47 | var testDataRDD: Option[RDD[LabeledObject[X, Y]]] = Option.empty[RDD[LabeledObject[X, Y]]] 48 | 49 | var doLineSearch: Boolean = true 50 | 51 | // Checkpoint once in these many rounds 52 | var checkpointFreq: Int = 50 53 | 54 | // In case of multi-class 55 | var numClasses = -1 56 | 57 | var classWeights:Boolean = true 58 | 59 | // Cache params 60 | var enableOracleCache: Boolean = false 61 | var oracleCacheSize: Int = 10 62 | 63 | // DBCFW specific params 64 | var H: Int = 5 // Number of data points to sample in each round of CoCoA (= number of local coordinate updates) 65 | var sampleFrac: Double = 0.5 66 | var sampleWithReplacement: Boolean = false 67 | 68 | var enableManualPartitionSize: Boolean = false 69 | var NUM_PART: Int = 1 // Number of partitions of the RDD 70 | 71 | // SSG specific params 72 | var ssg_gamma0: Int = 1000 73 | 74 | // For debugging/Testing purposes 75 | // Basic debugging flag 76 | var debug: Boolean = false 77 | // Obtain statistics (primal value, duality gap, train error, test error, etc.) once in these many rounds. 78 | // If 1, obtains statistics in each round 79 | var debugMultiplier: Int = 1 80 | 81 | // Option A - Limit number of communication rounds 82 | var roundLimit: Int = 25 83 | 84 | // Option B - Check gap 85 | var gapThreshold: Double = 0.1 86 | var gapCheck: Int = 1 // Check for once these many rounds 87 | 88 | // Option C - Run for this amount of time (in secs) 89 | var timeLimit: Int = 300 90 | 91 | var stoppingCriterion: StoppingCriterion = RoundLimitCriterion 92 | 93 | // Sparse representation of w_i's 94 | var sparse: Boolean = false 95 | 96 | // Path to write the CSVs 97 | var debugInfoPath: String = new File(".").getCanonicalPath() + "/debugInfo-%d.csv".format(System.currentTimeMillis()) 98 | 99 | var solver: Solver = UseDBCFWSolver 100 | 101 | override def toString(): String = { 102 | val sb: StringBuilder = new StringBuilder() 103 | 104 | sb ++= "# numRounds=%s\n".format(roundLimit) 105 | sb ++= "# doWeightedAveraging=%s\n".format(doWeightedAveraging) 106 | 107 | sb ++= "# randSeed=%d\n".format(randSeed) 108 | 109 | sb ++= "# sample=%s\n".format(sample) 110 | sb ++= "# lambda=%f\n".format(lambda) 111 | sb ++= "# doLineSearch=%s\n".format(doLineSearch) 112 | 113 | sb ++= "# enableManualPartitionSize=%s\n".format(enableManualPartitionSize) 114 | sb ++= "# NUM_PART=%s\n".format(NUM_PART) 115 | 116 | sb ++= "# enableOracleCache=%s\n".format(enableOracleCache) 117 | sb ++= "# oracleCacheSize=%d\n".format(oracleCacheSize) 118 | 119 | sb ++= "# H=%d\n".format(H) 120 | sb ++= "# sampleFrac=%f\n".format(sampleFrac) 121 | sb ++= "# sampleWithReplacement=%s\n".format(sampleWithReplacement) 122 | 123 | sb ++= "# debugInfoPath=%s\n".format(debugInfoPath) 124 | 125 | sb ++= "# checkpointFreq=%d\n".format(checkpointFreq) 126 | 127 | sb ++= "# stoppingCriterion=%s\n".format(stoppingCriterion) 128 | this.stoppingCriterion match { 129 | case RoundLimitCriterion => sb ++= "# roundLimit=%d\n".format(roundLimit) 130 | case GapThresholdCriterion => sb ++= "# gapThreshold=%f\n".format(gapThreshold) 131 | case TimeLimitCriterion => sb ++= "# timeLimit=%d\n".format(timeLimit) 132 | case _ => throw new Exception("Unrecognized Stopping Criterion") 133 | } 134 | 135 | sb ++="# solver=%s\n".format(solver) 136 | 137 | sb ++="# class weighting=%s\n".format(classWeights) 138 | 139 | sb ++= "# debugMultiplier=%d\n".format(debugMultiplier) 140 | 141 | sb.toString() 142 | } 143 | 144 | } 145 | -------------------------------------------------------------------------------- /dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/optimization/SolverUtils.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.optimization 2 | 3 | import ch.ethz.dalab.dissolve.classification.StructSVMModel 4 | import breeze.linalg._ 5 | import ch.ethz.dalab.dissolve.regression.LabeledObject 6 | import org.apache.spark.SparkContext 7 | import org.apache.spark.SparkConf 8 | import org.apache.spark.rdd.RDD 9 | import scala.reflect.ClassTag 10 | import org.apache.spark.broadcast.Broadcast 11 | 12 | object SolverUtils { 13 | 14 | /** 15 | * Average loss 16 | */ 17 | def averageLoss[X, Y](data: Seq[LabeledObject[X, Y]], 18 | dissolveFunctions: DissolveFunctions[X, Y], 19 | model: StructSVMModel[X, Y]): (Double, Double) = { 20 | 21 | var errorTerm: Double = 0.0 22 | var structuredHingeLoss: Double = 0.0 23 | 24 | for (i <- 0 until data.size) { 25 | val ystar_i = dissolveFunctions.predictFn(model, data(i).pattern) 26 | val loss = dissolveFunctions.lossFn(data(i).label, ystar_i) 27 | errorTerm += loss 28 | 29 | val wFeatureDotProduct = model.getWeights().t * dissolveFunctions.featureFn(data(i).pattern, data(i).label) 30 | val structuredHingeLoss: Double = loss - wFeatureDotProduct 31 | } 32 | 33 | // Return average of loss terms 34 | (errorTerm / (data.size.toDouble), structuredHingeLoss / (data.size.toDouble)) 35 | } 36 | 37 | def averageLoss[X, Y](data: RDD[LabeledObject[X, Y]], 38 | dissolveFunctions: DissolveFunctions[X, Y], 39 | model: StructSVMModel[X, Y], 40 | dataSize: Int): (Double, Double) = { 41 | 42 | val (loss, hloss) = 43 | data.map { 44 | case datapoint => 45 | val ystar_i = dissolveFunctions.predictFn(model, datapoint.pattern) 46 | val loss = dissolveFunctions.lossFn(ystar_i, datapoint.label) 47 | val wFeatureDotProduct = model.getWeights().t * (dissolveFunctions.featureFn(datapoint.pattern, datapoint.label) 48 | - dissolveFunctions.featureFn(datapoint.pattern, ystar_i)) 49 | val structuredHingeLoss: Double = loss - wFeatureDotProduct 50 | 51 | (loss, structuredHingeLoss) 52 | }.fold((0.0, 0.0)) { 53 | case ((lossAccum, hlossAccum), (loss, hloss)) => 54 | (lossAccum + loss, hlossAccum + hloss) 55 | } 56 | 57 | (loss / dataSize, hloss / dataSize) 58 | } 59 | 60 | /** 61 | * Objective function (SVM dual, assuming we know the vector b of all losses. See BCFW paper) 62 | */ 63 | def objectiveFunction(w: Vector[Double], 64 | b_alpha: Double, 65 | lambda: Double): Double = { 66 | // Return the value of f(alpha) 67 | 0.5 * lambda * (w.t * w) - b_alpha 68 | } 69 | 70 | /** 71 | * Compute Duality gap 72 | * Requires one full pass of decoding over all data examples. 73 | */ 74 | def dualityGap[X, Y](data: Seq[LabeledObject[X, Y]], 75 | featureFn: (X, Y) => Vector[Double], 76 | lossFn: (Y, Y) => Double, 77 | oracleFn: (StructSVMModel[X, Y], X, Y) => Y, 78 | model: StructSVMModel[X, Y], 79 | lambda: Double)(implicit m: ClassTag[Y]): (Double, Vector[Double], Double) = { 80 | 81 | val phi = featureFn 82 | val maxOracle = oracleFn 83 | 84 | val w: Vector[Double] = model.getWeights() 85 | val ell: Double = model.getEll() 86 | 87 | val n: Int = data.size 88 | val d: Int = model.getWeights().size 89 | val yStars = new Array[Y](n) 90 | 91 | for (i <- 0 until n) { 92 | yStars(i) = maxOracle(model, data(i).pattern, data(i).label) 93 | } 94 | 95 | var w_s: DenseVector[Double] = DenseVector.zeros[Double](d) 96 | var ell_s: Double = 0.0 97 | for (i <- 0 until n) { 98 | w_s += phi(data(i).pattern, data(i).label) - phi(data(i).pattern, yStars(i)) 99 | ell_s += lossFn(yStars(i), data(i).label) 100 | } 101 | 102 | w_s = w_s / (lambda * n) 103 | ell_s = ell_s / n 104 | 105 | val gap: Double = w.t * (w - w_s) * lambda - ell + ell_s 106 | 107 | (gap, w_s, ell_s) 108 | } 109 | 110 | /** 111 | * Alternative implementation, using fold. TODO: delete this or the above 112 | * Requires one full pass of decoding over all data examples. 113 | */ 114 | def dualityGap[X, Y](data: RDD[LabeledObject[X, Y]], 115 | dissolveFunctions: DissolveFunctions[X, Y], 116 | model: StructSVMModel[X, Y], 117 | lambda: Double, 118 | dataSize: Int)(implicit m: ClassTag[Y]): (Double, Vector[Double], Double) = { 119 | 120 | val phi = dissolveFunctions.featureFn _ 121 | val maxOracle = dissolveFunctions.oracleFn _ 122 | val lossFn = dissolveFunctions.lossFn _ 123 | val classWeight = dissolveFunctions.classWeights _ 124 | 125 | val w: Vector[Double] = model.getWeights() 126 | val ell: Double = model.getEll() 127 | 128 | val n: Int = dataSize.toInt 129 | val d: Int = model.getWeights().size 130 | 131 | var (w_s, ell_s) = data.map { 132 | case datapoint => 133 | val yStar = maxOracle(model, datapoint.pattern, datapoint.label) 134 | val w_s = (phi(datapoint.pattern, datapoint.label) - phi(datapoint.pattern, yStar))*classWeight(datapoint.label) 135 | val ell_s = lossFn(yStar, datapoint.label)*classWeight(datapoint.label) 136 | 137 | (w_s, ell_s) 138 | }.fold((Vector.zeros[Double](d), 0.0)) { 139 | case ((w_acc, ell_acc), (w_i, ell_i)) => 140 | (w_acc + w_i, ell_acc + ell_i) 141 | } 142 | 143 | w_s = w_s / (lambda * n) 144 | ell_s = ell_s / n 145 | 146 | val gap: Double = w.t * (w - w_s) * lambda - ell + ell_s 147 | 148 | (gap, w_s, ell_s) 149 | } 150 | 151 | /** 152 | * Primal objective. 153 | * Requires one full pass of decoding over all data examples. 154 | */ 155 | def primalObjective[X, Y](data: Seq[LabeledObject[X, Y]], 156 | dissolveFunctions: DissolveFunctions[X, Y], 157 | model: StructSVMModel[X, Y], 158 | lambda: Double): Double = { 159 | 160 | val featureFn = dissolveFunctions.featureFn _ 161 | val oracleFn = dissolveFunctions.oracleFn _ 162 | val lossFn = dissolveFunctions.lossFn _ 163 | val classWeight = dissolveFunctions.classWeights _ 164 | 165 | var hingeLosses: Double = 0.0 166 | for (i <- 0 until data.size) { 167 | val yStar_i = oracleFn(model, data(i).pattern, data(i).label) 168 | val loss_i = lossFn(yStar_i, data(i).label)*classWeight(data(i).label) 169 | val psi_i = featureFn(data(i).pattern, data(i).label) - featureFn(data(i).pattern, yStar_i)*classWeight(data(i).label) 170 | 171 | val hingeloss_i = loss_i - model.getWeights().t * psi_i 172 | // println("loss_i = %f, other_loss = %f".format(loss_i, model.getWeights().t * psi_i)) 173 | // assert(hingeloss_i >= 0.0) 174 | 175 | hingeLosses += hingeloss_i 176 | } 177 | 178 | // Compute the primal and return it 179 | 0.5 * lambda * (model.getWeights.t * model.getWeights) + hingeLosses / data.size 180 | 181 | } 182 | 183 | case class DataEval(gap: Double, 184 | avgDelta: Double, 185 | avgHLoss: Double) 186 | 187 | case class PartialTrainDataEval(sum_w_s: Vector[Double], 188 | sum_ell_s: Double, 189 | sum_Delta: Double, 190 | sum_HLoss: Double) { 191 | 192 | def +(that: PartialTrainDataEval): PartialTrainDataEval = { 193 | 194 | val sum_w_s = this.sum_w_s + that.sum_w_s 195 | val sum_ell_s: Double = this.sum_ell_s + that.sum_ell_s 196 | val sum_Delta: Double = this.sum_Delta + that.sum_Delta 197 | val sum_HLoss: Double = this.sum_HLoss + that.sum_HLoss 198 | 199 | PartialTrainDataEval(sum_w_s, 200 | sum_ell_s, 201 | sum_Delta, 202 | sum_HLoss) 203 | 204 | } 205 | } 206 | /** 207 | * Makes an additional pass over the data to compute the following: 208 | * 1. Duality Gap 209 | * 2. Average \Delta 210 | * 3. Average Structured Hinge Loss 211 | * 4. Average Per-Class pixel-loss 212 | * 5. Global loss 213 | */ 214 | def trainDataEval[X, Y](data: RDD[LabeledObject[X, Y]], 215 | dissolveFunctions: DissolveFunctions[X, Y], 216 | model: StructSVMModel[X, Y], 217 | lambda: Double, 218 | dataSize: Int)(implicit m: ClassTag[Y]): DataEval = { 219 | 220 | val phi = dissolveFunctions.featureFn _ 221 | val maxOracle = dissolveFunctions.oracleFn _ 222 | val lossFn = dissolveFunctions.lossFn _ 223 | val predictFn = dissolveFunctions.predictFn _ 224 | val classWeight = dissolveFunctions.classWeights _ 225 | 226 | val n: Int = dataSize.toInt 227 | val d: Int = model.getWeights().size 228 | 229 | val bcModel: Broadcast[StructSVMModel[X, Y]] = data.context.broadcast(model) 230 | 231 | val initEval = 232 | PartialTrainDataEval(DenseVector.zeros[Double](d), 233 | 0.0, 234 | 0.0, 235 | 0.0) 236 | 237 | val partialEval = data.map { 238 | case datapoint => 239 | /** 240 | * Gap and Structured HingeLoss 241 | */ 242 | val lossAug_yStar = maxOracle(bcModel.value, datapoint.pattern, datapoint.label) 243 | val w_s = (phi(datapoint.pattern, datapoint.label) - phi(datapoint.pattern, lossAug_yStar))*classWeight(datapoint.label) 244 | val ell_s = lossFn(datapoint.label, lossAug_yStar)*classWeight(datapoint.label) 245 | val lossAug_wFeatureDotProduct = lossFn(datapoint.label, lossAug_yStar) - 246 | (bcModel.value.getWeights().t * (phi(datapoint.pattern, datapoint.label) 247 | - phi(datapoint.pattern, lossAug_yStar))) 248 | val structuredHingeLoss: Double = lossAug_wFeatureDotProduct 249 | 250 | /** 251 | * \Delta 252 | */ 253 | val predict_yStar = predictFn(bcModel.value, datapoint.pattern) 254 | val loss = lossFn(datapoint.label, predict_yStar) 255 | 256 | /** 257 | * Per-class loss 258 | */ 259 | val y_truth = datapoint.label 260 | val y_predicted = predict_yStar 261 | 262 | PartialTrainDataEval(w_s, 263 | ell_s, 264 | loss, 265 | structuredHingeLoss) 266 | 267 | }.reduce(_ + _) 268 | 269 | val w: Vector[Double] = model.getWeights() 270 | val ell: Double = model.getEll() 271 | 272 | // Gap 273 | val sum_w_s = partialEval.sum_w_s 274 | val w_s = sum_w_s / (lambda * n) 275 | val ell_s = partialEval.sum_ell_s / n 276 | val gap: Double = w.t * (w - w_s) * lambda - ell + ell_s 277 | 278 | // Loss 279 | val avgLoss = partialEval.sum_Delta / n 280 | val avgHLoss = partialEval.sum_HLoss / n 281 | 282 | DataEval(gap, 283 | avgLoss, 284 | avgHLoss) 285 | } 286 | 287 | /** 288 | * Get Spark's properties 289 | */ 290 | def getSparkConfString(sc: SparkConf): String = { 291 | val keys = List("spark.app.name", "spark.executor.memory", "spark.task.cpus", "spark.local.dir", "spark.default.parallelism") 292 | val sb: StringBuilder = new StringBuilder() 293 | 294 | for (key <- keys) 295 | sb ++= "# %s=%s\n".format(key, sc.get(key, "NA")) 296 | 297 | sb.toString() 298 | } 299 | 300 | } -------------------------------------------------------------------------------- /dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/regression/LabeledObject.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.regression 2 | 3 | import org.apache.spark.mllib.regression.LabeledPoint 4 | 5 | import breeze.linalg._ 6 | 7 | case class LabeledObject[X, Y]( 8 | val label: Y, 9 | val pattern: X) extends Serializable { 10 | } -------------------------------------------------------------------------------- /dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/utils/cli/CLAParser.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.utils.cli 2 | 3 | import ch.ethz.dalab.dissolve.optimization._ 4 | import scopt.OptionParser 5 | 6 | object CLAParser { 7 | 8 | def getParser(): OptionParser[Config] = 9 | new scopt.OptionParser[Config]("spark-submit ... ") { 10 | head("dissolve^struct", "0.1-SNAPSHOT") 11 | 12 | help("help") text ("prints this usage text") 13 | 14 | opt[Double]("lambda") action { (x, c) => 15 | c.copy(lambda = x) 16 | } text ("Regularization constant. Default = 0.01") 17 | 18 | opt[Int]("randseed") action { (x, c) => 19 | c.copy(randSeed = x) 20 | } text ("Random Seed. Default = 42") 21 | 22 | opt[Unit]("linesearch") action { (_, c) => 23 | c.copy(lineSearch = true) 24 | } text ("Enable Line Search. Default = false") 25 | 26 | opt[Double]("samplefrac") action { (x, c) => 27 | c.copy(sampleFrac = x) 28 | } text ("Fraction of original dataset to be sampled in each round. Default = 0.5") 29 | 30 | opt[Int]("minpart") action { (x, c) => 31 | c.copy(minPartitions = x) 32 | } text ("Repartition data RDD to given number of partitions before training. Default = Auto") 33 | 34 | opt[Unit]("sparse") action { (_, c) => 35 | c.copy(sparse = true) 36 | } text ("Maintain vectors as sparse vectors. Default = Dense.") 37 | 38 | opt[Int]("oraclecachesize") action { (x, c) => 39 | c.copy(oracleCacheSize = x) 40 | } text ("Oracle Cache Size (caching answers of the maximization oracle for this datapoint). Default = Disabled") 41 | 42 | opt[Int]("cpfreq") action { (x, c) => 43 | c.copy(checkpointFreq = x) 44 | } text ("Checkpoint Frequency (in rounds). Default = 50") 45 | 46 | opt[String]("stopcrit") action { (x, c) => 47 | x match { 48 | case "round" => 49 | c.copy(stoppingCriterion = RoundLimitCriterion) 50 | case "gap" => 51 | c.copy(stoppingCriterion = GapThresholdCriterion) 52 | case "time" => 53 | c.copy(stoppingCriterion = TimeLimitCriterion) 54 | } 55 | } validate { x => 56 | x match { 57 | case "round" => success 58 | case "gap" => success 59 | case "time" => success 60 | case _ => failure("Stopping criterion has to be one of: round | gap | time") 61 | } 62 | } text ("Stopping Criterion. (round | gap | time). Default = round") 63 | 64 | opt[Int]("roundlimit") action { (x, c) => 65 | c.copy(roundLimit = x) 66 | } text ("Round Limit. Default = 25") 67 | 68 | opt[Double]("gapthresh") action { (x, c) => 69 | c.copy(gapThreshold = x) 70 | } text ("Gap Threshold. Default = 0.1") 71 | 72 | opt[Int]("gapcheck") action { (x, c) => 73 | c.copy(gapCheck = x) 74 | } text ("Checks for gap every these many rounds. Default = 25") 75 | 76 | opt[Int]("timelimit") action { (x, c) => 77 | c.copy(timeLimit = x) 78 | } text ("Time Limit (in secs). Default = 300 secs") 79 | 80 | opt[Unit]("debug") action { (_, c) => 81 | c.copy(debug = true) 82 | } text ("Enable debugging. Default = false") 83 | 84 | opt[Int]("debugmult") action { (x, c) => 85 | c.copy(debugMultiplier = x) 86 | } text ("Frequency of debugging. Obtains gap, train and test errors. Default = 1") 87 | 88 | opt[String]("debugfile") action { (x, c) => 89 | c.copy(debugPath = x) 90 | } text ("Path to debug file. Default = current-dir") 91 | 92 | opt[Map[String, String]]("kwargs") valueName ("k1=v1,k2=v2...") action { (x, c) => 93 | c.copy(kwargs = x) 94 | } text ("other arguments") 95 | } 96 | 97 | def argsToOptions[X, Y](args: Array[String]): (SolverOptions[X, Y], Map[String, String]) = 98 | 99 | getParser().parse(args, Config()) match { 100 | case Some(config) => 101 | val solverOptions: SolverOptions[X, Y] = new SolverOptions[X, Y]() 102 | // Copy all config parameters to a Solver Options instance 103 | solverOptions.lambda = config.lambda 104 | solverOptions.randSeed = config.randSeed 105 | solverOptions.doLineSearch = config.lineSearch 106 | solverOptions.doWeightedAveraging = config.wavg 107 | 108 | solverOptions.sampleFrac = config.sampleFrac 109 | if (config.minPartitions > 0) { 110 | solverOptions.enableManualPartitionSize = true 111 | solverOptions.NUM_PART = config.minPartitions 112 | } 113 | solverOptions.sparse = config.sparse 114 | 115 | if (config.oracleCacheSize > 0) { 116 | solverOptions.enableOracleCache = true 117 | solverOptions.oracleCacheSize = config.oracleCacheSize 118 | } 119 | 120 | solverOptions.checkpointFreq = config.checkpointFreq 121 | 122 | solverOptions.stoppingCriterion = config.stoppingCriterion 123 | solverOptions.roundLimit = config.roundLimit 124 | solverOptions.gapCheck = config.gapCheck 125 | solverOptions.gapThreshold = config.gapThreshold 126 | solverOptions.timeLimit = config.timeLimit 127 | 128 | solverOptions.debug = config.debug 129 | solverOptions.debugMultiplier = config.debugMultiplier 130 | solverOptions.debugInfoPath = config.debugPath 131 | 132 | (solverOptions, config.kwargs) 133 | 134 | case None => 135 | // No options passed. Do nothing. 136 | val solverOptions: SolverOptions[X, Y] = new SolverOptions[X, Y]() 137 | val kwargs = Map[String, String]() 138 | 139 | (solverOptions, kwargs) 140 | } 141 | 142 | def main(args: Array[String]): Unit = { 143 | val foo = argsToOptions(args) 144 | println(foo._1.toString()) 145 | println(foo._2) 146 | } 147 | 148 | } -------------------------------------------------------------------------------- /dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/utils/cli/Config.scala: -------------------------------------------------------------------------------- 1 | package ch.ethz.dalab.dissolve.utils.cli 2 | 3 | import ch.ethz.dalab.dissolve.optimization._ 4 | import java.io.File 5 | 6 | case class Config( 7 | 8 | // BCFW parameters 9 | lambda: Double = 0.01, 10 | randSeed: Int = 42, 11 | lineSearch: Boolean = false, 12 | wavg: Boolean = false, 13 | 14 | // dissolve^struct parameters 15 | sampleFrac: Double = 0.5, 16 | minPartitions: Int = 0, 17 | sparse: Boolean = false, 18 | 19 | // Oracle 20 | oracleCacheSize: Int = 0, 21 | 22 | // Spark 23 | checkpointFreq: Int = 50, 24 | 25 | // Stopping criteria 26 | stoppingCriterion: StoppingCriterion = RoundLimitCriterion, 27 | // A - RoundLimit 28 | roundLimit: Int = 25, 29 | // B - Gap Check 30 | gapThreshold: Double = 0.1, 31 | gapCheck: Int = 10, 32 | // C - Time Limit 33 | timeLimit: Int = 300, // (In seconds) 34 | 35 | // Debug parameters 36 | debug: Boolean = false, 37 | debugMultiplier: Int = 1, 38 | debugPath: String = new File(".", "debug-%d.csv".format(System.currentTimeMillis())).getAbsolutePath, 39 | 40 | // Other parameters 41 | kwargs: Map[String, String] = Map()) -------------------------------------------------------------------------------- /helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dalab/dissolve-struct/67f37377b74c32cf05d8f43a0e3658a10864f9bf/helpers/__init__.py -------------------------------------------------------------------------------- /helpers/benchmark_runner.py: -------------------------------------------------------------------------------- 1 | """Given experimental parameters, runs the required experiments and obtains the data 2 | """ 3 | import argparse 4 | import ConfigParser 5 | import datetime 6 | import os 7 | 8 | from benchmark_utils import * 9 | 10 | VALID_PARAMS = {"lambda", "minpart", "samplefrac", "oraclesize"} 11 | VAL_PARAMS = {"lambda", "minpart", "samplefrac", "oraclesize", "stopcrit", "roundlimit", "gaplimit", "gapcheck", 12 | "timelimit", "debugmult"} 13 | BOOL_PARAMS = {"sparse", "debug", "linesearch"} 14 | 15 | WDIR = "/home/ec2-user" # Working directory 16 | 17 | 18 | def str_to_bool(s): 19 | if s in ['True', 'true']: 20 | return True 21 | elif s in ['False', 'false']: 22 | return False 23 | else: 24 | raise ValueError("Boolean value in config '%s' unrecognized") 25 | 26 | 27 | def main(): 28 | parser = argparse.ArgumentParser(description='Run benchmark') 29 | parser.add_argument("identity_file", help="SSH private key to log into spark nodes") 30 | parser.add_argument("master_uri", help="URI of master node") 31 | parser.add_argument("expt_config", help="Experimental config file") 32 | args = parser.parse_args() 33 | 34 | master_host = args.master_uri 35 | identity_file = args.identity_file 36 | 37 | def ssh_spark(command, user="root", cwd=WDIR): 38 | command = "source /root/.bash_profile; cd %s; %s" % (cwd, command) 39 | ssh(master_host, user, command, identity_file) 40 | 41 | # Check if setup has been executed 42 | ssh_spark("if [ ! -f /home/ec2-user/onesmallstep ]; then echo \"Run benchmark_setup and try again\"; exit 1; fi", 43 | cwd=WDIR) 44 | 45 | dtf = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S") 46 | appname_format = "{dtf}-{expt_name}-R{rep_num}-{param}-{paramval}" 47 | spark_submit_cmd_format = ("{spark_submit} " 48 | "--jars {lib_jar_path} " 49 | "--class \"{class_name}\" " 50 | "{spark_args} " 51 | "{examples_jar_path} " 52 | "{solver_options_args} " 53 | "--kwargs k1=v1,{app_args}") 54 | 55 | config = ConfigParser.ConfigParser() 56 | config.read(args.expt_config) 57 | 58 | expt_name = config.get("general", "experiment_name") 59 | class_name = config.get("general", "class_name") 60 | num_repetitions = config.getint("general", "repetitions") 61 | 62 | # Pivot values 63 | pivot_param = config.get("pivot", "param") 64 | assert (pivot_param in VALID_PARAMS) 65 | pivot_values_raw = config.get("pivot", "values") 66 | pivot_values = map(lambda x: x.strip(), pivot_values_raw.split(",")) 67 | 68 | # Paths 69 | examples_jar_path = config.get("paths", "examples_jar_path") 70 | spark_dir = config.get("paths", "spark_dir") 71 | spark_submit_path = os.path.join(spark_dir, "bin", "spark-submit") 72 | hdfs_input_path = config.get("paths", "hdfs_input_path") 73 | 74 | local_output_dir = config.get("paths", "local_output_dir") 75 | local_output_expt_dir = os.path.join(local_output_dir, "%s %s" % (expt_name, dtf)) 76 | if not os.path.exists(local_output_expt_dir): 77 | os.makedirs(local_output_expt_dir) 78 | 79 | dissolve_lib_jar_path = config.get("paths", "lib_jar_path") 80 | scopt_jar_path = "/root/.ivy2/cache/com.github.scopt/scopt_2.10/jars/scopt_2.10-3.3.0.jar" 81 | lib_jar_path = ','.join([dissolve_lib_jar_path, scopt_jar_path]) 82 | 83 | for rep_num in range(1, num_repetitions + 1): 84 | print "===========================" 85 | print "====== Repetition %d ======" % rep_num 86 | print "===========================" 87 | for pivot_val in pivot_values: 88 | print "=== %s = %s ===" % (pivot_param, pivot_val) 89 | ''' 90 | Construct command to execute on spark cluster 91 | ''' 92 | appname = appname_format.format(dtf=dtf, 93 | expt_name=expt_name, 94 | rep_num=rep_num, 95 | param=pivot_param, 96 | paramval=pivot_val) 97 | 98 | # === Construct Spark arguments === 99 | spark_args = ','.join(["--%s %s" % (k, v) for k, v in config.items("spark_args")]) 100 | 101 | # === Construct Solver Options arguments === 102 | valued_parameter_args = ' '.join( 103 | ["--%s %s" % (k, v) for k, v in config.items("parameters") if k in VAL_PARAMS]) 104 | boolean_parameter_args = ' '.join( 105 | ["--%s" % k for k, v in config.items("parameters") if k in BOOL_PARAMS and str_to_bool(v)]) 106 | valued_dissolve_args = ' '.join( 107 | ["--%s %s" % (k, v) for k, v in config.items("dissolve_args") if k in VAL_PARAMS]) 108 | boolean_dissolve_args = ' '.join( 109 | ["--%s" % k for k, v in config.items("dissolve_args") if k in BOOL_PARAMS and str_to_bool(v)]) 110 | 111 | # === Add the pivotal parameter === 112 | assert (pivot_param not in config.options("parameters")) 113 | pivot_param_arg = "--%s %s" % (pivot_param, pivot_val) 114 | 115 | solver_options_args = ' '.join( 116 | [valued_parameter_args, boolean_parameter_args, valued_dissolve_args, boolean_dissolve_args, 117 | pivot_param_arg]) 118 | 119 | # == Construct App-specific arguments === 120 | debug_filename = "%s.csv" % appname 121 | debug_file_path = os.path.join(WDIR, debug_filename) 122 | default_app_args = ("appname={appname}," 123 | "input_path={input_path}," 124 | "debug_file={debug_file_path}").format(appname=appname, 125 | input_path=hdfs_input_path, 126 | debug_file_path=debug_file_path) 127 | extra_app_args = ','.join(["%s=%s" % (k, v) for k, v in config.items("app_args")]) 128 | 129 | app_args = ','.join([default_app_args, extra_app_args]) 130 | 131 | spark_submit_cmd = spark_submit_cmd_format.format(spark_submit=spark_submit_path, 132 | lib_jar_path=lib_jar_path, 133 | class_name=class_name, 134 | examples_jar_path=examples_jar_path, 135 | spark_args=spark_args, 136 | solver_options_args=solver_options_args, 137 | app_args=app_args) 138 | 139 | ''' 140 | Execute Command 141 | ''' 142 | print "Executing on %s:\n%s" % (master_host, spark_submit_cmd) 143 | ssh_spark(spark_submit_cmd) 144 | 145 | ''' 146 | Obtain required files 147 | ''' 148 | scp_from(master_host, identity_file, "root", debug_file_path, local_output_expt_dir) 149 | 150 | ''' 151 | Perform clean-up 152 | ''' 153 | 154 | 155 | if __name__ == '__main__': 156 | main() -------------------------------------------------------------------------------- /helpers/benchmark_setup.py: -------------------------------------------------------------------------------- 1 | """Prepare the EC2 cluster for dissolve^struct experiments. 2 | 3 | Given a cluster created using spark-ec2 setup, this script will: 4 | - retrieve the datasets and places them into HDFS 5 | - build the required packages and place them into appropriate folders 6 | - setup the execution environment 7 | 8 | Reuses ssh code from AmpLab's Big Data Benchmarks 9 | """ 10 | import argparse 11 | import os 12 | 13 | from benchmark_utils import * 14 | 15 | WDIR = "/home/ec2-user" # Working directory 16 | 17 | 18 | def main(): 19 | parser = argparse.ArgumentParser(description='Setup benchmark cluster') 20 | parser.add_argument("identity_file", help="SSH private key to log into spark nodes") 21 | parser.add_argument("master_uri", help="URI of master node") 22 | args = parser.parse_args() 23 | 24 | master_host = args.master_uri 25 | identity_file = args.identity_file 26 | 27 | dissolve_dir = os.path.join(WDIR, "dissolve-struct") 28 | dissolve_lib_dir = os.path.join(dissolve_dir, "dissolve-struct-lib") 29 | dissolve_examples_dir = os.path.join(dissolve_dir, "dissolve-struct-examples") 30 | 31 | def ssh_spark(command, user="root", cwd=WDIR): 32 | command = "source /root/.bash_profile; cd %s; %s" % (cwd, command) 33 | ssh(master_host, user, command, identity_file) 34 | 35 | # === Install all required dependencies === 36 | # sbt 37 | ssh_spark("curl https://bintray.com/sbt/rpm/rpm | sudo tee /etc/yum.repos.d/bintray-sbt-rpm.repo") 38 | ssh_spark("yum install sbt -y") 39 | # python pip 40 | ssh_spark("yum install python27 -y") 41 | ssh_spark("yum install python-pip -y") 42 | 43 | # === Checkout git repo === 44 | ssh_spark("git clone https://github.com/dalab/dissolve-struct.git %s" % dissolve_dir) 45 | 46 | # === Build packages === 47 | # Build lib 48 | # Jar location: 49 | # /root/.ivy2/local/ch.ethz.dalab/dissolvestruct_2.10/0.1-SNAPSHOT/jars/dissolvestruct_2.10.jar 50 | ssh_spark("sbt publish-local", cwd=dissolve_lib_dir) 51 | # Build examples 52 | # Jar location: 53 | # /home/ec2-user/dissolve-struct/dissolve-struct-examples/target/scala-2.10/dissolvestructexample_2.10-0.1-SNAPSHOT.jar 54 | ssh_spark("sbt package", cwd=dissolve_examples_dir) 55 | 56 | # === Data setup === 57 | # Install pip dependencies 58 | ssh_spark("pip install -r requirements.txt", cwd=dissolve_dir) 59 | 60 | # Execute data retrieval script 61 | ssh_spark("python helpers/retrieve_datasets.py -d", cwd=dissolve_dir) 62 | 63 | # === Setup environment === 64 | conf_dir = os.path.join(dissolve_examples_dir, "conf") 65 | ssh_spark("cp -r %s %s" % (conf_dir, WDIR)) 66 | 67 | # === Move data to HDFS === 68 | data_dir = os.path.join(dissolve_dir, "data") 69 | ssh_spark("/root/ephemeral-hdfs/bin/hadoop fs -put %s data" % data_dir) 70 | 71 | # === Create a file to mark everything is setup === 72 | ssh_spark("touch onesmallstep", cwd=WDIR) 73 | 74 | 75 | if __name__ == '__main__': 76 | main() 77 | -------------------------------------------------------------------------------- /helpers/benchmark_utils.py: -------------------------------------------------------------------------------- 1 | """Common utils by benchmark scripts 2 | """ 3 | import subprocess 4 | 5 | 6 | def ssh(host, username, command, identity_file=None): 7 | if not identity_file: 8 | subprocess.check_call( 9 | "ssh -t -o StrictHostKeyChecking=no %s@%s '%s'" % 10 | (username, host, command), shell=True) 11 | else: 12 | subprocess.check_call( 13 | "ssh -t -o StrictHostKeyChecking=no -i %s %s@%s '%s'" % 14 | (identity_file, username, host, command), shell=True) 15 | 16 | 17 | # Copy a file to a given host through scp, throwing an exception if scp fails 18 | def scp_to(host, identity_file, username, local_file, remote_file): 19 | subprocess.check_call( 20 | "scp -q -o StrictHostKeyChecking=no -i %s '%s' '%s@%s:%s'" % 21 | (identity_file, local_file, username, host, remote_file), shell=True) 22 | 23 | 24 | # Copy a file to a given host through scp, throwing an exception if scp fails 25 | def scp_from(host, identity_file, username, remote_file, local_file): 26 | subprocess.check_call( 27 | "scp -q -o StrictHostKeyChecking=no -i %s '%s@%s:%s' '%s'" % 28 | (identity_file, username, host, remote_file, local_file), shell=True) -------------------------------------------------------------------------------- /helpers/brutus_runner.py: -------------------------------------------------------------------------------- 1 | """Given experimental parameters, runs the required experiments and obtains the data. 2 | To be executed on the Hadoop main node. 3 | """ 4 | import argparse 5 | import ConfigParser 6 | import datetime 7 | import re 8 | 9 | from benchmark_utils import * 10 | from paths import * 11 | 12 | VALID_PARAMS = {"lambda", "minpart", "samplefrac", "oraclesize", "num-executors"} 13 | VAL_PARAMS = {"lambda", "minpart", "samplefrac", "oraclesize", "stopcrit", "roundlimit", "gaplimit", "gapcheck", "gapthresh", 14 | "timelimit", "debugmult"} 15 | BOOL_PARAMS = {"sparse", "debug", "linesearch"} 16 | 17 | HOME_DIR = os.getenv("HOME") 18 | PROJ_DIR = os.path.join(HOME_DIR, "dissolve-struct") 19 | 20 | DEFAULT_CORES = 4 21 | 22 | 23 | def execute(command, cwd=PROJ_DIR): 24 | subprocess.check_call(command, cwd=cwd, shell=True) 25 | 26 | 27 | def str_to_bool(s): 28 | if s in ['True', 'true']: 29 | return True 30 | elif s in ['False', 'false']: 31 | return False 32 | else: 33 | raise ValueError("Boolean value in config '%s' unrecognized") 34 | 35 | 36 | def main(): 37 | parser = argparse.ArgumentParser(description='Run benchmark') 38 | parser.add_argument("expt_config", help="Experimental config file") 39 | parser.add_argument("--ds", help="Run with debugging separately. Forces execution of two spark jobs", 40 | action='store_true') 41 | args = parser.parse_args() 42 | 43 | # Check if setup has been executed 44 | touchfile_path = os.path.join(HOME_DIR, 'onesmallstep') 45 | execute("if [ ! -f %s ]; then echo \"Run benchmark_setup and try again\"; exit 1; fi" % touchfile_path, 46 | cwd=HOME_DIR) 47 | 48 | dtf = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S") 49 | appname_format = "{dtf}-{expt_name}-{param}-{paramval}" 50 | spark_submit_cmd_format = ("{spark_submit} " 51 | "--jars {lib_jar_path} " 52 | "--class \"{class_name}\" " 53 | "{spark_args} " 54 | "{examples_jar_path} " 55 | "{solver_options_args} " 56 | "--kwargs k1=v1,{app_args}") 57 | 58 | config = ConfigParser.ConfigParser() 59 | config.read(args.expt_config) 60 | 61 | expt_name = config.get("general", "experiment_name") 62 | class_name = config.get("general", "class_name") 63 | 64 | if args.ds: 65 | assert ('debug' in config.options('dissolve_args')) 66 | # Want debug set to true, in case of double-debugging 67 | assert (str_to_bool(config.get('dissolve_args', 'debug'))) 68 | 69 | # Pivot values 70 | pivot_param = config.get("pivot", "param") 71 | assert (pivot_param in VALID_PARAMS) 72 | pivot_values_raw = config.get("pivot", "values") 73 | pivot_values = map(lambda x: x.strip(), pivot_values_raw.split(",")) 74 | 75 | # Paths 76 | examples_jar_path = EXAMPLES_JAR_PATH 77 | spark_submit_path = "spark-submit" 78 | input_path = config.get("paths", "input_path") 79 | 80 | output_dir = EXPT_OUTPUT_DIR 81 | local_output_expt_dir = os.path.join(output_dir, "%s_%s" % (expt_name, dtf)) 82 | if not os.path.exists(local_output_expt_dir): 83 | os.makedirs(local_output_expt_dir) 84 | 85 | dissolve_lib_jar_path = LIB_JAR_PATH 86 | scopt_jar_path = SCOPT_JAR_PATH 87 | lib_jar_path = ','.join([dissolve_lib_jar_path, scopt_jar_path]) 88 | 89 | ''' 90 | Execute experiment 91 | ''' 92 | for pivot_val in pivot_values: 93 | print "=== %s = %s ===" % (pivot_param, pivot_val) 94 | ''' 95 | Construct command to execute on spark cluster 96 | ''' 97 | appname = appname_format.format(dtf=dtf, 98 | expt_name=expt_name, 99 | param=pivot_param, 100 | paramval=pivot_val) 101 | 102 | # === Construct Solver Options arguments === 103 | valued_parameter_args = ' '.join( 104 | ["--%s %s" % (k, v) for k, v in config.items("parameters") if k in VAL_PARAMS and k not in ['minpart']]) 105 | # Treat 'minpart' as a special case. If minpart = 'auto', set minpart = num_cores * num_executors 106 | if 'minpart' in config.options('parameters'): 107 | if config.get('parameters', 'minpart') == 'auto': 108 | if pivot_param == 'num-executors': 109 | num_executors = int(pivot_val) 110 | else: 111 | num_executors = config.getint('spark_args', 'num-executors') 112 | minpart = DEFAULT_CORES * num_executors 113 | else: 114 | minpart = config.getint('parameters', 'minpart') 115 | minpart_arg = '--minpart %d' % minpart 116 | valued_parameter_args = ' '.join([valued_parameter_args, minpart_arg]) 117 | boolean_parameter_args = ' '.join( 118 | ["--%s" % k for k, v in config.items("parameters") if k in BOOL_PARAMS and str_to_bool(v)]) 119 | valued_dissolve_args = ' '.join( 120 | ["--%s %s" % (k, v) for k, v in config.items("dissolve_args") if k in VAL_PARAMS]) 121 | boolean_dissolve_args = ' '.join( 122 | ["--%s" % k for k, v in config.items("dissolve_args") if k in BOOL_PARAMS and str_to_bool(v)]) 123 | 124 | solver_options_args = ' '.join( 125 | [valued_parameter_args, boolean_parameter_args, valued_dissolve_args, boolean_dissolve_args]) 126 | 127 | # === Construct Spark arguments === 128 | spark_args = ' '.join(["--%s %s" % (k, v) for k, v in config.items("spark_args")]) 129 | 130 | # === Add the pivotal parameter === 131 | assert (pivot_param not in config.options("parameters")) 132 | assert (pivot_param not in config.options("spark_args")) 133 | pivot_param_arg = "--%s %s" % (pivot_param, pivot_val) 134 | 135 | # Is this pivotal parameters a spark argument or a dissolve argument? 136 | if pivot_param in ['num-executors', ]: 137 | spark_args = ' '.join([spark_args, pivot_param_arg]) 138 | else: 139 | solver_options_args = ' '.join([solver_options_args, pivot_param_arg]) 140 | 141 | # == Construct App-specific arguments === 142 | debug_filename = "%s.csv" % appname 143 | debug_file_path = os.path.join('', debug_filename) 144 | default_app_args = ("appname={appname}," 145 | "input_path={input_path}," 146 | "debug_file={debug_file_path}").format(appname=appname, 147 | input_path=input_path, 148 | debug_file_path=debug_file_path) 149 | extra_app_args = ','.join(["%s=%s" % (k, v) for k, v in config.items("app_args")]) 150 | 151 | app_args = ','.join([default_app_args, extra_app_args]) 152 | 153 | spark_submit_cmd = spark_submit_cmd_format.format(spark_submit=spark_submit_path, 154 | lib_jar_path=lib_jar_path, 155 | class_name=class_name, 156 | examples_jar_path=examples_jar_path, 157 | spark_args=spark_args, 158 | solver_options_args=solver_options_args, 159 | app_args=app_args) 160 | 161 | ''' 162 | Execute Command 163 | ''' 164 | print "Executing:\n%s" % spark_submit_cmd 165 | execute(spark_submit_cmd) 166 | 167 | ''' 168 | If enabled, execute command again, but without the debug flag 169 | ''' 170 | if args.ds: 171 | no_debug_appname = appname + '.no_debug' 172 | debug_filename = "%s.csv" % no_debug_appname 173 | debug_file_path = os.path.join('', debug_filename) 174 | default_app_args = ("appname={appname}," 175 | "input_path={input_path}," 176 | "debug_file={debug_file_path}").format(appname=no_debug_appname, 177 | input_path=input_path, 178 | debug_file_path=debug_file_path) 179 | 180 | extra_app_args = ','.join(["%s=%s" % (k, v) for k, v in config.items("app_args")]) 181 | 182 | app_args = ','.join([default_app_args, extra_app_args]) 183 | 184 | # Get rid of the debugging flag 185 | solver_options_args = re.sub(' --debug$', ' ', solver_options_args) 186 | solver_options_args = re.sub(' --debug ', ' ', solver_options_args) 187 | 188 | no_debug_spark_submit_cmd = spark_submit_cmd_format.format(spark_submit=spark_submit_path, 189 | lib_jar_path=lib_jar_path, 190 | class_name=class_name, 191 | examples_jar_path=examples_jar_path, 192 | spark_args=spark_args, 193 | solver_options_args=solver_options_args, 194 | app_args=app_args) 195 | print "Executing WITHOUT debugging:\n%s" % no_debug_spark_submit_cmd 196 | execute(no_debug_spark_submit_cmd) 197 | 198 | 199 | if __name__ == '__main__': 200 | main() -------------------------------------------------------------------------------- /helpers/brutus_sample.cfg: -------------------------------------------------------------------------------- 1 | [general] 2 | experiment_name: cov_binary_k2 3 | class_name: ch.ethz.dalab.dissolve.examples.binaryclassification.COVBinary 4 | 5 | [paths] 6 | input_path: /user/torekond/data/generated/covtype.libsvm.binary.scale 7 | 8 | [parameters] 9 | ; numeric parameters 10 | lambda: 0.01 11 | minpart: auto 12 | samplefrac: 0.5 13 | oraclesize: 0 14 | ; boolean parameters 15 | sparse: false 16 | 17 | 18 | [pivot] 19 | ; param is one of: lambda, minpart, samplefrac, oraclesize, num-executors 20 | param: num-executors 21 | values: 4, 8, 16, 32 22 | 23 | 24 | [dissolve_args] 25 | stopcrit: round 26 | roundlimit: 25 27 | debug: true 28 | debugmult: 10 29 | 30 | 31 | [spark_args] 32 | driver-memory: 2G 33 | executor-memory: 7G 34 | 35 | [app_args] 36 | ; Any key-value pairs mentioned here are sent as --kwargs k1=v1,k2=v2 37 | foo:bar 38 | -------------------------------------------------------------------------------- /helpers/brutus_setup.py: -------------------------------------------------------------------------------- 1 | '''Setup dissolve^struct environment on Brutus. (To be executed ON Brutus in the home directory.) 2 | ''' 3 | __author__ = 'tribhu' 4 | 5 | import os 6 | import argparse 7 | 8 | from benchmark_utils import * 9 | 10 | from retrieve_datasets import retrieve, download_to_gen_dir 11 | from paths import PROJECT_DIR, JARS_DIR, DATA_DIR 12 | 13 | LIB_JAR_URL = 'https://dl.dropboxusercontent.com/u/12851272/dissolvestruct_2.10.jar' 14 | EXAMPLES_JAR_PATH = 'https://dl.dropboxusercontent.com/u/12851272/dissolvestructexample_2.10-0.1-SNAPSHOT.jar' 15 | SCOPT_JAR_PATH = 'https://dl.dropboxusercontent.com/u/12851272/scopt_2.10-3.3.0.jar' 16 | 17 | HOME_DIR = os.getenv("HOME") 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser( 22 | description='Setup Brutus cluster for dissolve^struct. Needs to be executed on Brutus login node.') 23 | parser.add_argument("username", help="Username (which has access to Hadoop Cluster)") 24 | args = parser.parse_args() 25 | 26 | username = args.username 27 | 28 | def ssh_brutus(command): 29 | ssh('hadoop', username, command) 30 | 31 | # === Obtain data === 32 | retrieve(download_all=True) 33 | 34 | # === Obtain the jars === 35 | jars_dir = JARS_DIR 36 | if not os.path.exists(jars_dir): 37 | os.makedirs(jars_dir) 38 | 39 | print "=== Downloading executables to: ", jars_dir, "===" 40 | for jar_url in [LIB_JAR_URL, EXAMPLES_JAR_PATH, SCOPT_JAR_PATH]: 41 | print "== Retrieving ", jar_url.split('/')[-1], '==' 42 | download_to_gen_dir(jar_url, jars_dir) 43 | 44 | # === Move data to HDFS === 45 | print "=== Moving data to HDFS ===" 46 | put_data_cmd = "hadoop fs -put -f %s /user/%s/data" % (DATA_DIR, username) 47 | ssh_brutus(put_data_cmd) 48 | 49 | # === Create a file to mark everything is setup === 50 | touch_file_path = os.path.join(HOME_DIR, 'onesmallstep') 51 | ssh_brutus("touch %s" % touch_file_path) 52 | 53 | 54 | if __name__ == '__main__': 55 | main() -------------------------------------------------------------------------------- /helpers/buildall.py: -------------------------------------------------------------------------------- 1 | '''Builds the lib and examples packages. Then, syncs to Dropbox 2 | ''' 3 | import os 4 | import subprocess 5 | import shutil 6 | 7 | from paths import PROJECT_DIR 8 | 9 | SYNC_JARS = True 10 | 11 | HOME_DIR = os.getenv("HOME") 12 | 13 | OUTPUT_DIR = os.path.join(HOME_DIR, 'Dropbox/Public/') 14 | 15 | LIB_JAR_PATH = os.path.join(HOME_DIR, 16 | '.ivy2/local/ch.ethz.dalab/dissolvestruct_2.10/0.1-SNAPSHOT/jars/dissolvestruct_2.10.jar') 17 | EXAMPLES_JAR_PATH = os.path.join(PROJECT_DIR, 'dissolve-struct-examples', 'target/scala-2.10/', 18 | 'dissolvestructexample_2.10-0.1-SNAPSHOT.jar') 19 | SCOPT_JAR_PATH = os.path.join(HOME_DIR, '.ivy2/cache/com.github.scopt/scopt_2.10/jars/scopt_2.10-3.3.0.jar') 20 | 21 | 22 | def execute(command, cwd='.'): 23 | subprocess.check_call(command, cwd=cwd) 24 | 25 | 26 | def main(): 27 | dissolve_lib_dir = os.path.join(PROJECT_DIR, 'dissolve-struct-lib') 28 | dissolve_examples_dir = os.path.join(PROJECT_DIR, 'dissolve-struct-examples') 29 | 30 | # Build lib package 31 | print "=== Building dissolve-struct-lib ===" 32 | lib_build_cmd = ["sbt", "publish-local"] 33 | execute(lib_build_cmd, cwd=dissolve_lib_dir) 34 | 35 | # Build examples package 36 | print "=== Building dissolve-struct-examples ===" 37 | examples_build_cmd = ["sbt", "package"] 38 | execute(examples_build_cmd, cwd=dissolve_examples_dir) 39 | 40 | 41 | # Sync all packages to specified output directory 42 | if SYNC_JARS: 43 | print "=== Syncing Jars to Dropbox Public Folder ===" 44 | print LIB_JAR_PATH 45 | shutil.copy(LIB_JAR_PATH, OUTPUT_DIR) 46 | print EXAMPLES_JAR_PATH 47 | shutil.copy(EXAMPLES_JAR_PATH, OUTPUT_DIR) 48 | print SCOPT_JAR_PATH 49 | shutil.copy(SCOPT_JAR_PATH, OUTPUT_DIR) 50 | 51 | 52 | if __name__ == '__main__': 53 | main() -------------------------------------------------------------------------------- /helpers/ocr_helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io 3 | import os 4 | from paths import * 5 | 6 | FILTER_FOLD = True 7 | TST_FOLD_NUMS = [0,] # Datapoints from these folds will be treated as test dataset 8 | 9 | def convert_ocr_data(): 10 | idx_trn = 0 11 | idx_tst = 0 12 | 13 | ocr_mat_path = os.path.join(DATA_DIR, 'ocr.mat') 14 | patterns_train_path = os.path.join(GEN_DATA_DIR, 'patterns_train.csv') 15 | patterns_test_path = os.path.join(GEN_DATA_DIR, 'patterns_test.csv') 16 | labels_train_path = os.path.join(GEN_DATA_DIR, 'labels_train.csv') 17 | labels_test_path = os.path.join(GEN_DATA_DIR, 'labels_test.csv') 18 | folds_train_path = os.path.join(GEN_DATA_DIR, 'folds_train.csv') 19 | folds_test_path = os.path.join(GEN_DATA_DIR, 'folds_test.csv') 20 | 21 | print "Processing features available in %s" % ocr_mat_path 22 | 23 | mat = scipy.io.loadmat(ocr_mat_path, struct_as_record=False, squeeze_me=True) 24 | n = np.shape(mat['dataset'])[0] 25 | with open(patterns_train_path, 'w') as fpat_trn, open(labels_train_path, 'w') as flab_trn, open(folds_train_path, 'w') as ffold_trn, \ 26 | open(patterns_test_path, 'w') as fpat_tst, open(labels_test_path, 'w') as flab_tst, open(folds_test_path, 'w') as ffold_tst: 27 | for i in range(n): 28 | ### Write folds 29 | fold = mat['dataset'][i].__dict__['fold'] 30 | if fold in TST_FOLD_NUMS: 31 | fpat, flab, ffold = fpat_tst, flab_tst, ffold_tst 32 | idx_tst += 1 33 | idx = idx_tst 34 | else: 35 | fpat, flab, ffold = fpat_trn, flab_trn, ffold_trn 36 | idx_trn += 1 37 | idx = idx_trn 38 | # FORMAT: id,fold 39 | ffold.write('%d,%d\n' % (idx, fold)) 40 | 41 | ### Write patterns (x_i's) 42 | pixels = mat['dataset'][i].__dict__['pixels'] 43 | num_letters = np.shape(pixels)[0] 44 | letter_shape = np.shape(pixels[0]) 45 | # Create a matrix of size num_pixels+bias_var x num_letters 46 | xi = np.zeros((letter_shape[0] * letter_shape[1] + 1, num_letters)) 47 | for letter_id in range(num_letters): 48 | letter = pixels[letter_id] # Returns a 16x8 matrix 49 | xi[:, letter_id] = np.append(letter.flatten(order='F'), [1.]) 50 | # Vectorize the above matrix and store it 51 | # After flattening, order is column-major 52 | xi_str = ','.join([`s` for s in xi.flatten('F')]) 53 | # FORMAT: id,#rows,#cols,x_0_0,x_0_1,...x_n_m 54 | fpat.write('%d,%d,%d,%s\n' % (idx, np.shape(xi)[0], np.shape(xi)[1], xi_str)) 55 | 56 | ### Write labels (y_i's) 57 | labels = mat['dataset'][i].__dict__['word'] 58 | labels_str = ','.join([`a` for a in labels]) 59 | # FORMAT: id,#letters,letter_0,letter_1,...letter_n 60 | flab.write('%d,%d,%s\n' % (idx, np.shape(labels)[0], labels_str)) 61 | 62 | 63 | def main(): 64 | convert_ocr_data() 65 | 66 | if __name__ == '__main__': 67 | main() 68 | -------------------------------------------------------------------------------- /helpers/paths.py: -------------------------------------------------------------------------------- 1 | """Keeps track of paths used by other submodules 2 | """ 3 | import os 4 | 5 | current_file_path = os.path.abspath(__file__) 6 | PROJECT_DIR = os.path.join(os.path.dirname(current_file_path), os.pardir) 7 | 8 | DATA_DIR = os.path.join(PROJECT_DIR, 'data') 9 | GEN_DATA_DIR = os.path.join(DATA_DIR, 'generated') 10 | 11 | CHAIN_OCR_FILE = os.path.join(DATA_DIR, "ocr.mat") 12 | 13 | # If an sbt build cannot be performed, jars can be accessed here 14 | JARS_DIR = os.path.join(PROJECT_DIR, 'jars') 15 | LIB_JAR_PATH = os.path.join(JARS_DIR, 'dissolvestruct_2.10.jar') 16 | EXAMPLES_JAR_PATH = os.path.join(JARS_DIR, 'dissolvestructexample_2.10-0.1-SNAPSHOT.jar') 17 | SCOPT_JAR_PATH = os.path.join(JARS_DIR, 'scopt_2.10-3.3.0.jar') 18 | 19 | # Output dir. Any output produced in a subdirectory within this folder. 20 | EXPT_OUTPUT_DIR = os.path.join(PROJECT_DIR, 'benchmark-data') 21 | 22 | 23 | 24 | # URLS for DATASETS 25 | A1A_URL = "http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/a1a" 26 | 27 | COV_BIN_URL = "http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/covtype.libsvm.binary.scale.bz2" 28 | 29 | COV_MULT_URL = "http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/covtype.scale.bz2" 30 | 31 | RCV1_URL = "http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/rcv1_train.binary.bz2" 32 | 33 | MSRC_URL = "https://s3-eu-west-1.amazonaws.com/dissolve-struct/msrc/msrc.tar.gz" 34 | 35 | chain_files = ["folds_test.csv", 36 | "folds_train.csv", 37 | "patterns_train.csv", 38 | "patterns_test.csv", 39 | "labels_train.csv", 40 | "labels_test.csv", 41 | ] 42 | s3_chain_base_url = "https://s3-eu-west-1.amazonaws.com/dissolve-struct/chain" 43 | CHAIN_URLS = [os.path.join(s3_chain_base_url, fname) for fname in chain_files] 44 | 45 | -------------------------------------------------------------------------------- /helpers/retrieve_datasets.py: -------------------------------------------------------------------------------- 1 | """Retrieves datasets from various sources 2 | """ 3 | import urllib 4 | import subprocess 5 | from paths import * 6 | import argparse 7 | 8 | 9 | def decompress(filename): 10 | print "Decompressing: ", filename 11 | try: 12 | subprocess.check_call(['bzip2', '-d', filename]) 13 | except subprocess.CalledProcessError as e: 14 | pass 15 | 16 | 17 | def download_to_gen_dir(url, dir=GEN_DATA_DIR): 18 | print "Downloading: ", url 19 | basename = os.path.basename(url) 20 | destname = os.path.join(dir, basename) 21 | urllib.urlretrieve(url, destname) 22 | return destname 23 | 24 | 25 | def download_and_decompress(url): 26 | destname = download_to_gen_dir(url) 27 | decompress(destname) 28 | 29 | 30 | def retrieve(download_all=False): 31 | # Retrieve the files 32 | print "=== A1A ===" 33 | download_to_gen_dir(A1A_URL) 34 | 35 | print "=== COV BINARY ===" 36 | download_and_decompress(COV_BIN_URL) 37 | 38 | print "=== COV MULTICLASS ===" 39 | download_and_decompress(COV_MULT_URL) 40 | 41 | print "=== RCV1 ===" 42 | download_and_decompress(RCV1_URL) 43 | 44 | print "=== CHAIN ===" 45 | if download_all: 46 | for url in CHAIN_URLS: 47 | download_to_gen_dir(url) 48 | else: 49 | import ocr_helpers 50 | ocr_helpers.convert_ocr_data() 51 | 52 | 53 | def main(): 54 | parser = argparse.ArgumentParser(description='Retrieve datasets for dissolve^struct') 55 | parser.add_argument("-d", "--download", action="store_true", 56 | help="Download files instead of processing when possible") 57 | args = parser.parse_args() 58 | 59 | retrieve(args.download) 60 | 61 | 62 | if __name__ == '__main__': 63 | main() 64 | -------------------------------------------------------------------------------- /paper/dissolve-jmlr-software.bbl: -------------------------------------------------------------------------------- 1 | \begin{thebibliography}{6} 2 | \providecommand{\natexlab}[1]{#1} 3 | \providecommand{\url}[1]{\texttt{#1}} 4 | \expandafter\ifx\csname urlstyle\endcsname\relax 5 | \providecommand{\doi}[1]{doi: #1}\else 6 | \providecommand{\doi}{doi: \begingroup \urlstyle{rm}\Url}\fi 7 | 8 | \bibitem[Jaggi et~al.(2014)Jaggi, Smith, Tak{\'a}{\v c}, Terhorst, Krishnan, 9 | Hofmann, and Jordan]{Jaggi:2014vi} 10 | Martin Jaggi, Virginia Smith, Martin Tak{\'a}{\v c}, Jonathan Terhorst, Sanjay 11 | Krishnan, Thomas Hofmann, and Michael~I Jordan. 12 | \newblock {Communication-Efficient Distributed Dual Coordinate Ascent}. 13 | \newblock In \emph{NIPS 2014 - Advances in Neural Information Processing 14 | Systems 27}, pages 3068--3076, 2014. 15 | 16 | \bibitem[Lacoste-Julien et~al.(2013)Lacoste-Julien, Jaggi, Schmidt, and 17 | Pletscher]{LacosteJulien:2013ue} 18 | Simon Lacoste-Julien, Martin Jaggi, Mark Schmidt, and Patrick Pletscher. 19 | \newblock {Block-Coordinate Frank-Wolfe Optimization for Structural SVMs}. 20 | \newblock In \emph{ICML 2013 - Proceedings of the 30th International Conference 21 | on Machine Learning}, 2013. 22 | 23 | \bibitem[Ratliff et~al.(2007)Ratliff, Bagnell, and Zinkevich]{Ratliff:2007ti} 24 | Nathan~D Ratliff, J~Andrew Bagnell, and Martin~A Zinkevich. 25 | \newblock {(Online) Subgradient Methods for Structured Prediction}. 26 | \newblock In \emph{AISTATS}, 2007. 27 | 28 | \bibitem[Shalev-Shwartz et~al.(2010)Shalev-Shwartz, Singer, Srebro, and 29 | Cotter]{ShalevShwartz:2010cg} 30 | Shai Shalev-Shwartz, Yoram Singer, Nathan Srebro, and Andrew Cotter. 31 | \newblock {Pegasos: Primal Estimated Sub-Gradient Solver for SVM}. 32 | \newblock \emph{Mathematical Programming}, 127\penalty0 (1):\penalty0 3--30, 33 | October 2010. 34 | 35 | \bibitem[Taskar et~al.(2003)Taskar, Guestrin, and Koller]{Taskar:2003tt} 36 | Ben Taskar, Carlos Guestrin, and Daphne Koller. 37 | \newblock {Max-Margin Markov Networks}. 38 | \newblock In \emph{NIPS 2014 - Advances in Neural Information Processing 39 | Systems 27}, 2003. 40 | 41 | \bibitem[Tsochantaridis et~al.(2005)Tsochantaridis, Joachims, Hofmann, and 42 | Altun]{Tsochantaridis:2005ww} 43 | Ioannis Tsochantaridis, Thorsten Joachims, Thomas Hofmann, and Yasemin Altun. 44 | \newblock {Large Margin Methods for Structured and Interdependent Output 45 | Variables}. 46 | \newblock \emph{The Journal of Machine Learning Research}, 6, December 2005. 47 | 48 | \end{thebibliography} 49 | -------------------------------------------------------------------------------- /paper/dissolve-jmlr-software.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dalab/dissolve-struct/67f37377b74c32cf05d8f43a0e3658a10864f9bf/paper/dissolve-jmlr-software.pdf -------------------------------------------------------------------------------- /paper/references.bib: -------------------------------------------------------------------------------- 1 | % 2 | 3 | @inproceedings{LacosteJulien:2013ue, 4 | author = {Lacoste-Julien, Simon and Jaggi, Martin and Schmidt, Mark and Pletscher, Patrick}, 5 | title = {{Block-Coordinate Frank-Wolfe Optimization for Structural SVMs}}, 6 | booktitle = {ICML 2013 - Proceedings of the 30th International Conference on Machine Learning}, 7 | year = {2013} 8 | } 9 | 10 | @article{Lucchi:2015co, 11 | author = {Lucchi, Aurelien and Marquez-Neila, Pablo and Becker, Carlos and Li, Yunpeng and Smith, Kevin and Knott, Graham and Fua, Pascal}, 12 | title = {{Learning Structured Models for Segmentation of 2D and 3D Imagery}}, 13 | journal = {IEEE Transactions on Medical Imaging}, 14 | year = {2015}, 15 | pages = {1096--1110} 16 | } 17 | 18 | @inproceedings{Ma:2015ti, 19 | author = {Ma, Chenxin and Smith, Virginia and Jaggi, Martin and Jordan, Michael I and Richt{\'a}rik, Peter and Tak{\'a}{\v c}, Martin}, 20 | title = {{Adding vs. Averaging in Distributed Primal-Dual Optimization}}, 21 | booktitle = {ICML 2015 - Proceedings of the 32th International Conference on Machine Learning}, 22 | year = {2015}, 23 | pages = {1973--1982} 24 | } 25 | 26 | @inproceedings{Jaggi:2014vi, 27 | author = {Jaggi, Martin and Smith, Virginia and Tak{\'a}{\v c}, Martin and Terhorst, Jonathan and Krishnan, Sanjay and Hofmann, Thomas and Jordan, Michael I}, 28 | title = {{Communication-Efficient Distributed Dual Coordinate Ascent}}, 29 | booktitle = {NIPS 2014 - Advances in Neural Information Processing Systems 27}, 30 | year = {2014}, 31 | pages = {3068--3076} 32 | } 33 | 34 | @article{Tsochantaridis:2004tg, 35 | author = {Tsochantaridis, Ioannis and Hofmann, Thomas and Joachims, Thorsten and Altun, Yasemin}, 36 | title = {{Support vector machine learning for interdependent and structured output spaces}}, 37 | journal = {ICML '04: Proceedings of the twenty-first international conference on Machine learning}, 38 | year = {2004}, 39 | month = jul, 40 | annote = {original SVMstruct paper. also says how you get the Crammer/Singer multiclass as a special case, by copying the features} 41 | } 42 | 43 | @article{Joachims:2009ex, 44 | author = {Joachims, Thorsten and Finley, Thomas and Yu, Chun-Nam John}, 45 | title = {{Cutting-Plane Training of Structural SVMs}}, 46 | year = {2009}, 47 | volume = {77}, 48 | number = {1}, 49 | pages = {27--59}, 50 | month = oct 51 | } 52 | 53 | @inproceedings{Taskar:2003tt, 54 | author = {Taskar, Ben and Guestrin, Carlos and Koller, Daphne}, 55 | title = {{Max-Margin Markov Networks}}, 56 | booktitle = {NIPS 2014 - Advances in Neural Information Processing Systems 27}, 57 | year = {2003}, 58 | } 59 | 60 | @article{Tsochantaridis:2005ww, 61 | author = {Tsochantaridis, Ioannis and Joachims, Thorsten and Hofmann, Thomas and Altun, Yasemin}, 62 | title = {{Large Margin Methods for Structured and Interdependent Output Variables}}, 63 | journal = {The Journal of Machine Learning Research}, 64 | year = {2005}, 65 | volume = {6}, 66 | month = dec, 67 | } 68 | 69 | @inproceedings{Ratliff:2007ti, 70 | author = {Ratliff, Nathan D and Bagnell, J Andrew and Zinkevich, Martin A}, 71 | title = {{(Online) Subgradient Methods for Structured Prediction}}, 72 | booktitle = {AISTATS}, 73 | year = {2007} 74 | } 75 | 76 | @article{ShalevShwartz:2010cg, 77 | author = {Shalev-Shwartz, Shai and Singer, Yoram and Srebro, Nathan and Cotter, Andrew}, 78 | title = {{Pegasos: Primal Estimated Sub-Gradient Solver for SVM}}, 79 | journal = {Mathematical Programming}, 80 | year = {2010}, 81 | volume = {127}, 82 | number = {1}, 83 | pages = {3--30}, 84 | month = oct 85 | } 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.9.2 2 | scipy==0.15.1 3 | --------------------------------------------------------------------------------