├── .gitignore
├── .travis.yml
├── LICENSE
├── README.md
├── data
├── debug
│ └── debugWeights
│ │ ├── matlab-weights-1.csv
│ │ ├── matlab-weights-2.csv
│ │ ├── matlab-weights-3.csv
│ │ ├── matlab-weights-4.csv
│ │ └── matlab-weights-5.csv
├── generated
│ └── .gitignore
├── ocr.mat
├── regression-tests
│ ├── bcfw-base-cache.csv
│ ├── bcfw-base-wavg.csv
│ ├── bcfw-base.csv
│ ├── dissolve-base-cache.csv
│ ├── dissolve-base-wavg.csv
│ ├── dissolve-base.csv
│ ├── dissolve-frac-parts-wavg-cache.csv
│ ├── dissolve-frac-parts-wavg.csv
│ ├── dissolve-frac-parts.csv
│ └── dissolve-frac.csv
└── retrieve_datasets.sh
├── debug
└── .gitignore
├── dissolve-struct-application
├── build.sbt
├── project
│ └── plugins.sbt
└── src
│ └── main
│ └── scala
│ └── ch
│ └── ethz
│ └── dalab
│ └── dissolve
│ └── app
│ └── DSApp.scala
├── dissolve-struct-examples
├── build.sbt
├── conf
│ └── log4j.properties
├── lib
│ └── .gitignore
├── project
│ └── plugins.sbt
└── src
│ ├── main
│ ├── resources
│ │ ├── adj.txt
│ │ ├── chain_test.csv
│ │ ├── chain_train.csv
│ │ ├── imageseg_cattle_test.txt
│ │ ├── imageseg_cattle_train.txt
│ │ ├── imageseg_colormap.txt
│ │ ├── imageseg_lab_freq.txt
│ │ ├── imageseg_label_color_map.txt
│ │ ├── imageseg_test.txt
│ │ ├── imageseg_train.txt
│ │ └── noun.txt
│ └── scala
│ │ └── ch
│ │ └── ethz
│ │ └── dalab
│ │ └── dissolve
│ │ └── examples
│ │ ├── binaryclassification
│ │ ├── AdultBinary.scala
│ │ ├── BinaryClassificationDemo.scala
│ │ ├── COVBinary.scala
│ │ └── RCV1Binary.scala
│ │ ├── chain
│ │ ├── ChainBPDemo.scala
│ │ └── ChainDemo.scala
│ │ ├── imageseg
│ │ ├── ImageSeg.scala
│ │ ├── ImageSegRunner.scala
│ │ ├── ImageSegTypes.scala
│ │ └── ImageSegUtils.scala
│ │ ├── multiclass
│ │ └── COVMulticlass.scala
│ │ └── utils
│ │ └── ExampleUtils.scala
│ └── test
│ └── scala
│ └── ch
│ └── ethz
│ └── dalab
│ └── dissolve
│ └── diagnostics
│ ├── FeatureFnSpec.scala
│ ├── OracleSpec.scala
│ ├── StructLossSpec.scala
│ └── UnitSpec.scala
├── dissolve-struct-lib
├── .gitignore
├── build.sbt
├── lib
│ └── jython.jar
├── project
│ ├── assembly.sbt
│ └── plugins.sbt
└── src
│ └── main
│ └── scala
│ └── ch
│ └── ethz
│ └── dalab
│ └── dissolve
│ ├── classification
│ ├── BinarySVMWithDBCFW.scala
│ ├── BinarySVMWithSSG.scala
│ ├── ClassificationUtils.scala
│ ├── MultiClassSVMWithDBCFW.scala
│ ├── StructSVMModel.scala
│ ├── StructSVMWithBCFW.scala
│ ├── StructSVMWithDBCFW.scala
│ ├── StructSVMWithMiniBatch.scala
│ ├── StructSVMWithSSG.scala
│ └── Types.scala
│ ├── optimization
│ ├── BCFWSolver.scala
│ ├── DBCFWSolverTuned.scala
│ ├── DissolveFunctions.scala
│ ├── SSGSolver.scala
│ ├── SolverOptions.scala
│ └── SolverUtils.scala
│ ├── regression
│ └── LabeledObject.scala
│ └── utils
│ └── cli
│ ├── CLAParser.scala
│ └── Config.scala
├── helpers
├── __init__.py
├── benchmark_runner.py
├── benchmark_setup.py
├── benchmark_utils.py
├── brutus_runner.py
├── brutus_sample.cfg
├── brutus_setup.py
├── buildall.py
├── ocr_helpers.py
├── paths.py
└── retrieve_datasets.py
├── paper
├── dissolve-jmlr-software.bbl
├── dissolve-jmlr-software.pdf
├── dissolve-jmlr-software.tex
├── jmlr2e.sty
└── references.bib
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # use glob syntax.
2 | syntax: glob
3 | *.ser
4 | *.class
5 | *~
6 | *.bak
7 | #*.off
8 | *.old
9 |
10 | # eclipse conf file
11 | .settings
12 | .classpath
13 | .project
14 | .manager
15 | .scala_dependencies
16 |
17 | # idea
18 | .idea
19 | *.iml
20 |
21 | # building
22 | target
23 | build
24 | null
25 | tmp*
26 | temp*
27 | dist
28 | test-output
29 | build.log
30 |
31 | # other scm
32 | .svn
33 | .CVS
34 | .hg*
35 |
36 | # Mac stuff
37 | .DS_Store
38 |
39 | .metadata/
40 | _site/
41 | Gemfile.lock
42 |
43 | # switch to regexp syntax.
44 | # syntax: regexp
45 | # ^\.pc/
46 |
47 | #SHITTY output not in target directory
48 | build.log
49 | *.pyc
50 |
51 | # checkpoint files generated by Spark
52 | checkpoint-files/*
53 |
54 | # Custom stuff
55 | *.cache
56 | *.log
57 |
58 | debug/*
59 | data/generated/*
60 |
61 | dissolve-struct-lib/lib_managed/*
62 |
63 | dissolve-struct-examples/spark-1.*
64 |
65 | # Experiment related scripts/data
66 | conf/*
67 | expt-data/*
68 | figures/*
69 | *.ipynb
70 | setup_cluster.sh
71 | run.sh
72 |
73 | # Data for regression testing
74 | data/debug/*
75 |
76 | # ec2 credentials
77 | ec2_config.json
78 |
79 | # New eclipse introduced some temp files
80 | .cache-main
81 | .cache-tests
82 |
83 | # tribhu's current directory
84 | benchmark-data/
85 | jars/
86 |
87 | % latex
88 | *.out
89 | *.synctex.gz
90 | *.blg
91 | *.aux
92 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: scala
2 | scala:
3 | - 2.10.4
4 |
5 | # Build only the lib package, and exclude examples package from build
6 | before_install:
7 | - cd dissolve-struct-lib
8 |
9 | # whitelist
10 | branches:
11 | only:
12 | - master
13 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://travis-ci.org/dalab/dissolve-struct)
2 | [](https://github.com/dalab/dissolve-struct/releases)
3 |
4 | dissolvestruct
5 | ===========
6 |
7 | Distributed solver library for structured output prediction, based on Spark.
8 |
9 | The library is based on the primal-dual BCFW solver, allowing approximate inference oracles, and distributes this algorithm using the recent communication efficient CoCoA scheme.
10 | The interface to the user is the same as in the widely used SVMstruct in the single machine case.
11 |
12 | For more information, checkout the [project page](http://dalab.github.io/dissolve-struct/)
13 |
14 |
--------------------------------------------------------------------------------
/data/generated/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 |
3 | !.gitignore
--------------------------------------------------------------------------------
/data/ocr.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dalab/dissolve-struct/67f37377b74c32cf05d8f43a0e3658a10864f9bf/data/ocr.mat
--------------------------------------------------------------------------------
/data/regression-tests/bcfw-base-cache.csv:
--------------------------------------------------------------------------------
1 | # BCFW
2 | # numPasses=5
3 | # doWeightedAveraging=false
4 | # randSeed=42
5 | # sample=perm
6 | # lambda=0.010000
7 | # doLineSearch=true
8 | # enableManualPartitionSize=false
9 | # NUM_PART=1
10 | # enableOracleCache=true
11 | # oracleCacheSize=10
12 | # H=5
13 | # sampleFrac=0.500000
14 | # sampleWithReplacement=false
15 | # debugInfoPath=/Users/tribhu/git/DBCFWstruct/debug/debug-bcfw-1416414185362.csv
16 |
17 | round,time,iter,primal,dual,gap,train_error,test_error
18 | 1,0.761000,312,0.740438,0.004261,0.736177,0.289614,0.382340
19 | 2,1.756000,624,0.619814,0.005358,0.614455,0.220483,0.339086
20 | 3,2.816000,936,0.547373,0.006040,0.541333,0.186888,0.315781
21 | 4,3.741000,1248,0.502321,0.006678,0.495643,0.167291,0.294808
22 | 5,4.692000,1560,0.461032,0.007274,0.453758,0.158045,0.284424
23 |
--------------------------------------------------------------------------------
/data/regression-tests/bcfw-base-wavg.csv:
--------------------------------------------------------------------------------
1 | # BCFW
2 | # numPasses=5
3 | # doWeightedAveraging=true
4 | # randSeed=42
5 | # sample=perm
6 | # lambda=0.010000
7 | # doLineSearch=true
8 | # enableManualPartitionSize=false
9 | # NUM_PART=1
10 | # enableOracleCache=false
11 | # oracleCacheSize=10
12 | # H=5
13 | # sampleFrac=0.500000
14 | # sampleWithReplacement=false
15 | # debugInfoPath=/Users/tribhu/git/DBCFWstruct/debug/debug-bcfw-1416413767561.csv
16 |
17 | round,time,iter,primal,dual,gap,train_error,test_error
18 | 1,0.757000,312,0.649247,0.003156,0.646091,0.232588,0.329987
19 | 2,1.908000,624,0.511866,0.005451,0.506414,0.170380,0.286416
20 | 3,2.888000,936,0.445621,0.007233,0.438388,0.143429,0.267820
21 | 4,3.850000,1248,0.402138,0.008722,0.393416,0.124673,0.258407
22 | 5,4.823000,1560,0.369806,0.010022,0.359784,0.114160,0.253438
23 |
--------------------------------------------------------------------------------
/data/regression-tests/bcfw-base.csv:
--------------------------------------------------------------------------------
1 | # BCFW
2 | # numPasses=5
3 | # doWeightedAveraging=false
4 | # randSeed=42
5 | # sample=perm
6 | # lambda=0.010000
7 | # doLineSearch=true
8 | # enableManualPartitionSize=false
9 | # NUM_PART=1
10 | # enableOracleCache=false
11 | # oracleCacheSize=10
12 | # H=5
13 | # sampleFrac=0.500000
14 | # sampleWithReplacement=false
15 | # debugInfoPath=/Users/tribhu/git/DBCFWstruct/debug/debug-bcfw-1416401133519.csv
16 |
17 | round,time,iter,primal,dual,gap,train_error,test_error
18 | 1,0.807000,312,0.740438,0.004261,0.736177,0.289614,0.382340
19 | 2,2.497000,624,0.587597,0.007060,0.580537,0.207465,0.336020
20 | 3,3.623000,936,0.482679,0.009151,0.473528,0.151328,0.292496
21 | 4,4.749000,1248,0.493073,0.010945,0.482128,0.153231,0.305247
22 | 5,5.732000,1560,0.411257,0.012449,0.398809,0.122657,0.288465
23 |
--------------------------------------------------------------------------------
/data/regression-tests/dissolve-base-cache.csv:
--------------------------------------------------------------------------------
1 | # numPasses=5
2 | # doWeightedAveraging=false
3 | # randSeed=42
4 | # sample=frac
5 | # lambda=0.010000
6 | # doLineSearch=true
7 | # enableManualPartitionSize=true
8 | # NUM_PART=1
9 | # enableOracleCache=true
10 | # oracleCacheSize=10
11 | # H=5
12 | # sampleFrac=1.000000
13 | # sampleWithReplacement=false
14 | # debugInfoPath=/Users/tribhu/git/DBCFWstruct/debug/debug-dissolve-1416414162557.csv
15 |
16 | # spark.app.name=Chain-DBCFW
17 | # spark.executor.memory=NA
18 | # spark.task.cpus=NA
19 | # spark.local.dir=NA
20 | # spark.default.parallelism=NA
21 |
22 | # indexedTrainDataRDD.partitions.size=1
23 | # indexedPrimalsRDD.partitions.size=1
24 | round,time,primal,dual,gap,train_error,test_error
25 | 1,5,0.740438,0.004261,0.736177,0.289614,0.382340
26 | 2,8,0.619814,0.005358,0.614455,0.220483,0.339086
27 | 3,11,0.547373,0.006040,0.541333,0.186888,0.315781
28 | 4,13,0.502321,0.006678,0.495643,0.167291,0.294808
29 | 5,16,0.461032,0.007274,0.453758,0.158045,0.284424
30 |
--------------------------------------------------------------------------------
/data/regression-tests/dissolve-base-wavg.csv:
--------------------------------------------------------------------------------
1 | # numPasses=5
2 | # doWeightedAveraging=true
3 | # randSeed=42
4 | # sample=frac
5 | # lambda=0.010000
6 | # doLineSearch=true
7 | # enableManualPartitionSize=true
8 | # NUM_PART=1
9 | # enableOracleCache=false
10 | # oracleCacheSize=10
11 | # H=5
12 | # sampleFrac=1.000000
13 | # sampleWithReplacement=false
14 | # debugInfoPath=/Users/tribhu/git/DBCFWstruct/debug/debug-dissolve-1416413743225.csv
15 |
16 | # spark.app.name=Chain-DBCFW
17 | # spark.executor.memory=NA
18 | # spark.task.cpus=NA
19 | # spark.local.dir=NA
20 | # spark.default.parallelism=NA
21 |
22 | # indexedTrainDataRDD.partitions.size=1
23 | # indexedPrimalsRDD.partitions.size=1
24 | round,time,primal,dual,gap,train_error,test_error
25 | 1,5,0.649247,0.003156,0.646091,0.232588,0.329987
26 | 2,9,0.511866,0.005451,0.506414,0.170380,0.286416
27 | 3,12,0.445621,0.007233,0.438388,0.143429,0.267820
28 | 4,15,0.402138,0.008722,0.393416,0.124673,0.258407
29 | 5,18,0.369806,0.010022,0.359784,0.114160,0.253438
30 |
--------------------------------------------------------------------------------
/data/regression-tests/dissolve-base.csv:
--------------------------------------------------------------------------------
1 | # numPasses=5
2 | # doWeightedAveraging=false
3 | # randSeed=42
4 | # sample=frac
5 | # lambda=0.010000
6 | # doLineSearch=true
7 | # enableManualPartitionSize=true
8 | # NUM_PART=1
9 | # enableOracleCache=false
10 | # oracleCacheSize=10
11 | # H=5
12 | # sampleFrac=1.000000
13 | # sampleWithReplacement=false
14 | # debugInfoPath=/Users/tribhu/git/DBCFWstruct/debug/debug-dissolve-1416401107250.csv
15 |
16 | # spark.app.name=Chain-DBCFW
17 | # spark.executor.memory=NA
18 | # spark.task.cpus=NA
19 | # spark.local.dir=NA
20 | # spark.default.parallelism=NA
21 |
22 | # indexedTrainDataRDD.partitions.size=1
23 | # indexedPrimalsRDD.partitions.size=1
24 | round,time,primal,dual,gap,train_error,test_error
25 | 1,6,0.740438,0.004261,0.736177,0.289614,0.382340
26 | 2,8,0.587597,0.007060,0.580537,0.207465,0.336020
27 | 3,11,0.482679,0.009151,0.473528,0.151328,0.292496
28 | 4,14,0.493073,0.010945,0.482128,0.153231,0.305247
29 | 5,17,0.411257,0.012449,0.398809,0.122657,0.288465
30 |
--------------------------------------------------------------------------------
/data/regression-tests/dissolve-frac-parts-wavg-cache.csv:
--------------------------------------------------------------------------------
1 | # numPasses=5
2 | # doWeightedAveraging=true
3 | # randSeed=42
4 | # sample=frac
5 | # lambda=0.010000
6 | # doLineSearch=true
7 | # enableManualPartitionSize=true
8 | # NUM_PART=4
9 | # enableOracleCache=true
10 | # oracleCacheSize=10
11 | # H=5
12 | # sampleFrac=0.500000
13 | # sampleWithReplacement=false
14 | # debugInfoPath=/Users/tribhu/git/DBCFWstruct/debug/debug-dissolve-1416475800888.csv
15 |
16 | # spark.app.name=Chain-DBCFW
17 | # spark.executor.memory=NA
18 | # spark.task.cpus=NA
19 | # spark.local.dir=NA
20 | # spark.default.parallelism=NA
21 |
22 | # indexedTrainDataRDD.partitions.size=4
23 | # indexedPrimalsRDD.partitions.size=4
24 | round,time,primal,dual,gap,train_error,test_error
25 | 1,6,0.949089,0.000614,0.948474,0.423528,0.483519
26 | 2,9,0.904272,0.000989,0.903284,0.380705,0.450028
27 | 3,12,0.879400,0.001221,0.878179,0.360667,0.436719
28 | 4,15,0.852875,0.001393,0.851483,0.342916,0.422211
29 | 5,19,0.826390,0.001539,0.824851,0.329641,0.407550
30 |
--------------------------------------------------------------------------------
/data/regression-tests/dissolve-frac-parts-wavg.csv:
--------------------------------------------------------------------------------
1 | # numPasses=5
2 | # doWeightedAveraging=true
3 | # randSeed=42
4 | # sample=frac
5 | # lambda=0.010000
6 | # doLineSearch=true
7 | # enableManualPartitionSize=true
8 | # NUM_PART=4
9 | # enableOracleCache=false
10 | # oracleCacheSize=10
11 | # H=5
12 | # sampleFrac=0.500000
13 | # sampleWithReplacement=false
14 | # debugInfoPath=/Users/tribhu/git/DBCFWstruct/debug/debug-dissolve-1416475749034.csv
15 |
16 | # spark.app.name=Chain-DBCFW
17 | # spark.executor.memory=NA
18 | # spark.task.cpus=NA
19 | # spark.local.dir=NA
20 | # spark.default.parallelism=NA
21 |
22 | # indexedTrainDataRDD.partitions.size=4
23 | # indexedPrimalsRDD.partitions.size=4
24 | round,time,primal,dual,gap,train_error,test_error
25 | 1,5,0.949089,0.000614,0.948474,0.423528,0.483519
26 | 2,8,0.860342,0.001137,0.859205,0.334412,0.406605
27 | 3,11,0.810682,0.001574,0.809108,0.306459,0.385472
28 | 4,13,0.774036,0.001963,0.772073,0.283006,0.367418
29 | 5,16,0.744513,0.002314,0.742199,0.264717,0.350701
30 |
--------------------------------------------------------------------------------
/data/regression-tests/dissolve-frac-parts.csv:
--------------------------------------------------------------------------------
1 | # numPasses=5
2 | # doWeightedAveraging=false
3 | # randSeed=42
4 | # sample=frac
5 | # lambda=0.010000
6 | # doLineSearch=true
7 | # enableManualPartitionSize=true
8 | # NUM_PART=4
9 | # enableOracleCache=false
10 | # oracleCacheSize=10
11 | # H=5
12 | # sampleFrac=0.500000
13 | # sampleWithReplacement=false
14 | # debugInfoPath=/Users/tribhu/git/DBCFWstruct/debug/debug-dissolve-1416475677521.csv
15 |
16 | # spark.app.name=Chain-DBCFW
17 | # spark.executor.memory=NA
18 | # spark.task.cpus=NA
19 | # spark.local.dir=NA
20 | # spark.default.parallelism=NA
21 |
22 | # indexedTrainDataRDD.partitions.size=4
23 | # indexedPrimalsRDD.partitions.size=4
24 | round,time,primal,dual,gap,train_error,test_error
25 | 1,6,0.942347,0.000898,0.941449,0.421662,0.473316
26 | 2,9,0.842540,0.001582,0.840958,0.337215,0.412921
27 | 3,12,0.785794,0.002142,0.783652,0.302160,0.378707
28 | 4,14,0.749114,0.002636,0.746478,0.282378,0.359638
29 | 5,17,0.712107,0.003064,0.709043,0.255069,0.345087
30 |
--------------------------------------------------------------------------------
/data/regression-tests/dissolve-frac.csv:
--------------------------------------------------------------------------------
1 | # numPasses=5
2 | # doWeightedAveraging=false
3 | # randSeed=42
4 | # sample=frac
5 | # lambda=0.010000
6 | # doLineSearch=true
7 | # enableManualPartitionSize=true
8 | # NUM_PART=1
9 | # enableOracleCache=false
10 | # oracleCacheSize=10
11 | # H=5
12 | # sampleFrac=0.500000
13 | # sampleWithReplacement=false
14 | # debugInfoPath=/Users/tribhu/git/DBCFWstruct/debug/debug-dissolve-1416475480026.csv
15 |
16 | # spark.app.name=Chain-DBCFW
17 | # spark.executor.memory=NA
18 | # spark.task.cpus=NA
19 | # spark.local.dir=NA
20 | # spark.default.parallelism=NA
21 |
22 | # indexedTrainDataRDD.partitions.size=1
23 | # indexedPrimalsRDD.partitions.size=1
24 | round,time,primal,dual,gap,train_error,test_error
25 | 1,4,0.832474,0.002271,0.830203,0.347001,0.416671
26 | 2,7,0.795129,0.003703,0.791426,0.316188,0.402175
27 | 3,9,0.665157,0.004785,0.660371,0.258996,0.334717
28 | 4,11,0.610020,0.005626,0.604394,0.220944,0.333295
29 | 5,13,0.565354,0.006300,0.559054,0.210634,0.317841
30 |
--------------------------------------------------------------------------------
/data/retrieve_datasets.sh:
--------------------------------------------------------------------------------
1 | GEN_DIR="generated"
2 |
3 | if [ ! -d "$GEN_DIR" ]; then
4 | mkdir $GEN_DIR
5 | fi
6 |
7 | cd $GEN_DIR
8 |
9 | # Adult
10 | wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/a1a
11 |
12 | # Forest Cover (Binary)
13 | wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/covtype.libsvm.binary.scale.bz2
14 | bzip2 -d covtype.libsvm.binary.scale.bz2
15 |
16 | # Forest Cover (Multiclass)
17 | wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/covtype.scale.bz2
18 | bzip2 -d covtype.scale.bz2
19 |
20 | # RCV1
21 | wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/rcv1_train.binary.bz2
22 | bzip2 -d rcv1_train.binary.bz2
23 |
24 | # Factorie jar
25 | wget https://github.com/factorie/factorie/releases/download/factorie-1.0/factorie-1.0.jar
26 | mv factorie-1.0.jar ../../dissolve-struct-examples/lib/
27 |
28 | cd ..
29 |
--------------------------------------------------------------------------------
/debug/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 |
3 | !.gitignore
--------------------------------------------------------------------------------
/dissolve-struct-application/build.sbt:
--------------------------------------------------------------------------------
1 | // ---- < 1 > -----------------------------------------------------------------
2 | // Enter your application name, organization and version below
3 | // This information will be used when the binary jar is packaged
4 | name := "DissolveStructApplication"
5 |
6 | organization := "ch.ethz.dalab"
7 |
8 | version := "0.1-SNAPSHOT" // Keep this unchanged for development releases
9 | // ---- 1 > ----------------------------------------------------------------
10 |
11 | scalaVersion := "2.10.4"
12 |
13 | libraryDependencies += "ch.ethz.dalab" %% "dissolvestruct" % "0.1-SNAPSHOT"
14 |
15 | libraryDependencies += "org.scalatest" % "scalatest_2.10" % "2.0" % "test"
16 |
17 | libraryDependencies += "org.apache.spark" %% "spark-core" % "1.4.1"
18 |
19 | libraryDependencies += "org.apache.spark" %% "spark-mllib" % "1.4.1"
20 |
21 | resolvers += "IESL Release" at "http://dev-iesl.cs.umass.edu/nexus/content/groups/public"
22 |
23 | libraryDependencies += "cc.factorie" % "factorie" % "1.0"
24 |
25 | libraryDependencies += "com.github.scopt" %% "scopt" % "3.3.0"
26 |
27 | // ---- < 2 > -----------------------------------------------------------------
28 | // Add additional dependencies in the space provided below, like above.
29 | // Libraries often provide the exact line that needs to be added here on their
30 | // webpage.
31 | // PS: Keep your eyes peeled -- there is a difference between "%%" and "%"
32 |
33 | // libraryDependencies += "organization" %% "application_name" % "version"
34 |
35 | // ---- 2 > ----------------------------------------------------------------
36 |
37 | resolvers += Resolver.sonatypeRepo("public")
38 |
39 | EclipseKeys.createSrc := EclipseCreateSrc.Default + EclipseCreateSrc.Resource
40 |
41 | mergeStrategy in assembly <<= (mergeStrategy in assembly) { (old) =>
42 | {
43 | case PathList("javax", "servlet", xs @ _*) => MergeStrategy.first
44 | case PathList(ps @ _*) if ps.last endsWith ".html" => MergeStrategy.first
45 | case "application.conf" => MergeStrategy.concat
46 | case "reference.conf" => MergeStrategy.concat
47 | case "log4j.properties" => MergeStrategy.discard
48 | case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard
49 | case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard
50 | case _ => MergeStrategy.first
51 | }
52 | }
53 |
54 | test in assembly := {}
55 |
--------------------------------------------------------------------------------
/dissolve-struct-application/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | logLevel := Level.Warn
2 |
3 | addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.5.0")
4 |
5 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.13.0")
6 |
--------------------------------------------------------------------------------
/dissolve-struct-application/src/main/scala/ch/ethz/dalab/dissolve/app/DSApp.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.app
2 |
3 | import org.apache.spark.SparkConf
4 | import org.apache.spark.SparkContext
5 | import org.apache.spark.rdd.RDD
6 |
7 | import breeze.linalg.Vector
8 | import ch.ethz.dalab.dissolve.classification.StructSVMModel
9 | import ch.ethz.dalab.dissolve.classification.StructSVMWithDBCFW
10 | import ch.ethz.dalab.dissolve.optimization.DissolveFunctions
11 | import ch.ethz.dalab.dissolve.optimization.GapThresholdCriterion
12 | import ch.ethz.dalab.dissolve.optimization.SolverOptions
13 | import ch.ethz.dalab.dissolve.regression.LabeledObject
14 |
15 | /**
16 | * This defines the x-part of the training example
17 | * For example, in case of sequence OCR this would be a (d x n) matrix, with
18 | * each column containing the pixel representation of a character
19 | */
20 | case class Pattern() {
21 |
22 | }
23 |
24 | /**
25 | * This defined the y-part of the training example
26 | * Once again, in case of OCR this would be a n-dimensional vector, with the
27 | * i-th element containing label for i-th character of x.
28 | */
29 | case class Label() {
30 |
31 | }
32 |
33 | /**
34 | * This is the core of your Structured SVM application.
35 | * In here, you'll find three functions and a driver program that you'll need
36 | * to fill in to get your application running.
37 | *
38 | * The interface is inspired by SVM^struct by Joachims et al.
39 | * (http://www.cs.cornell.edu/people/tj/svm_light/svm_struct.html)
40 | */
41 | object DSApp extends DissolveFunctions[Pattern, Label] {
42 |
43 | /**
44 | * ============== Joint feature Map: \phi(x, y) ==============
45 | *
46 | * This encodes the complex input-output (x, y) pair in a Vector space (done
47 | * using vectors from the Breeze library)
48 | */
49 | def featureFn(x: Pattern, y: Label): Vector[Double] = {
50 |
51 | // Insert code here for Joint feature map here
52 |
53 | ???
54 | }
55 |
56 | /**
57 | * ============== Structured Loss Function: \Delta(y, y^m) ==============
58 | *
59 | * Loss for predicting instead of .
60 | * This needs to be 0 if ==
61 | */
62 | def lossFn(yPredicted: Label, yTruth: Label): Double = {
63 |
64 | // Insert code for Loss function here
65 |
66 | ???
67 | }
68 |
69 | /**
70 | * ============== Maximization Oracle: H^m(w) ==============
71 | *
72 | * Finds the most violating constraint by solving the loss-augmented decoding
73 | * subproblem.
74 | * This is equivalent to predicting
75 | * y* = argmax_{y} \Delta(y, y^m) + < w, \phi(x^m, y) >
76 | * for some training example (x^m, y^m) and parameters w
77 | *
78 | * Make sure the loss-augmentation is consistent with the \Delta defined above.
79 | *
80 | * By default, the prediction function calls this oracle with y^m = null.
81 | * In which case, the loss-augmentation can be skipped using a simple check
82 | * on y^m.
83 | *
84 | * For examples, or common oracle/decoding functions (like BP Loopy, Viterbi
85 | * or BP on Chain CF) refer to the examples package.
86 | */
87 | def oracleFn(model: StructSVMModel[Pattern, Label], x: Pattern, y: Label): Label = {
88 |
89 | val weightVec = model.weights
90 |
91 | // Insert code for maximization Oracle here
92 |
93 | ???
94 | }
95 |
96 | /**
97 | * ============== Prediction Function ==============
98 | *
99 | * Finds the best output candidate for x, given parameters w.
100 | * This is equivalent to solving:
101 | * y* = argmax_{y} < w, \phi(x^m, y) >
102 | *
103 | * Note that this is very similar to the maximization oracle, but without
104 | * the loss-augmentation. So, by default, we call the oracle function by
105 | * setting y as null.
106 | */
107 | def predictFn(model: StructSVMModel[Pattern, Label], x: Pattern): Label =
108 | oracleFn(model, x, null)
109 |
110 | /**
111 | * ============== Driver ==============
112 | *
113 | * This is the entry point into the program.
114 | * In here, we initialize the SparkContext, set the parameters and call the
115 | * optimization routine.
116 | *
117 | * To begin with the training, we'll need three things:
118 | * a. A SparkContext instance (Defaults provided)
119 | * b. Solver Parameters (Defaults provided)
120 | * c. Data
121 | *
122 | * To execute, you should package this into a jar and provide it using
123 | * spark-submit (http://spark.apache.org/docs/latest/submitting-applications.html).
124 | *
125 | * Alternately, you can right-click and Run As -> Scala Application to run
126 | * within Eclipse.
127 | */
128 | def main(args: Array[String]): Unit = {
129 |
130 | val appname = "DSApp"
131 |
132 | /**
133 | * ============== Initialize Spark ==============
134 | *
135 | * Alternately, use:
136 | * val conf = new SparkConf().setAppName(appname).setMaster("local[4]")
137 | * if you're planning to execute within Eclipse using 4 cores
138 | */
139 | val conf = new SparkConf().setAppName(appname)
140 | val sc = new SparkContext(conf)
141 | sc.setCheckpointDir("checkpoint-files")
142 |
143 | /**
144 | * ============== Set Solver parameters ==============
145 | */
146 | val solverOptions = new SolverOptions[Pattern, Label]()
147 | // Regularization paramater
148 | solverOptions.lambda = 0.01
149 |
150 | // Stopping criterion
151 | solverOptions.stoppingCriterion = GapThresholdCriterion
152 | solverOptions.gapThreshold = 1e-3
153 | solverOptions.gapCheck = 25 // Checks for gap every gapCheck rounds
154 |
155 | // Set the fraction of data to be used in training during each round
156 | // In this case, 50% of the data is uniformly sampled for training at the
157 | // beginning of each round
158 | solverOptions.sampleFrac = 0.5
159 |
160 | // Set how many partitions you want to split the data into.
161 | // These partitions will be local to each machine and the respective dual
162 | // variables associated with these partitions will reside locally.
163 | // Ideally, you want to set this to: #cores x #workers x 2.
164 | // If this is disabled, Spark decides on the partitioning, which be may
165 | // be suboptimal.
166 | solverOptions.enableManualPartitionSize = true
167 | solverOptions.NUM_PART = 8
168 |
169 | // Optionally, you can enable obtaining additional statistics like the
170 | // the training, test errors w.r.t to rounds, along with the gap
171 | // This is expensive as it involves a complete pass through the data.
172 | solverOptions.debug = false
173 | // This computes the statistics every debugMultiplier^i rounds.
174 | // So, in this case, it does so in 1, 2, 4, 8, ...
175 | // Beyond the 50th round, statistics is collected every 10 rounds.
176 | solverOptions.debugMultiplier = 2
177 | // Writes the statistics in CSV format in the provided path
178 | solverOptions.debugInfoPath = "path/to/statistics.csv"
179 |
180 | /**
181 | * ============== Provide Data ==============
182 | */
183 | val trainDataRDD: RDD[LabeledObject[Pattern, Label]] = {
184 |
185 | // Insert code to load TRAIN data here
186 |
187 | ???
188 | }
189 | val testDataRDD: RDD[LabeledObject[Pattern, Label]] = {
190 |
191 | // Insert code to load TEST data here
192 |
193 | ???
194 | }
195 | // Optionally, set to None in case you don't want statistics on test data
196 | solverOptions.testDataRDD = Some(testDataRDD)
197 |
198 | /**
199 | * ============== Training ==============
200 | */
201 | val trainer: StructSVMWithDBCFW[Pattern, Label] =
202 | new StructSVMWithDBCFW[Pattern, Label](
203 | trainDataRDD,
204 | DSApp,
205 | solverOptions)
206 |
207 | val model: StructSVMModel[Pattern, Label] = trainer.trainModel()
208 |
209 | /**
210 | * ============== Store Model ==============
211 | *
212 | * Optionally, you can store the model's weight parameters.
213 | *
214 | * To load a model, you can use
215 | * val weights = breeze.linalg.csvread(new java.io.File(weightOutPath))
216 | * val model = new StructSVMModel[Pattern, Label](weights, 0.0, null, DSApp)
217 | */
218 | val weightOutPath = "path/to/weights.csv"
219 | val weights = model.weights.toDenseVector.toDenseMatrix
220 | breeze.linalg.csvwrite(new java.io.File(weightOutPath), weights)
221 |
222 | }
223 | }
--------------------------------------------------------------------------------
/dissolve-struct-examples/build.sbt:
--------------------------------------------------------------------------------
1 | name := "DissolveStructExample"
2 |
3 | organization := "ch.ethz.dalab"
4 |
5 | version := "0.1-SNAPSHOT"
6 |
7 | scalaVersion := "2.10.4"
8 |
9 | libraryDependencies += "ch.ethz.dalab" %% "dissolvestruct" % "0.1-SNAPSHOT"
10 |
11 | libraryDependencies += "org.scalatest" % "scalatest_2.10" % "2.2.4" % "test"
12 |
13 | libraryDependencies += "org.apache.spark" %% "spark-core" % "1.4.1"
14 |
15 | libraryDependencies += "org.apache.spark" %% "spark-mllib" % "1.4.1"
16 |
17 | resolvers += "IESL Release" at "http://dev-iesl.cs.umass.edu/nexus/content/groups/public"
18 |
19 | libraryDependencies += "cc.factorie" % "factorie" % "1.0"
20 |
21 | libraryDependencies += "com.github.scopt" %% "scopt" % "3.3.0"
22 |
23 | resolvers += Resolver.sonatypeRepo("public")
24 |
25 | EclipseKeys.createSrc := EclipseCreateSrc.Default + EclipseCreateSrc.Resource
26 |
27 | mergeStrategy in assembly <<= (mergeStrategy in assembly) { (old) =>
28 | {
29 | case PathList("javax", "servlet", xs @ _*) => MergeStrategy.first
30 | case PathList(ps @ _*) if ps.last endsWith ".html" => MergeStrategy.first
31 | case "application.conf" => MergeStrategy.concat
32 | case "reference.conf" => MergeStrategy.concat
33 | case "log4j.properties" => MergeStrategy.discard
34 | case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard
35 | case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard
36 | case _ => MergeStrategy.first
37 | }
38 | }
39 |
40 | test in assembly := {}
41 |
--------------------------------------------------------------------------------
/dissolve-struct-examples/conf/log4j.properties:
--------------------------------------------------------------------------------
1 | # Set everything to be logged to the console
2 | log4j.rootCategory=INFO, console, file
3 |
4 | log4j.appender.console=org.apache.log4j.ConsoleAppender
5 | log4j.appender.console.target=System.err
6 | log4j.appender.console.layout=org.apache.log4j.PatternLayout
7 | log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
8 | log4j.appender.console.Threshold=WARN
9 |
10 | # Direct log messages to a log file
11 | log4j.appender.file=org.apache.log4j.RollingFileAppender
12 | log4j.appender.file.File=logging.log
13 | log4j.appender.file.MaxFileSize=10MB
14 | log4j.appender.file.MaxBackupIndex=10
15 | log4j.appender.file.layout=org.apache.log4j.PatternLayout
16 | log4j.appender.file.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss} %-5p %c{1}:%L - %m%n
17 | log4j.appender.file.Threshold=WARN
--------------------------------------------------------------------------------
/dissolve-struct-examples/lib/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 |
3 | !.gitignore
--------------------------------------------------------------------------------
/dissolve-struct-examples/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | logLevel := Level.Warn
2 |
3 | addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.5.0")
4 |
5 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.13.0")
6 |
--------------------------------------------------------------------------------
/dissolve-struct-examples/src/main/resources/adj.txt:
--------------------------------------------------------------------------------
1 | able
2 | abnormal
3 | absent
4 | absolute
5 | abstract
6 | abundant
7 | academic
8 | acceptable
9 | accessible
10 | accurate
11 | active
12 | acute
13 | addicted
14 | adequate
15 | aesthetic
16 | afraid
17 | aggressive
18 | agile
19 | agricultural
20 | alert
21 | alive
22 | aloof
23 | amber
24 | ambiguous
25 | ambitious
26 | ample
27 | angry
28 | annual
29 | anonymous
30 | applied
31 | appropriate
32 | arbitrary
33 | archaeological
34 | arrogant
35 | artificial
36 | artistic
37 | ashamed
38 | asleep
39 | assertive
40 | astonishing
41 | attractive
42 | automatic
43 | available
44 | awake
45 | aware
46 | awful
47 | awkward
48 | bad
49 | balanced
50 | bald
51 | bare
52 | basic
53 | beautiful
54 | bitter
55 | black
56 | bland
57 | blank
58 | blind
59 | blonde
60 | bloody
61 | bold
62 | brave
63 | broken
64 | brown
65 | bureaucratic
66 | busy
67 | capable
68 | careful
69 | cautious
70 | central
71 | certain
72 | characteristic
73 | charismatic
74 | cheap
75 | cheerful
76 | childish
77 | chronic
78 | civic
79 | civilian
80 | classical
81 | clean
82 | clear
83 | close
84 | closed
85 | cold
86 | color-blind
87 | colourful
88 | comfortable
89 | commercial
90 | common
91 | comparable
92 | compatible
93 | competent
94 | competitive
95 | complete
96 | complex
97 | comprehensive
98 | concrete
99 | confident
100 | conscious
101 | conservative
102 | considerable
103 | consistent
104 | constant
105 | constitutional
106 | constructive
107 | content
108 | continental
109 | continuous
110 | controversial
111 | convenient
112 | conventional
113 | cool
114 | cooperative
115 | corporate
116 | critical
117 | crude
118 | cruel
119 | cultural
120 | curious
121 | current
122 | cute
123 | daily
124 | dangerous
125 | dark
126 | dead
127 | deadly
128 | deaf
129 | decisive
130 | decorative
131 | deep
132 | definite
133 | delicate
134 | democratic
135 | dependent
136 | desirable
137 | different
138 | difficult
139 | digital
140 | diplomatic
141 | direct
142 | dirty
143 | discreet
144 | distant
145 | distinct
146 | domestic
147 | dominant
148 | dramatic
149 | dry
150 | due
151 | dull
152 | dynamic
153 | eager
154 | early
155 | easy
156 | economic
157 | educational
158 | effective
159 | efficient
160 | electronic
161 | elegant
162 | eligible
163 | eloquent
164 | emotional
165 | empirical
166 | empty
167 | encouraging
168 | enjoyable
169 | enthusiastic
170 | environmental
171 | equal
172 | essential
173 | established
174 | eternal
175 | ethical
176 | ethnic
177 | even
178 | exact
179 | excited
180 | exciting
181 | exclusive
182 | exotic
183 | expected
184 | expensive
185 | experienced
186 | experimental
187 | explicit
188 | express
189 | external
190 | extinct
191 | extraordinary
192 | fair
193 | faithful
194 | false
195 | familiar
196 | far
197 | fashionable
198 | fast
199 | fastidious
200 | fat
201 | favorable
202 | federal
203 | feminine
204 | financial
205 | fine
206 | finished
207 | first
208 | firsthand
209 | flat
210 | flawed
211 | flexible
212 | foolish
213 | formal
214 | forward
215 | fragrant
216 | frank
217 | free
218 | frequent
219 | fresh
220 | friendly
221 | frozen
222 | full
223 | full-time
224 | functional
225 | funny
226 | general
227 | generous
228 | genetic
229 | genuine
230 | geological
231 | glad
232 | glorious
233 | good
234 | gradual
235 | grand
236 | graphic
237 | grateful
238 | great
239 | green
240 | gregarious
241 | handy
242 | happy
243 | hard
244 | harmful
245 | harsh
246 | healthy
247 | heavy
248 | helpful
249 | helpless
250 | high
251 | hilarious
252 | historical
253 | holy
254 | homosexual
255 | honest
256 | honorable
257 | horizontal
258 | hostile
259 | hot
260 | huge
261 | human
262 | hungry
263 | ignorant
264 | illegal
265 | immune
266 | imperial
267 | implicit
268 | important
269 | impossible
270 | impressive
271 | inadequate
272 | inappropriate
273 | incapable
274 | incongruous
275 | incredible
276 | independent
277 | indigenous
278 | indirect
279 | indoor
280 | industrial
281 | inevitable
282 | infinite
283 | influential
284 | informal
285 | inner
286 | innocent
287 | insufficient
288 | integrated
289 | intellectual
290 | intense
291 | interactive
292 | interesting
293 | intermediate
294 | internal
295 | international
296 | invisible
297 | irrelevant
298 | jealous
299 | joint
300 | judicial
301 | junior
302 | just
303 | kind
304 | large
305 | last
306 | late
307 | latest
308 | lazy
309 | left
310 | legal
311 | legislative
312 | liberal
313 | light
314 | likely
315 | limited
316 | linear
317 | liquid
318 | literary
319 | live
320 | lively
321 | logical
322 | lonely
323 | long
324 | loose
325 | lost
326 | loud
327 | low
328 | loyal
329 | lucky
330 | magnetic
331 | main
332 | major
333 | manual
334 | marine
335 | married
336 | mathematical
337 | mature
338 | maximum
339 | meaningful
340 | mechanical
341 | medieval
342 | memorable
343 | mental
344 | middle-class
345 | mild
346 | military
347 | minimum
348 | minor
349 | miserable
350 | mobile
351 | modern
352 | modest
353 | molecular
354 | monstrous
355 | monthly
356 | moral
357 | moving
358 | multiple
359 | municipal
360 | musical
361 | mutual
362 | narrow
363 | national
364 | native
365 | necessary
366 | negative
367 | nervous
368 | neutral
369 | new
370 | nice
371 | noble
372 | noisy
373 | normal
374 | notorious
375 | nuclear
376 | obese
377 | objective
378 | obscure
379 | obvious
380 | occupational
381 | odd
382 | offensive
383 | official
384 | old
385 | open
386 | operational
387 | opposed
388 | optimistic
389 | optional
390 | oral
391 | ordinary
392 | organic
393 | original
394 | orthodox
395 | other
396 | outer
397 | outside
398 | painful
399 | parallel
400 | paralyzed
401 | parental
402 | particular
403 | part-time
404 | passionate
405 | passive
406 | past
407 | patient
408 | peaceful
409 | perfect
410 | permanent
411 | persistent
412 | personal
413 | petty
414 | philosophical
415 | physical
416 | plain
417 | pleasant
418 | polite
419 | political
420 | poor
421 | popular
422 | portable
423 | positive
424 | possible
425 | powerful
426 | practical
427 | precise
428 | predictable
429 | pregnant
430 | premature
431 | present
432 | presidential
433 | primary
434 | private
435 | privileged
436 | productive
437 | professional
438 | profound
439 | progressive
440 | prolonged
441 | proper
442 | proportional
443 | proud
444 | provincial
445 | public
446 | pure
447 | qualified
448 | quantitative
449 | quiet
450 | racial
451 | random
452 | rare
453 | rational
454 | raw
455 | ready
456 | real
457 | realistic
458 | reasonable
459 | reckless
460 | regional
461 | regular
462 | related
463 | relative
464 | relevant
465 | reliable
466 | religious
467 | representative
468 | resident
469 | residential
470 | respectable
471 | responsible
472 | restless
473 | restricted
474 | retired
475 | revolutionary
476 | rich
477 | right
478 | romantic
479 | rotten
480 | rough
481 | round
482 | rural
483 | sacred
484 | sad
485 | safe
486 | satisfactory
487 | satisfied
488 | scientific
489 | seasonal
490 | secondary
491 | secular
492 | secure
493 | senior
494 | sensitive
495 | separate
496 | serious
497 | sexual
498 | shallow
499 | sharp
500 | short
501 | shy
502 | sick
503 | similar
504 | single
505 | skilled
506 | slippery
507 | slow
508 | small
509 | smart
510 | smooth
511 | social
512 | socialist
513 | soft
514 | solar
515 | solid
516 | sophisticated
517 | sound
518 | sour
519 | spatial
520 | specified
521 | spontaneous
522 | square
523 | stable
524 | standard
525 | statistical
526 | steady
527 | steep
528 | sticky
529 | still
530 | straight
531 | strange
532 | strategic
533 | strict
534 | strong
535 | structural
536 | stubborn
537 | stunning
538 | stupid
539 | subjective
540 | subsequent
541 | successful
542 | sudden
543 | sufficient
544 | superior
545 | supplementary
546 | surprised
547 | surprising
548 | sweet
549 | sympathetic
550 | systematic
551 | talented
552 | talkative
553 | tall
554 | tasty
555 | technical
556 | temporary
557 | tender
558 | tense
559 | terminal
560 | thick
561 | thin
562 | thirsty
563 | thoughtful
564 | tidy
565 | tight
566 | tired
567 | tolerant
568 | tough
569 | toxic
570 | traditional
571 | transparent
572 | trivial
573 | tropical
574 | true
575 | typical
576 | ugly
577 | ultimate
578 | unanimous
579 | unaware
580 | uncomfortable
581 | uneasy
582 | unemployed
583 | unexpected
584 | unfair
585 | unfortunate
586 | uniform
587 | unique
588 | universal
589 | unlawful
590 | unlike
591 | unlikely
592 | unpleasant
593 | urban
594 | useful
595 | useless
596 | usual
597 | vacant
598 | vague
599 | vain
600 | valid
601 | valuable
602 | varied
603 | verbal
604 | vertical
605 | viable
606 | vicious
607 | vigorous
608 | violent
609 | visible
610 | visual
611 | vocational
612 | voluntary
613 | vulnerable
614 | warm
615 | weak
616 | weekly
617 | welcome
618 | well
619 | wet
620 | white
621 | whole
622 | wild
623 | wise
624 | written
625 | wrong
626 | young
627 |
--------------------------------------------------------------------------------
/dissolve-struct-examples/src/main/resources/chain_test.csv:
--------------------------------------------------------------------------------
1 | 214,150,602,408,267,440,253,191,67,338,111,364,118,114,10,611,415,342,8,82,411,141,117,102,281,386,544,204,409,42,622,27,594,47,153,272,596,325,606,53,570,344,162,203,341,13,477,502,474,107,146,569,113,311,275,260,14,470,308,432,236,583,232,259,464,418,32,567,453,462,181,592,459,609,626,244,575,241,531,194,263,103,530,410,94,334,268,18,549,139,65,37,235,105,525,358,615,62,161,589,333,121,86,254,424,206,288,230,74,4,492,603,258,573,387,414,202,447,448,185,557,207,97,585,31,55,506,100,476,367,584,234,158,439,467,540,261,483,518,610,271,189,226,331,50,522,197,147,362,149,283,222,345,72,382,12,78,108,501,310,302,402,122,473,23,552,558,237,217,363,163,328,46,346,51,61,239,520,126,19,266,306,240,297,179,106,223,95,457,625,20,400,265,535,456,417,543,209,136,145,24,144,485,542,252,129,227,6,322,93,152,335,243,554,425,290,96,182,154,351,279,172,286,138,354,379,478,127,555,532,164,56,577,284,3,348,388,228,39,463,416,200,446,454,605,143,578,495,488,368,216,445,356,365,250,620,422,572,427,372,187,16,498,273,77,564,623,312,79,291,330,539,370,595,49,270,287,134,303,229,429,514,505,378,87,212,238,534,196,249,529,443,195,507,504,561,30,9,517,607,225,170,7,112,437,513,304,546,99,489,233,277,289,198,70,374,384,460,282,52,71,523,475,391,83,503,190,455,579,321,465,208,135,593,452,43,140,123,614,496,451,441,300,116,280,40,213,180,436,493,177,101,508,524,376,500,481,586,15,613,420,617,242,90,480,22,66,419,218,115,307,34,317,479,245,199,91,574,377,119,98,73,357,430,339,324,110,619,45,571,527,278,148,537,399,315,92,175,512,515,178,431,151,44,393,309,392,255,582,547,167,587,215,526,406,381,327,482,128,450,433,60,556,423,89,76,41,326,292,171,336,600,188,299,25,320,486,497,545,131,156,26,251,624,169,28,359,124,295,301,519,159,125,490,5,219,201,510,298,405,142,347,565,494,63,469,458,360,566,487,269,337,395,349,516,421,444,548,396,428,68,413,247,581,580,468,88,536,183,604,394,366,407,257,314,355,521,616,438,401,155,601,296,21,11,350,404,211,369,559,80,57,84,176,560,2,205,48,157,133,383,426,294,551,33,316,186,120,568,285,332,130,35,192,305,81,385,193,319,329,590,598,618,221,373,313,621,434,412,472,1,293,165,397,361,389,36,390,264,442,375,343,173,323,85,541,109,248,435,449,380,256,553,599,276,491,59,64,563,484,184,168,246,403,38,398,274,220,528,75,340,262,132,318,104,591,461,588,160,576,471,466,29,612,17,511,597,210,166,353,54,608,58,538,509,69,499,174,231,352,224,533,371,550,137,562
2 |
--------------------------------------------------------------------------------
/dissolve-struct-examples/src/main/resources/imageseg_cattle_test.txt:
--------------------------------------------------------------------------------
1 | 1_1_s.bmp
2 | 1_3_s.bmp
3 | 1_13_s.bmp
4 | 1_30_s.bmp
5 | 1_26_s.bmp
6 | 1_14_s.bmp
7 | 1_16_s.bmp
8 | 1_12_s.bmp
9 | 1_18_s.bmp
10 | 1_9_s.bmp
11 | 1_27_s.bmp
12 | 1_23_s.bmp
13 | 1_24_s.bmp
14 |
--------------------------------------------------------------------------------
/dissolve-struct-examples/src/main/resources/imageseg_cattle_train.txt:
--------------------------------------------------------------------------------
1 | 1_6_s.bmp
2 | 1_10_s.bmp
3 | 1_4_s.bmp
4 | 1_22_s.bmp
5 | 1_11_s.bmp
6 | 1_21_s.bmp
7 | 1_19_s.bmp
8 | 1_8_s.bmp
9 | 1_7_s.bmp
10 | 1_15_s.bmp
11 | 1_5_s.bmp
12 | 1_20_s.bmp
13 | 1_29_s.bmp
14 | 1_17_s.bmp
15 | 1_2_s.bmp
16 | 1_25_s.bmp
17 | 1_28_s.bmp
18 |
--------------------------------------------------------------------------------
/dissolve-struct-examples/src/main/resources/imageseg_colormap.txt:
--------------------------------------------------------------------------------
1 | 0 0 0 0 0 void
2 | 1 128 0 0 128 cow
3 | 2 16320 0 64 0 bird
4 | 3 32640 0 128 0 grass
5 | 4 32768 0 128 128 sheep
6 | 5 48960 0 192 0 chair
7 | 6 49088 0 192 128 cat
8 | 7 4161600 64 0 0 mountain
9 | 8 4161728 64 0 128 car
10 | 9 4177920 64 64 0 body
11 | 10 4194240 64 128 0 water
12 | 11 4194368 64 128 128 flower
13 | 12 8323200 128 0 0 building
14 | 13 8323328 128 0 128 horse
15 | 14 8339520 128 64 0 book
16 | 15 8339648 128 64 128 road
17 | 16 8355840 128 128 0 tree
18 | 17 8355968 128 128 128 sky
19 | 18 8372288 128 192 128 dog
20 | 19 12484800 192 0 0 aeroplane
21 | 20 12484928 192 0 128 bicycle
22 | 21 12501120 192 64 0 boat
23 | 22 12517440 192 128 0 face
24 | 23 12517568 192 128 128 sign
25 |
--------------------------------------------------------------------------------
/dissolve-struct-examples/src/main/resources/imageseg_lab_freq.txt:
--------------------------------------------------------------------------------
1 | 0,0.269829
2 | 1,0.021016
3 | 2,0.009487
4 | 3,0.142275
5 | 4,0.016398
6 | 5,0.013469
7 | 6,0.011448
8 | 7,0.007311
9 | 8,0.023944
10 | 9,0.015565
11 | 10,0.055292
12 | 11,0.022215
13 | 12,0.076981
14 | 13,0.000735
15 | 14,0.038273
16 | 15,0.061869
17 | 16,0.064517
18 | 17,0.069084
19 | 18,0.012804
20 | 19,0.010377
21 | 20,0.020019
22 | 21,0.006815
23 | 22,0.012681
24 | 23,0.017599
25 |
--------------------------------------------------------------------------------
/dissolve-struct-examples/src/main/resources/imageseg_label_color_map.txt:
--------------------------------------------------------------------------------
1 | 12 8323200 128 0 0 building 0
2 | 3 32640 0 128 0 grass 1
3 | 16 8355840 128 128 0 tree 2
4 | 1 128 0 0 128 cow 3
5 | 4 32768 0 128 128 sheep 4
6 | 17 8355968 128 128 128 sky 5
7 | 19 12484800 192 0 0 aeroplane 6
8 | 10 4194240 64 128 0 water 7
9 | 22 12517440 192 128 0 face 8
10 | 8 4161728 64 0 128 car 9
11 | 20 12484928 192 0 128 bicycle 10
12 | 11 4194368 64 128 128 flower 11
13 | 23 12517568 192 128 128 sign 12
14 | 2 16320 0 64 0 bird 13
15 | 14 8339520 128 64 0 book 14
16 | 5 48960 0 192 0 chair 15
17 | 15 8339648 128 64 128 road 16
18 | 6 49088 0 192 128 cat 17
19 | 18 8372288 128 192 128 dog 18
20 | 9 4177920 64 64 0 body 19
21 | 21 12501120 192 64 0 boat 20
22 | 0 0 0 0 0 void 21
23 |
--------------------------------------------------------------------------------
/dissolve-struct-examples/src/main/resources/imageseg_test.txt:
--------------------------------------------------------------------------------
1 | 18_5_s.bmp
2 | 18_21_s.bmp
3 | 15_22_s.bmp
4 | 15_1_s.bmp
5 | 1_1_s.bmp
6 | 20_9_s.bmp
7 | 2_8_s.bmp
8 | 15_19_s.bmp
9 | 2_10_s.bmp
10 | 18_26_s.bmp
11 | 15_5_s.bmp
12 | 5_11_s.bmp
13 | 18_30_s.bmp
14 | 1_3_s.bmp
15 | 18_29_s.bmp
16 | 1_13_s.bmp
17 | 15_13_s.bmp
18 | 13_11_s.bmp
19 | 15_2_s.bmp
20 | 20_17_s.bmp
21 | 11_9_s.bmp
22 | 1_30_s.bmp
23 | 13_7_s.bmp
24 | 16_12_s.bmp
25 | 9_3_s.bmp
26 | 4_28_s.bmp
27 | 12_19_s.bmp
28 | 13_8_s.bmp
29 | 3_22_s.bmp
30 | 11_5_s.bmp
31 | 16_24_s.bmp
32 | 3_20_s.bmp
33 | 4_27_s.bmp
34 | 17_15_s.bmp
35 | 13_22_s.bmp
36 | 12_26_s.bmp
37 | 5_3_s.bmp
38 | 18_23_s.bmp
39 | 16_4_s.bmp
40 | 18_20_s.bmp
41 | 13_25_s.bmp
42 | 11_15_s.bmp
43 | 5_13_s.bmp
44 | 19_7_s.bmp
45 | 12_7_s.bmp
46 | 2_25_s.bmp
47 | 12_8_s.bmp
48 | 20_2_s.bmp
49 | 15_3_s.bmp
50 | 7_26_s.bmp
51 | 9_20_s.bmp
52 | 1_26_s.bmp
53 | 10_2_s.bmp
54 | 13_20_s.bmp
55 | 14_21_s.bmp
56 | 10_18_s.bmp
57 | 5_2_s.bmp
58 | 14_13_s.bmp
59 | 4_26_s.bmp
60 | 15_7_s.bmp
61 | 17_19_s.bmp
62 | 18_24_s.bmp
63 | 9_24_s.bmp
64 | 10_1_s.bmp
65 | 1_14_s.bmp
66 | 10_22_s.bmp
67 | 10_25_s.bmp
68 | 12_33_s.bmp
69 | 3_16_s.bmp
70 | 4_16_s.bmp
71 | 9_7_s.bmp
72 | 4_1_s.bmp
73 | 12_9_s.bmp
74 | 4_13_s.bmp
75 | 19_3_s.bmp
76 | 16_16_s.bmp
77 | 17_5_s.bmp
78 | 17_30_s.bmp
79 | 20_1_s.bmp
80 | 9_1_s.bmp
81 | 20_16_s.bmp
82 | 17_8_s.bmp
83 | 8_19_s.bmp
84 | 9_9_s.bmp
85 | 8_6_s.bmp
86 | 19_12_s.bmp
87 | 3_5_s.bmp
88 | 8_18_s.bmp
89 | 18_7_s.bmp
90 | 16_5_s.bmp
91 | 10_29_s.bmp
92 | 7_29_s.bmp
93 | 20_20_s.bmp
94 | 8_22_s.bmp
95 | 2_7_s.bmp
96 | 13_9_s.bmp
97 | 16_10_s.bmp
98 | 2_14_s.bmp
99 | 9_13_s.bmp
100 | 14_6_s.bmp
101 | 5_14_s.bmp
102 | 18_12_s.bmp
103 | 13_1_s.bmp
104 | 6_16_s.bmp
105 | 7_5_s.bmp
106 | 3_9_s.bmp
107 | 10_12_s.bmp
108 | 3_12_s.bmp
109 | 3_8_s.bmp
110 | 4_24_s.bmp
111 | 2_29_s.bmp
112 | 13_4_s.bmp
113 | 11_13_s.bmp
114 | 20_8_s.bmp
115 | 10_4_s.bmp
116 | 10_31_s.bmp
117 | 3_26_s.bmp
118 | 3_7_s.bmp
119 | 19_23_s.bmp
120 | 8_27_s.bmp
121 | 6_22_s.bmp
122 | 14_15_s.bmp
123 | 7_12_s.bmp
124 | 7_21_s.bmp
125 | 2_18_s.bmp
126 | 11_4_s.bmp
127 | 7_14_s.bmp
128 | 8_29_s.bmp
129 | 4_30_s.bmp
130 | 14_17_s.bmp
131 | 8_10_s.bmp
132 | 1_16_s.bmp
133 | 2_26_s.bmp
134 | 14_9_s.bmp
135 | 18_28_s.bmp
136 | 5_23_s.bmp
137 | 14_10_s.bmp
138 | 19_30_s.bmp
139 | 12_31_s.bmp
140 | 16_14_s.bmp
141 | 12_14_s.bmp
142 | 14_18_s.bmp
143 | 11_14_s.bmp
144 | 18_3_s.bmp
145 | 16_13_s.bmp
146 | 3_30_s.bmp
147 | 6_5_s.bmp
148 | 19_24_s.bmp
149 | 17_17_s.bmp
150 | 19_15_s.bmp
151 | 11_25_s.bmp
152 | 2_22_s.bmp
153 | 2_30_s.bmp
154 | 17_28_s.bmp
155 | 6_13_s.bmp
156 | 9_2_s.bmp
157 | 11_18_s.bmp
158 | 19_16_s.bmp
159 | 8_17_s.bmp
160 | 2_28_s.bmp
161 | 16_23_s.bmp
162 | 1_12_s.bmp
163 | 4_20_s.bmp
164 | 7_15_s.bmp
165 | 11_8_s.bmp
166 | 11_6_s.bmp
167 | 3_18_s.bmp
168 | 1_18_s.bmp
169 | 1_9_s.bmp
170 | 19_29_s.bmp
171 | 12_2_s.bmp
172 | 17_26_s.bmp
173 | 16_1_s.bmp
174 | 19_8_s.bmp
175 | 13_2_s.bmp
176 | 16_6_s.bmp
177 | 8_24_s.bmp
178 | 5_28_s.bmp
179 | 15_6_s.bmp
180 | 13_17_s.bmp
181 | 17_23_s.bmp
182 | 9_19_s.bmp
183 | 19_13_s.bmp
184 | 9_18_s.bmp
185 | 14_20_s.bmp
186 | 17_4_s.bmp
187 | 6_11_s.bmp
188 | 7_25_s.bmp
189 | 2_16_s.bmp
190 | 12_16_s.bmp
191 | 12_23_s.bmp
192 | 6_24_s.bmp
193 | 12_27_s.bmp
194 | 16_8_s.bmp
195 | 10_17_s.bmp
196 | 5_17_s.bmp
197 | 12_12_s.bmp
198 | 20_14_s.bmp
199 | 12_30_s.bmp
200 | 9_17_s.bmp
201 | 17_25_s.bmp
202 | 14_16_s.bmp
203 | 4_15_s.bmp
204 | 14_28_s.bmp
205 | 7_10_s.bmp
206 | 5_19_s.bmp
207 | 10_3_s.bmp
208 | 5_27_s.bmp
209 | 8_11_s.bmp
210 | 5_4_s.bmp
211 | 5_10_s.bmp
212 | 10_28_s.bmp
213 | 6_7_s.bmp
214 | 4_23_s.bmp
215 | 11_3_s.bmp
216 | 10_7_s.bmp
217 | 13_21_s.bmp
218 | 7_7_s.bmp
219 | 16_27_s.bmp
220 | 13_27_s.bmp
221 | 6_25_s.bmp
222 | 14_14_s.bmp
223 | 17_6_s.bmp
224 | 12_13_s.bmp
225 | 18_6_s.bmp
226 | 19_28_s.bmp
227 | 14_5_s.bmp
228 | 8_2_s.bmp
229 | 3_29_s.bmp
230 | 6_12_s.bmp
231 | 4_3_s.bmp
232 | 1_27_s.bmp
233 | 19_20_s.bmp
234 | 8_4_s.bmp
235 | 7_22_s.bmp
236 | 4_10_s.bmp
237 | 6_2_s.bmp
238 | 6_15_s.bmp
239 | 8_8_s.bmp
240 | 3_13_s.bmp
241 | 9_5_s.bmp
242 | 9_26_s.bmp
243 | 5_12_s.bmp
244 | 1_23_s.bmp
245 | 15_23_s.bmp
246 | 1_24_s.bmp
247 | 20_6_s.bmp
248 | 17_12_s.bmp
249 | 6_10_s.bmp
250 | 2_17_s.bmp
251 | 11_7_s.bmp
252 | 11_19_s.bmp
253 | 7_23_s.bmp
254 | 10_16_s.bmp
255 | 7_3_s.bmp
256 | 6_3_s.bmp
257 |
--------------------------------------------------------------------------------
/dissolve-struct-examples/src/main/resources/imageseg_train.txt:
--------------------------------------------------------------------------------
1 | 11_16_s.bmp
2 | 11_27_s.bmp
3 | 11_12_s.bmp
4 | 12_34_s.bmp
5 | 11_2_s.bmp
6 | 12_28_s.bmp
7 | 16_11_s.bmp
8 | 15_24_s.bmp
9 | 8_25_s.bmp
10 | 2_15_s.bmp
11 | 11_22_s.bmp
12 | 18_9_s.bmp
13 | 9_23_s.bmp
14 | 19_2_s.bmp
15 | 2_12_s.bmp
16 | 3_11_s.bmp
17 | 20_11_s.bmp
18 | 5_9_s.bmp
19 | 14_22_s.bmp
20 | 3_24_s.bmp
21 | 4_6_s.bmp
22 | 9_10_s.bmp
23 | 11_11_s.bmp
24 | 2_20_s.bmp
25 | 1_6_s.bmp
26 | 16_30_s.bmp
27 | 12_1_s.bmp
28 | 9_22_s.bmp
29 | 15_18_s.bmp
30 | 18_27_s.bmp
31 | 17_16_s.bmp
32 | 10_8_s.bmp
33 | 1_10_s.bmp
34 | 16_17_s.bmp
35 | 17_13_s.bmp
36 | 13_3_s.bmp
37 | 1_4_s.bmp
38 | 5_29_s.bmp
39 | 1_22_s.bmp
40 | 20_21_s.bmp
41 | 2_19_s.bmp
42 | 1_11_s.bmp
43 | 18_13_s.bmp
44 | 3_15_s.bmp
45 | 5_24_s.bmp
46 | 4_18_s.bmp
47 | 18_11_s.bmp
48 | 13_26_s.bmp
49 | 9_14_s.bmp
50 | 19_27_s.bmp
51 | 12_10_s.bmp
52 | 18_8_s.bmp
53 | 5_6_s.bmp
54 | 12_11_s.bmp
55 | 18_17_s.bmp
56 | 4_8_s.bmp
57 | 10_32_s.bmp
58 | 18_2_s.bmp
59 | 17_21_s.bmp
60 | 14_11_s.bmp
61 | 3_19_s.bmp
62 | 3_4_s.bmp
63 | 8_14_s.bmp
64 | 11_30_s.bmp
65 | 11_29_s.bmp
66 | 9_27_s.bmp
67 | 17_11_s.bmp
68 | 13_13_s.bmp
69 | 4_21_s.bmp
70 | 19_14_s.bmp
71 | 11_10_s.bmp
72 | 18_15_s.bmp
73 | 10_21_s.bmp
74 | 14_3_s.bmp
75 | 1_21_s.bmp
76 | 7_9_s.bmp
77 | 16_28_s.bmp
78 | 15_10_s.bmp
79 | 15_21_s.bmp
80 | 13_16_s.bmp
81 | 2_9_s.bmp
82 | 13_24_s.bmp
83 | 14_19_s.bmp
84 | 15_17_s.bmp
85 | 19_26_s.bmp
86 | 15_20_s.bmp
87 | 20_18_s.bmp
88 | 18_22_s.bmp
89 | 10_15_s.bmp
90 | 2_1_s.bmp
91 | 9_30_s.bmp
92 | 6_14_s.bmp
93 | 12_15_s.bmp
94 | 14_2_s.bmp
95 | 15_4_s.bmp
96 | 12_6_s.bmp
97 | 15_14_s.bmp
98 | 16_19_s.bmp
99 | 8_23_s.bmp
100 | 11_24_s.bmp
101 | 14_1_s.bmp
102 | 9_21_s.bmp
103 | 8_13_s.bmp
104 | 19_10_s.bmp
105 | 11_20_s.bmp
106 | 8_30_s.bmp
107 | 13_29_s.bmp
108 | 15_12_s.bmp
109 | 7_6_s.bmp
110 | 14_8_s.bmp
111 | 13_14_s.bmp
112 | 12_29_s.bmp
113 | 19_18_s.bmp
114 | 4_25_s.bmp
115 | 15_11_s.bmp
116 | 17_9_s.bmp
117 | 7_11_s.bmp
118 | 18_16_s.bmp
119 | 20_12_s.bmp
120 | 3_25_s.bmp
121 | 14_12_s.bmp
122 | 2_5_s.bmp
123 | 11_26_s.bmp
124 | 20_3_s.bmp
125 | 9_16_s.bmp
126 | 6_27_s.bmp
127 | 10_6_s.bmp
128 | 15_16_s.bmp
129 | 17_2_s.bmp
130 | 8_5_s.bmp
131 | 16_25_s.bmp
132 | 18_18_s.bmp
133 | 18_4_s.bmp
134 | 10_5_s.bmp
135 | 12_32_s.bmp
136 | 17_10_s.bmp
137 | 18_25_s.bmp
138 | 19_5_s.bmp
139 | 8_12_s.bmp
140 | 19_25_s.bmp
141 | 3_3_s.bmp
142 | 13_12_s.bmp
143 | 3_10_s.bmp
144 | 18_19_s.bmp
145 | 6_17_s.bmp
146 | 9_29_s.bmp
147 | 12_24_s.bmp
148 | 19_11_s.bmp
149 | 12_21_s.bmp
150 | 17_29_s.bmp
151 | 6_9_s.bmp
152 | 10_27_s.bmp
153 | 4_29_s.bmp
154 | 17_1_s.bmp
155 | 14_25_s.bmp
156 | 6_6_s.bmp
157 | 1_19_s.bmp
158 | 10_10_s.bmp
159 | 19_21_s.bmp
160 | 7_4_s.bmp
161 | 3_21_s.bmp
162 | 7_20_s.bmp
163 | 1_8_s.bmp
164 | 11_1_s.bmp
165 | 5_26_s.bmp
166 | 12_18_s.bmp
167 | 1_7_s.bmp
168 | 12_4_s.bmp
169 | 7_16_s.bmp
170 | 2_13_s.bmp
171 | 20_10_s.bmp
172 | 14_30_s.bmp
173 | 19_19_s.bmp
174 | 5_15_s.bmp
175 | 17_14_s.bmp
176 | 17_18_s.bmp
177 | 5_16_s.bmp
178 | 20_5_s.bmp
179 | 14_23_s.bmp
180 | 11_23_s.bmp
181 | 10_20_s.bmp
182 | 13_5_s.bmp
183 | 5_1_s.bmp
184 | 5_25_s.bmp
185 | 17_20_s.bmp
186 | 17_7_s.bmp
187 | 5_8_s.bmp
188 | 9_28_s.bmp
189 | 6_18_s.bmp
190 | 5_21_s.bmp
191 | 2_21_s.bmp
192 | 7_1_s.bmp
193 | 3_27_s.bmp
194 | 20_7_s.bmp
195 | 12_17_s.bmp
196 | 6_29_s.bmp
197 | 1_15_s.bmp
198 | 5_20_s.bmp
199 | 13_15_s.bmp
200 | 16_21_s.bmp
201 | 1_5_s.bmp
202 | 6_23_s.bmp
203 | 10_24_s.bmp
204 | 7_19_s.bmp
205 | 8_28_s.bmp
206 | 2_24_s.bmp
207 | 4_4_s.bmp
208 | 2_4_s.bmp
209 | 16_26_s.bmp
210 | 1_20_s.bmp
211 | 5_18_s.bmp
212 | 20_13_s.bmp
213 | 5_22_s.bmp
214 | 12_3_s.bmp
215 | 13_18_s.bmp
216 | 16_7_s.bmp
217 | 7_8_s.bmp
218 | 2_6_s.bmp
219 | 2_23_s.bmp
220 | 1_29_s.bmp
221 | 6_1_s.bmp
222 | 1_17_s.bmp
223 | 20_4_s.bmp
224 | 16_22_s.bmp
225 | 10_11_s.bmp
226 | 8_26_s.bmp
227 | 7_27_s.bmp
228 | 6_8_s.bmp
229 | 12_22_s.bmp
230 | 3_17_s.bmp
231 | 2_3_s.bmp
232 | 4_5_s.bmp
233 | 4_14_s.bmp
234 | 4_12_s.bmp
235 | 16_18_s.bmp
236 | 4_7_s.bmp
237 | 9_4_s.bmp
238 | 19_17_s.bmp
239 | 3_2_s.bmp
240 | 8_21_s.bmp
241 | 7_28_s.bmp
242 | 14_27_s.bmp
243 | 9_12_s.bmp
244 | 17_22_s.bmp
245 | 13_19_s.bmp
246 | 7_18_s.bmp
247 | 14_29_s.bmp
248 | 19_6_s.bmp
249 | 4_9_s.bmp
250 | 4_2_s.bmp
251 | 10_9_s.bmp
252 | 14_26_s.bmp
253 | 10_13_s.bmp
254 | 13_28_s.bmp
255 | 8_3_s.bmp
256 | 4_17_s.bmp
257 | 16_3_s.bmp
258 | 16_15_s.bmp
259 | 3_1_s.bmp
260 | 13_10_s.bmp
261 | 16_9_s.bmp
262 | 10_30_s.bmp
263 | 19_4_s.bmp
264 | 7_17_s.bmp
265 | 6_26_s.bmp
266 | 9_11_s.bmp
267 | 9_6_s.bmp
268 | 6_4_s.bmp
269 | 8_1_s.bmp
270 | 7_2_s.bmp
271 | 8_9_s.bmp
272 | 10_14_s.bmp
273 | 6_28_s.bmp
274 | 3_14_s.bmp
275 | 6_20_s.bmp
276 | 8_16_s.bmp
277 | 18_14_s.bmp
278 | 15_8_s.bmp
279 | 18_1_s.bmp
280 | 18_10_s.bmp
281 | 15_9_s.bmp
282 | 11_21_s.bmp
283 | 15_15_s.bmp
284 | 1_2_s.bmp
285 | 5_5_s.bmp
286 | 11_28_s.bmp
287 | 1_25_s.bmp
288 | 1_28_s.bmp
289 | 5_30_s.bmp
290 | 2_11_s.bmp
291 | 5_7_s.bmp
292 | 2_27_s.bmp
293 | 20_19_s.bmp
294 | 20_15_s.bmp
295 | 17_24_s.bmp
296 | 2_2_s.bmp
297 | 12_20_s.bmp
298 | 17_27_s.bmp
299 | 13_30_s.bmp
300 | 13_6_s.bmp
301 | 4_22_s.bmp
302 | 14_7_s.bmp
303 | 12_25_s.bmp
304 | 16_2_s.bmp
305 | 9_15_s.bmp
306 | 10_23_s.bmp
307 | 11_17_s.bmp
308 | 13_23_s.bmp
309 | 17_3_s.bmp
310 | 16_20_s.bmp
311 | 7_24_s.bmp
312 | 3_23_s.bmp
313 | 3_28_s.bmp
314 | 10_19_s.bmp
315 | 16_29_s.bmp
316 | 14_24_s.bmp
317 | 9_8_s.bmp
318 | 12_5_s.bmp
319 | 9_25_s.bmp
320 | 4_19_s.bmp
321 | 19_9_s.bmp
322 | 4_11_s.bmp
323 | 19_1_s.bmp
324 | 7_30_s.bmp
325 | 3_6_s.bmp
326 | 10_26_s.bmp
327 | 19_22_s.bmp
328 | 14_4_s.bmp
329 | 7_13_s.bmp
330 | 8_20_s.bmp
331 | 8_7_s.bmp
332 | 6_19_s.bmp
333 | 8_15_s.bmp
334 | 6_30_s.bmp
335 | 6_21_s.bmp
--------------------------------------------------------------------------------
/dissolve-struct-examples/src/main/scala/ch/ethz/dalab/dissolve/examples/binaryclassification/AdultBinary.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.examples.binaryclassification
2 |
3 | import ch.ethz.dalab.dissolve.regression.LabeledObject
4 | import ch.ethz.dalab.dissolve.optimization.SolverOptions
5 | import ch.ethz.dalab.dissolve.classification.StructSVMWithDBCFW
6 | import ch.ethz.dalab.dissolve.classification.BinarySVMWithDBCFW
7 | import ch.ethz.dalab.dissolve.classification.StructSVMModel
8 | import ch.ethz.dalab.dissolve.optimization.SolverUtils
9 | import org.apache.spark.mllib.util.MLUtils
10 | import org.apache.spark.SparkConf
11 | import org.apache.spark.SparkContext
12 | import org.apache.spark.rdd.RDD
13 | import org.apache.spark.mllib.regression.LabeledPoint
14 | import org.apache.spark.mllib.classification.SVMWithSGD
15 | import breeze.linalg._
16 | import breeze.numerics.abs
17 |
18 | /**
19 | *
20 | * Dataset: Adult (http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#a1a)
21 | * Type: Binary
22 | *
23 | * Created by tribhu on 12/11/14.
24 | */
25 | object AdultBinary {
26 |
27 | /**
28 | * DBCFW Implementation
29 | */
30 | def dbcfwAdult() {
31 | val a1aPath = "../data/generated/a1a"
32 |
33 | // Fix seed for reproducibility
34 | util.Random.setSeed(1)
35 |
36 | val conf = new SparkConf().setAppName("Adult-example").setMaster("local")
37 | val sc = new SparkContext(conf)
38 | sc.setCheckpointDir("checkpoint-files")
39 |
40 | val solverOptions: SolverOptions[Vector[Double], Double] = new SolverOptions()
41 |
42 | solverOptions.roundLimit = 20 // After these many passes, each slice of the RDD returns a trained model
43 | solverOptions.debug = true
44 | solverOptions.lambda = 0.01
45 | solverOptions.doWeightedAveraging = false
46 | solverOptions.doLineSearch = true
47 | solverOptions.debug = false
48 |
49 | solverOptions.sampleWithReplacement = false
50 |
51 | solverOptions.enableManualPartitionSize = true
52 | solverOptions.NUM_PART = 1
53 |
54 | solverOptions.sample = "frac"
55 | solverOptions.sampleFrac = 0.5
56 |
57 | solverOptions.enableOracleCache = false
58 |
59 | solverOptions.debugInfoPath = "../debug/debugInfo-a1a-%d.csv".format(System.currentTimeMillis())
60 |
61 |
62 | val data = MLUtils.loadLibSVMFile(sc, a1aPath)
63 |
64 | // Split data into training and test set
65 | val splits = data.randomSplit(Array(0.8, 0.2), seed = 1L)
66 | val training = splits(0)
67 | val test = splits(1)
68 |
69 | val objectifiedTraining: RDD[LabeledObject[Vector[Double], Double]] =
70 | training.map {
71 | case x: LabeledPoint =>
72 | new LabeledObject[Vector[Double], Double](x.label, Vector(x.features.toArray)) // Is the asInstanceOf required?
73 | }
74 |
75 | val objectifiedTest: RDD[LabeledObject[Vector[Double], Double]] =
76 | test.map {
77 | case x: LabeledPoint =>
78 | new LabeledObject[Vector[Double], Double](x.label, Vector(x.features.toArray)) // Is the asInstanceOf required?
79 | }
80 |
81 | solverOptions.testDataRDD = Some(objectifiedTest)
82 | val model = BinarySVMWithDBCFW.train(training, solverOptions)
83 |
84 | // Training Errors
85 | val trueTrainingPredictions =
86 | objectifiedTraining.map {
87 | case x: LabeledObject[Vector[Double], Double] =>
88 | val prediction = model.predict(x.pattern)
89 | if (prediction == x.label)
90 | 1
91 | else
92 | 0
93 | }.fold(0)((acc, ele) => acc + ele)
94 |
95 | println("Accuracy on Training set = %d/%d = %.4f".format(trueTrainingPredictions,
96 | objectifiedTraining.count(),
97 | (trueTrainingPredictions.toDouble / objectifiedTraining.count().toDouble) * 100))
98 |
99 | // Test Errors
100 | val trueTestPredictions =
101 | objectifiedTest.map {
102 | case x: LabeledObject[Vector[Double], Double] =>
103 | val prediction = model.predict(x.pattern)
104 | if (prediction == x.label)
105 | 1
106 | else
107 | 0
108 | }.fold(0)((acc, ele) => acc + ele)
109 |
110 | println("Accuracy on Test set = %d/%d = %.4f".format(trueTestPredictions,
111 | objectifiedTest.count(),
112 | (trueTestPredictions.toDouble / objectifiedTest.count().toDouble) * 100))
113 | }
114 |
115 | /**
116 | * MLLib's SVMWithSGD implementation
117 | */
118 | def mllibAdult() {
119 |
120 | val conf = new SparkConf().setAppName("Adult-example").setMaster("local")
121 | val sc = new SparkContext(conf)
122 | sc.setCheckpointDir("checkpoint-files")
123 |
124 | val data = MLUtils.loadLibSVMFile(sc, "../data/a1a_mllib.txt")
125 |
126 | // Split data into training and test set
127 | val splits = data.randomSplit(Array(0.8, 0.2), seed = 1L)
128 | val training = splits(0)
129 | val test = splits(1)
130 |
131 | // Run training algorithm to build the model
132 | val numIterations = 100
133 | val model = SVMWithSGD.train(training, numIterations)
134 |
135 | val trainError = training.map { point =>
136 | val score = model.predict(point.features)
137 | score == point.label
138 | }.collect().toList.count(_ == true).toDouble / training.count().toDouble
139 |
140 | val testError = test.map { point =>
141 | val score = model.predict(point.features)
142 | score == point.label
143 | }.collect().toList.count(_ == true).toDouble / test.count().toDouble
144 |
145 | println("Training accuracy = " + trainError)
146 | println("Test accuracy = " + testError)
147 |
148 | }
149 |
150 | def main(args: Array[String]): Unit = {
151 | // mllibAdult()
152 |
153 | dbcfwAdult()
154 | }
155 |
156 | }
157 |
--------------------------------------------------------------------------------
/dissolve-struct-examples/src/main/scala/ch/ethz/dalab/dissolve/examples/binaryclassification/BinaryClassificationDemo.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.examples.binaryclassification
2 |
3 | import ch.ethz.dalab.dissolve.regression.LabeledObject
4 | import ch.ethz.dalab.dissolve.optimization.SolverOptions
5 | import ch.ethz.dalab.dissolve.classification.BinarySVMWithDBCFW
6 | import ch.ethz.dalab.dissolve.classification.StructSVMModel
7 | import ch.ethz.dalab.dissolve.optimization.SolverUtils
8 | import org.apache.spark.mllib.util.MLUtils
9 | import org.apache.spark.SparkConf
10 | import org.apache.spark.SparkContext
11 | import org.apache.spark.rdd.RDD
12 | import org.apache.spark.mllib.regression.LabeledPoint
13 | import org.apache.spark.mllib.classification.SVMWithSGD
14 | import org.apache.log4j.PropertyConfigurator
15 | import org.apache.spark.mllib.linalg.Vectors
16 | import breeze.linalg._
17 |
18 | /**
19 | * Runs the binary
20 | *
21 | * Created by tribhu on 12/11/14.
22 | */
23 | object BinaryClassificationDemo {
24 |
25 | /**
26 | * Training Implementation
27 | */
28 | def dissolveTrain(trainData: RDD[LabeledPoint], testData: RDD[LabeledPoint]) {
29 |
30 | val solverOptions: SolverOptions[Vector[Double], Double] = new SolverOptions()
31 |
32 | solverOptions.roundLimit = 20 // After these many passes, each slice of the RDD returns a trained model
33 | solverOptions.debug = true
34 | solverOptions.lambda = 0.01
35 | solverOptions.doWeightedAveraging = false
36 | solverOptions.doLineSearch = true
37 | solverOptions.debug = false
38 |
39 | solverOptions.sampleWithReplacement = false
40 |
41 | solverOptions.enableManualPartitionSize = true
42 | solverOptions.NUM_PART = 2
43 |
44 | solverOptions.sample = "frac"
45 | solverOptions.sampleFrac = 0.5
46 |
47 | solverOptions.enableOracleCache = false
48 |
49 | solverOptions.debugInfoPath = "../debug/debugInfo-%d.csv".format(System.currentTimeMillis())
50 |
51 | val trainDataConverted: RDD[LabeledObject[Vector[Double], Double]] =
52 | trainData.map {
53 | case x: LabeledPoint =>
54 | new LabeledObject[Vector[Double], Double](x.label, Vector(x.features.toArray))
55 | }
56 |
57 | val testDataConverted: RDD[LabeledObject[Vector[Double], Double]] =
58 | testData.map {
59 | case x: LabeledPoint =>
60 | new LabeledObject[Vector[Double], Double](x.label, Vector(x.features.toArray))
61 | }
62 |
63 | solverOptions.testDataRDD = Some(testDataConverted)
64 | val model = BinarySVMWithDBCFW.train(trainData, solverOptions)
65 |
66 | // Training Errors
67 | val trueTrainingPredictions =
68 | trainDataConverted.map {
69 | case x: LabeledObject[Vector[Double], Double] =>
70 | val prediction = model.predict(x.pattern)
71 | if (prediction == x.label)
72 | 1
73 | else
74 | 0
75 | }.fold(0)((acc, ele) => acc + ele)
76 |
77 | println("Accuracy on training set = %d/%d = %.4f".format(trueTrainingPredictions,
78 | trainDataConverted.count(),
79 | (trueTrainingPredictions.toDouble / trainDataConverted.count().toDouble) * 100))
80 |
81 | // Test Errors
82 | val trueTestPredictions =
83 | testDataConverted.map {
84 | case x: LabeledObject[Vector[Double], Double] =>
85 | val prediction = model.predict(x.pattern)
86 | if (prediction == x.label)
87 | 1
88 | else
89 | 0
90 | }.fold(0)((acc, ele) => acc + ele)
91 |
92 | println("Accuracy on test set = %d/%d = %.4f".format(trueTestPredictions,
93 | testDataConverted.count(),
94 | (trueTestPredictions.toDouble / testDataConverted.count().toDouble) * 100))
95 | }
96 |
97 | /**
98 | * MLLib's SVMWithSGD implementation
99 | */
100 | def mllibTrain(trainData: RDD[LabeledPoint], testData: RDD[LabeledPoint]) {
101 | println("running MLlib's standard gradient descent solver")
102 |
103 | // labels are assumed to be 0,1 for MLlib
104 | val trainDataConverted: RDD[LabeledPoint] =
105 | trainData.map {
106 | case x: LabeledPoint =>
107 | new LabeledPoint(if (x.label > 0) 1 else 0, x.features)
108 | }
109 |
110 | val testDataConverted: RDD[LabeledPoint] =
111 | testData.map {
112 | case x: LabeledPoint =>
113 | new LabeledPoint(if (x.label > 0) 1 else 0, x.features)
114 | }
115 |
116 | // Run training algorithm to build the model
117 | val model = SVMWithSGD.train(trainDataConverted, numIterations = 200, stepSize = 1.0, regParam = 0.01)
118 |
119 | // report accuracy on train and test
120 | val trainAcc = testDataConverted.map { point =>
121 | val score = model.predict(point.features)
122 | score == point.label
123 | }.collect().toList.count(_ == true).toDouble / testDataConverted.count().toDouble
124 |
125 | val testAcc = testDataConverted.map { point =>
126 | val score = model.predict(point.features)
127 | score == point.label
128 | }.collect().toList.count(_ == true).toDouble / testDataConverted.count().toDouble
129 |
130 | println("MLlib Training accuracy = " + trainAcc)
131 | println("MLlib Test accuracy = " + testAcc)
132 |
133 | }
134 |
135 | def main(args: Array[String]): Unit = {
136 | PropertyConfigurator.configure("conf/log4j.properties")
137 |
138 | val dataDir = "../data/generated"
139 |
140 | // Dataset: Adult (http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#a1a)
141 | val datasetFilename = "a1a"
142 |
143 | val conf = new SparkConf().setAppName("BinaryClassificationDemo").setMaster("local")
144 | val sc = new SparkContext(conf)
145 | sc.setCheckpointDir("checkpoint-files")
146 |
147 | // load dataset
148 | val data = MLUtils.loadLibSVMFile(sc, dataDir+"/"+datasetFilename)
149 |
150 | // Split data into training and test set
151 | val splits = data.randomSplit(Array(0.8, 0.2), seed = 1L)
152 | val trainingData = splits(0)
153 | val testData = splits(1)
154 |
155 | // Fix seed for reproducibility
156 | util.Random.setSeed(1)
157 |
158 |
159 | mllibTrain(trainingData, testData)
160 |
161 | dissolveTrain(trainingData, testData)
162 | }
163 |
164 | }
165 |
--------------------------------------------------------------------------------
/dissolve-struct-examples/src/main/scala/ch/ethz/dalab/dissolve/examples/binaryclassification/COVBinary.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.examples.binaryclassification
2 |
3 | import org.apache.log4j.PropertyConfigurator
4 | import org.apache.spark.SparkConf
5 | import org.apache.spark.SparkContext
6 | import org.apache.spark.mllib.classification.SVMWithSGD
7 | import org.apache.spark.mllib.regression.LabeledPoint
8 | import org.apache.spark.mllib.util.MLUtils
9 | import org.apache.spark.rdd.RDD
10 | import breeze.linalg.Vector
11 | import ch.ethz.dalab.dissolve.classification.BinarySVMWithDBCFW
12 | import ch.ethz.dalab.dissolve.examples.utils.ExampleUtils
13 | import ch.ethz.dalab.dissolve.optimization.GapThresholdCriterion
14 | import ch.ethz.dalab.dissolve.optimization.RoundLimitCriterion
15 | import ch.ethz.dalab.dissolve.optimization.SolverOptions
16 | import ch.ethz.dalab.dissolve.optimization.TimeLimitCriterion
17 | import ch.ethz.dalab.dissolve.regression.LabeledObject
18 | import ch.ethz.dalab.dissolve.utils.cli.CLAParser
19 | import java.io.File
20 |
21 | object COVBinary {
22 |
23 | /**
24 | * MLLib's classifier
25 | */
26 | def mllibCov() {
27 | val conf = new SparkConf().setAppName("Adult-example").setMaster("local")
28 | val sc = new SparkContext(conf)
29 | sc.setCheckpointDir("checkpoint-files")
30 |
31 | val data = MLUtils.loadLibSVMFile(sc, "../data/generated/covtype.libsvm.binary.scale.head.mllib")
32 |
33 | // Split data into training and test set
34 | val splits = data.randomSplit(Array(0.8, 0.2), seed = 1L)
35 | val training = splits(0)
36 | val test = splits(1)
37 |
38 | // Run training algorithm to build the model
39 | val numIterations = 1000
40 | val model = SVMWithSGD.train(training, numIterations)
41 |
42 | val trainError = training.map { point =>
43 | val score = model.predict(point.features)
44 | score == point.label
45 | }.collect().toList.count(_ == true).toDouble / training.count().toDouble
46 |
47 | val testError = test.map { point =>
48 | val score = model.predict(point.features)
49 | score == point.label
50 | }.collect().toList.count(_ == true).toDouble / test.count().toDouble
51 |
52 | println("Training accuracy = " + trainError)
53 | println("Test accuracy = " + testError)
54 | }
55 |
56 | /**
57 | * DBCFW classifier
58 | */
59 | def dbcfwCov(args: Array[String]) {
60 | /**
61 | * Load all options
62 | */
63 | val (solverOptions, kwargs) = CLAParser.argsToOptions[Vector[Double], Double](args)
64 | val covPath = kwargs.getOrElse("input_path", "../data/generated/covtype.libsvm.binary.scale")
65 | val appname = kwargs.getOrElse("appname", "cov_binary")
66 | val debugPath = kwargs.getOrElse("debug_file", "cov_binary-%d.csv".format(System.currentTimeMillis() / 1000))
67 | solverOptions.debugInfoPath = debugPath
68 |
69 | println(covPath)
70 | println(kwargs)
71 |
72 | println("Current directory:" + new File(".").getAbsolutePath)
73 |
74 | // Fix seed for reproducibility
75 | util.Random.setSeed(1)
76 |
77 | val conf = new SparkConf().setAppName(appname)
78 | val sc = new SparkContext(conf)
79 | sc.setCheckpointDir("checkpoint-files")
80 |
81 | // Labels needs to be in a +1/-1 format
82 | val data = MLUtils
83 | .loadLibSVMFile(sc, covPath)
84 | .map {
85 | case x: LabeledPoint =>
86 | val label =
87 | if (x.label == 1)
88 | +1.00
89 | else
90 | -1.00
91 | LabeledPoint(label, x.features)
92 | }
93 |
94 | // Split data into training and test set
95 | val splits = data.randomSplit(Array(0.8, 0.2), seed = 1L)
96 | val training = splits(0)
97 | val test = splits(1)
98 |
99 | val objectifiedTest: RDD[LabeledObject[Vector[Double], Double]] =
100 | test.map {
101 | case x: LabeledPoint =>
102 | new LabeledObject[Vector[Double], Double](x.label, Vector(x.features.toArray)) // Is the asInstanceOf required?
103 | }
104 |
105 | solverOptions.testDataRDD = Some(objectifiedTest)
106 | val model = BinarySVMWithDBCFW.train(training, solverOptions)
107 |
108 | // Test Errors
109 | val trueTestPredictions =
110 | objectifiedTest.map {
111 | case x: LabeledObject[Vector[Double], Double] =>
112 | val prediction = model.predict(x.pattern)
113 | if (prediction == x.label)
114 | 1
115 | else
116 | 0
117 | }.fold(0)((acc, ele) => acc + ele)
118 |
119 | println("Accuracy on Test set = %d/%d = %.4f".format(trueTestPredictions,
120 | objectifiedTest.count(),
121 | (trueTestPredictions.toDouble / objectifiedTest.count().toDouble) * 100))
122 | }
123 |
124 | def main(args: Array[String]): Unit = {
125 |
126 | PropertyConfigurator.configure("conf/log4j.properties")
127 |
128 | System.setProperty("spark.akka.frameSize", "512")
129 |
130 | dbcfwCov(args)
131 | }
132 |
133 | }
--------------------------------------------------------------------------------
/dissolve-struct-examples/src/main/scala/ch/ethz/dalab/dissolve/examples/binaryclassification/RCV1Binary.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.examples.binaryclassification
2 |
3 | import org.apache.log4j.PropertyConfigurator
4 | import org.apache.spark.SparkConf
5 | import org.apache.spark.SparkContext
6 | import org.apache.spark.mllib.regression.LabeledPoint
7 | import org.apache.spark.mllib.util.MLUtils
8 | import org.apache.spark.rdd.RDD
9 | import breeze.linalg.SparseVector
10 | import breeze.linalg.Vector
11 | import ch.ethz.dalab.dissolve.classification.BinarySVMWithDBCFW
12 | import ch.ethz.dalab.dissolve.examples.utils.ExampleUtils
13 | import ch.ethz.dalab.dissolve.optimization.GapThresholdCriterion
14 | import ch.ethz.dalab.dissolve.optimization.RoundLimitCriterion
15 | import ch.ethz.dalab.dissolve.optimization.SolverOptions
16 | import ch.ethz.dalab.dissolve.optimization.TimeLimitCriterion
17 | import ch.ethz.dalab.dissolve.regression.LabeledObject
18 | import ch.ethz.dalab.dissolve.utils.cli.CLAParser
19 |
20 | object RCV1Binary {
21 |
22 | def dbcfwRcv1(args: Array[String]) {
23 | /**
24 | * Load all options
25 | */
26 | val (solverOptions, kwargs) = CLAParser.argsToOptions[Vector[Double], Double](args)
27 | val rcv1Path = kwargs.getOrElse("input_path", "../data/generated/rcv1_train.binary")
28 | val appname = kwargs.getOrElse("appname", "rcv1_binary")
29 | val debugPath = kwargs.getOrElse("debug_file", "rcv1_binary-%d.csv".format(System.currentTimeMillis() / 1000))
30 | solverOptions.debugInfoPath = debugPath
31 |
32 | println(rcv1Path)
33 | println(kwargs)
34 |
35 | // Fix seed for reproducibility
36 | util.Random.setSeed(1)
37 |
38 | println(solverOptions.toString())
39 |
40 | val conf = new SparkConf().setAppName(appname)
41 | val sc = new SparkContext(conf)
42 | sc.setCheckpointDir("checkpoint-files")
43 |
44 | // Labels needs to be in a +1/-1 format
45 | val data = MLUtils.loadLibSVMFile(sc, rcv1Path)
46 |
47 | // Split data into training and test set
48 | val splits = data.randomSplit(Array(0.8, 0.2), seed = 1L)
49 | val training = splits(0)
50 | val test = splits(1)
51 |
52 | val objectifiedTest: RDD[LabeledObject[Vector[Double], Double]] =
53 | test.map {
54 | case x: LabeledPoint =>
55 | new LabeledObject[Vector[Double], Double](x.label, SparseVector(x.features.toArray)) // Is the asInstanceOf required?
56 | }
57 |
58 | solverOptions.testDataRDD = Some(objectifiedTest)
59 | val model = BinarySVMWithDBCFW.train(training, solverOptions)
60 |
61 | // Test Errors
62 | val trueTestPredictions =
63 | objectifiedTest.map {
64 | case x: LabeledObject[Vector[Double], Double] =>
65 | val prediction = model.predict(x.pattern)
66 | if (prediction == x.label)
67 | 1
68 | else
69 | 0
70 | }.fold(0)((acc, ele) => acc + ele)
71 |
72 | println("Accuracy on Test set = %d/%d = %.4f".format(trueTestPredictions,
73 | objectifiedTest.count(),
74 | (trueTestPredictions.toDouble / objectifiedTest.count().toDouble) * 100))
75 | }
76 |
77 | def main(args: Array[String]): Unit = {
78 |
79 | PropertyConfigurator.configure("conf/log4j.properties")
80 |
81 | System.setProperty("spark.akka.frameSize", "512")
82 |
83 | dbcfwRcv1(args)
84 |
85 | }
86 |
87 | }
--------------------------------------------------------------------------------
/dissolve-struct-examples/src/main/scala/ch/ethz/dalab/dissolve/examples/imageseg/ImageSeg.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.examples.imageseg
2 |
3 | import scala.collection.mutable.ArrayBuffer
4 |
5 | import breeze.linalg.DenseMatrix
6 | import breeze.linalg.DenseVector
7 | import breeze.linalg.Vector
8 | import breeze.linalg.normalize
9 | import cc.factorie.infer.MaximizeByBPLoopy
10 | import cc.factorie.la.DenseTensor1
11 | import cc.factorie.la.Tensor
12 | import cc.factorie.model._
13 | import cc.factorie.singleFactorIterable
14 | import cc.factorie.variable.DiscreteDomain
15 | import cc.factorie.variable.DiscreteVariable
16 | import ch.ethz.dalab.dissolve.classification.StructSVMModel
17 | import ch.ethz.dalab.dissolve.examples.imageseg.ImageSegTypes.AdjacencyList
18 | import ch.ethz.dalab.dissolve.examples.imageseg.ImageSegTypes.Label
19 | import ch.ethz.dalab.dissolve.examples.imageseg.ImageSegTypes.RGB_INT
20 | import ch.ethz.dalab.dissolve.optimization.DissolveFunctions
21 |
22 | /**
23 | * `data` is a `d` x `N` matrix; each column contains the features of super-pixel
24 | * `transitions` is an adjacency list. transitions(i) contains neighbours of super-pixel `i+1`'th super-pixel
25 | * `pixelMapping` contains mapping of super-pixel to corresponding pixel in the original image (column-major)
26 | */
27 | case class QuantizedImage(unaries: DenseMatrix[Double],
28 | pairwise: Array[Array[Int]],
29 | pixelMapping: Array[Int],
30 | width: Int,
31 | height: Int,
32 | filename: String = "NA",
33 | unaryFeatures: DenseMatrix[Double] = null,
34 | rgbArray: Array[RGB_INT] = null,
35 | globalFeatures: Vector[Double] = null)
36 |
37 | /**
38 | * labels(i) contains label for `i`th super-pixel
39 | */
40 | case class QuantizedLabel(labels: Array[Int],
41 | filename: String = "NA")
42 |
43 | /**
44 | * Functions for dissolve^struct
45 | *
46 | * Designed for the MSRC-21 dataset
47 | */
48 | object ImageSeg
49 | extends DissolveFunctions[QuantizedImage, QuantizedLabel] {
50 |
51 | val NUM_CLASSES: Int = 22 // # Classes (0-indexed)
52 | val BACKGROUND_CLASS: Int = 21 // The last label
53 |
54 | val INTENSITY_LEVELS: Int = 8
55 | val NUM_BINS = INTENSITY_LEVELS * INTENSITY_LEVELS * INTENSITY_LEVELS // Size of feature vector x_i
56 |
57 | var DISABLE_PAIRWISE = false
58 |
59 | /**
60 | * ======= Joint Feature Map =======
61 | */
62 | def featureFn(x: QuantizedImage, y: QuantizedLabel): Vector[Double] = {
63 |
64 | val numSuperpixels = x.unaries.cols // # Super-pixels
65 | val classifierScore = x.unaries.rows // Score of SP per label
66 |
67 | assert(numSuperpixels == x.pairwise.length,
68 | "numSuperpixels == x.pairwise.length")
69 |
70 | /**
71 | * Unary Features
72 | */
73 | val d = x.unaryFeatures.rows
74 | assert(d == NUM_BINS, "d == NUM_BINS")
75 |
76 | val dCombined =
77 | if (x.globalFeatures != null)
78 | d + x.globalFeatures.size
79 | else
80 | d
81 |
82 | val unaryFeatures = DenseMatrix.zeros[Double](dCombined, NUM_CLASSES)
83 | for (superIdx <- 0 until numSuperpixels) {
84 | val x_i = x.unaryFeatures(::, superIdx)
85 | val x_global = x.globalFeatures
86 |
87 | val x_comb =
88 | if (x_global == null)
89 | x_i
90 | else
91 | Vector(Array.concat(x_i.toArray, x_global.toArray))
92 | val label = y.labels(superIdx)
93 | unaryFeatures(::, label) += x_comb
94 | }
95 |
96 | if (DISABLE_PAIRWISE)
97 | unaryFeatures.toDenseVector
98 | else {
99 | /**
100 | * Pairwise features
101 | */
102 | val transitions = DenseMatrix.zeros[Double](NUM_CLASSES, NUM_CLASSES)
103 | for (superIdx <- 0 until numSuperpixels) {
104 | val thisLabel = y.labels(superIdx)
105 |
106 | x.pairwise(superIdx).foreach {
107 | case adjacentSuperIdx =>
108 | val nextLabel = y.labels(adjacentSuperIdx)
109 |
110 | transitions(thisLabel, nextLabel) += 1.0
111 | transitions(nextLabel, thisLabel) += 1.0
112 | }
113 | }
114 | DenseVector.vertcat(unaryFeatures.toDenseVector,
115 | normalize(transitions.toDenseVector, 2))
116 | }
117 | }
118 |
119 | /**
120 | * Per-label Hamming loss
121 | */
122 | def perLabelLoss(labTruth: Label, labPredict: Label): Double =
123 | if (labTruth == labPredict)
124 | 0.0
125 | else
126 | 1.0
127 |
128 | /**
129 | * ======= Structured Error Function =======
130 | */
131 | def lossFn(yTruth: QuantizedLabel, yPredict: QuantizedLabel): Double = {
132 |
133 | assert(yTruth.labels.size == yPredict.labels.size,
134 | "Failed: yTruth.labels.size == yPredict.labels.size")
135 |
136 | val stuctHammingLoss = yTruth.labels
137 | .zip(yPredict.labels)
138 | .map {
139 | case (labTruth, labPredict) =>
140 | perLabelLoss(labTruth, labPredict)
141 | }
142 |
143 | // Return normalized hamming loss
144 | stuctHammingLoss.sum / stuctHammingLoss.length
145 | }
146 |
147 | /**
148 | * Construct Factor graph and run MAP
149 | * (Max-product using Loopy Belief Propogation)
150 | */
151 | def decode(unaryPot: DenseMatrix[Double],
152 | pairwisePot: DenseMatrix[Double],
153 | adj: AdjacencyList): Array[Label] = {
154 |
155 | val nSuperpixels = unaryPot.cols
156 | val nClasses = unaryPot.rows
157 |
158 | assert(nClasses == NUM_CLASSES)
159 | if (!DISABLE_PAIRWISE)
160 | assert(pairwisePot.rows == NUM_CLASSES)
161 |
162 | object PixelDomain extends DiscreteDomain(nClasses)
163 |
164 | class Pixel(i: Int) extends DiscreteVariable(i) {
165 | def domain = PixelDomain
166 | }
167 |
168 | def getUnaryFactor(yi: Pixel, idx: Int): Factor = {
169 | new Factor1(yi) {
170 | val weights: DenseTensor1 = new DenseTensor1(unaryPot(::, idx).toArray)
171 | def score(k: Pixel#Value) = unaryPot(k.intValue, idx)
172 | override def valuesScore(tensor: Tensor): Double = {
173 | weights dot tensor
174 | }
175 | }
176 | }
177 |
178 | def getPairwiseFactor(yi: Pixel, yj: Pixel): Factor = {
179 | new Factor2(yi, yj) {
180 | val weights: DenseTensor1 = new DenseTensor1(pairwisePot.toArray)
181 | def score(i: Pixel#Value, j: Pixel#Value) = pairwisePot(i.intValue, j.intValue)
182 | override def valuesScore(tensor: Tensor): Double = {
183 | weights dot tensor
184 | }
185 | }
186 | }
187 |
188 | val pixelSeq: IndexedSeq[Pixel] =
189 | (0 until nSuperpixels).map(x => new Pixel(12))
190 |
191 | val unaryFactors: IndexedSeq[Factor] =
192 | (0 until nSuperpixels).map {
193 | case idx =>
194 | getUnaryFactor(pixelSeq(idx), idx)
195 | }
196 |
197 | val model = new ItemizedModel
198 | model ++= unaryFactors
199 |
200 | if (!DISABLE_PAIRWISE) {
201 | val pairwiseFactors =
202 | (0 until nSuperpixels).flatMap {
203 | case thisIdx =>
204 | val thisFactors = new ArrayBuffer[Factor]
205 |
206 | adj(thisIdx).foreach {
207 | case nextIdx =>
208 | thisFactors ++=
209 | getPairwiseFactor(pixelSeq(thisIdx), pixelSeq(nextIdx))
210 | }
211 | thisFactors
212 | }
213 | model ++= pairwiseFactors
214 | }
215 |
216 | MaximizeByBPLoopy.maximize(pixelSeq, model)
217 |
218 | val mapLabels: Array[Label] = (0 until nSuperpixels).map {
219 | idx =>
220 | pixelSeq(idx).intValue
221 | }.toArray
222 |
223 | mapLabels
224 | }
225 |
226 | /**
227 | * Unpack weight vector to Unary and Pairwise weights
228 | */
229 |
230 | def unpackWeightVec(weightv: DenseVector[Double], d: Int): (DenseMatrix[Double], DenseMatrix[Double]) = {
231 |
232 | assert(weightv.size >= (NUM_CLASSES * d))
233 |
234 | val unaryWeights = weightv(0 until NUM_CLASSES * d)
235 | val unaryWeightMat = unaryWeights.toDenseMatrix.reshape(d, NUM_CLASSES)
236 |
237 | val pairwisePot =
238 | if (!DISABLE_PAIRWISE) {
239 | assert(weightv.size == (NUM_CLASSES * d) + (NUM_CLASSES * NUM_CLASSES))
240 | val pairwiseWeights = weightv((NUM_CLASSES * d) until weightv.size)
241 | pairwiseWeights.toDenseMatrix.reshape(NUM_CLASSES, NUM_CLASSES)
242 | } else null
243 |
244 | (unaryWeightMat, pairwisePot)
245 | }
246 |
247 | /**
248 | * ======= Maximization Oracle =======
249 | */
250 | override def oracleFn(model: StructSVMModel[QuantizedImage, QuantizedLabel],
251 | xi: QuantizedImage,
252 | yi: QuantizedLabel): QuantizedLabel = {
253 |
254 | val nSuperpixels = xi.unaryFeatures.cols
255 | val d = xi.unaryFeatures.rows
256 | val dComb =
257 | if (xi.globalFeatures == null)
258 | d
259 | else
260 | d + xi.globalFeatures.size
261 |
262 | assert(xi.pairwise.length == nSuperpixels,
263 | "xi.pairwise.length == nSuperpixels")
264 | assert(xi.unaryFeatures.cols == nSuperpixels,
265 | "xi.unaryFeatures.cols == nSuperpixels")
266 |
267 | val (unaryWeights, pairwisePot) = unpackWeightVec(model.weights.toDenseVector, dComb)
268 | val localFeatures =
269 | if (xi.globalFeatures == null)
270 | xi.unaryFeatures
271 | else {
272 | // Concatenate global features to local features
273 | // The order is : local || global
274 | val dGlob = xi.globalFeatures.size
275 | val glob = xi.globalFeatures.toDenseVector
276 | val globalFeatures = DenseMatrix.zeros[Double](dGlob, nSuperpixels)
277 | for (superIdx <- 0 until nSuperpixels) {
278 | globalFeatures(::, superIdx) := glob
279 | }
280 | DenseMatrix.vertcat(xi.unaryFeatures, globalFeatures)
281 | }
282 | val unaryPot = unaryWeights.t * localFeatures
283 |
284 | if (yi != null) {
285 | assert(yi.labels.length == xi.pairwise.length,
286 | "yi.labels.length == xi.pairwise.length")
287 |
288 | // Loss augment the scores
289 | for (superIdx <- 0 until nSuperpixels) {
290 | val trueLabel = yi.labels(superIdx)
291 | // FIXME Use \delta here
292 | unaryPot(::, superIdx) += (1.0 / nSuperpixels)
293 | unaryPot(trueLabel, superIdx) -= (1.0 / nSuperpixels)
294 | }
295 | }
296 |
297 | val t0 = System.currentTimeMillis()
298 | val decodedLabels = decode(unaryPot, pairwisePot, xi.pairwise)
299 | val oracleSolution = QuantizedLabel(decodedLabels, xi.filename)
300 | val t1 = System.currentTimeMillis()
301 |
302 | oracleSolution
303 | }
304 |
305 | def predictFn(model: StructSVMModel[QuantizedImage, QuantizedLabel],
306 | xi: QuantizedImage): QuantizedLabel = {
307 | oracleFn(model, xi, null)
308 | }
309 |
310 | }
--------------------------------------------------------------------------------
/dissolve-struct-examples/src/main/scala/ch/ethz/dalab/dissolve/examples/imageseg/ImageSegRunner.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.examples.imageseg
2 |
3 | import java.nio.file.Paths
4 | import org.apache.log4j.PropertyConfigurator
5 | import org.apache.spark.SparkConf
6 | import org.apache.spark.SparkContext
7 | import ch.ethz.dalab.dissolve.classification.StructSVMModel
8 | import ch.ethz.dalab.dissolve.classification.StructSVMWithDBCFW
9 | import ch.ethz.dalab.dissolve.utils.cli.CLAParser
10 | import javax.imageio.ImageIO
11 |
12 | /**
13 | * @author torekond
14 | */
15 | object ImageSegRunner {
16 |
17 | def main(args: Array[String]): Unit = {
18 |
19 | PropertyConfigurator.configure("conf/log4j.properties")
20 | System.setProperty("spark.akka.frameSize", "512")
21 |
22 | val startTime = System.currentTimeMillis() / 1000
23 |
24 | /**
25 | * Load all options
26 | */
27 | val (solverOptions, kwargs) = CLAParser.argsToOptions[QuantizedImage, QuantizedLabel](args)
28 | val dataDir = kwargs.getOrElse("input_path", "../data/generated/msrc")
29 | val appname = kwargs.getOrElse("appname", "imageseg-%d".format(startTime))
30 | val debugPath = kwargs.getOrElse("debug_file", "imageseg-%d.csv".format(startTime))
31 |
32 | val unariesOnly = kwargs.getOrElse("unaries", "true").toBoolean
33 |
34 | val trainFile = kwargs.getOrElse("train", "Train.txt")
35 | val validationFile = kwargs.getOrElse("validation", "Validation.txt")
36 | solverOptions.debugInfoPath = debugPath
37 |
38 | println(dataDir)
39 | println(kwargs)
40 |
41 | solverOptions.doLineSearch = true
42 |
43 | if (unariesOnly)
44 | ImageSeg.DISABLE_PAIRWISE = true
45 |
46 | /**
47 | * Setup Spark
48 | */
49 | val conf = new SparkConf().setAppName(appname).setMaster("local")
50 | val sc = new SparkContext(conf)
51 | sc.setCheckpointDir("checkpoint-files")
52 |
53 | val trainFilePath = Paths.get(dataDir, trainFile)
54 | val valFilePath = Paths.get(dataDir, validationFile)
55 |
56 | val trainDataSeq = ImageSegUtils.loadData(dataDir, trainFilePath)
57 | val valDataSeq = ImageSegUtils.loadData(dataDir, valFilePath)
58 |
59 | val trainData = sc.parallelize(trainDataSeq, 1).cache
60 | val valData = sc.parallelize(valDataSeq, 1).cache
61 |
62 | solverOptions.testDataRDD = Some(valData)
63 |
64 | println(solverOptions)
65 |
66 | val trainer: StructSVMWithDBCFW[QuantizedImage, QuantizedLabel] =
67 | new StructSVMWithDBCFW[QuantizedImage, QuantizedLabel](
68 | trainData,
69 | ImageSeg,
70 | solverOptions)
71 |
72 | val model: StructSVMModel[QuantizedImage, QuantizedLabel] = trainer.trainModel()
73 |
74 | // Create directories for image out, if it doesn't exist
75 | val imageOutDir = Paths.get(dataDir, "debug", appname)
76 | if (!imageOutDir.toFile().exists())
77 | imageOutDir.toFile().mkdirs()
78 |
79 | println("Test time!")
80 | for (lo <- trainDataSeq) {
81 | val t0 = System.currentTimeMillis()
82 | val prediction = model.predict(lo.pattern)
83 | val t1 = System.currentTimeMillis()
84 |
85 | val filename = lo.pattern.filename
86 | val format = "bmp"
87 | val outPath = Paths.get(imageOutDir.toString(), "train-%s.%s".format(filename, format))
88 |
89 | // Image
90 | val imgPath = Paths.get(dataDir.toString(), "All", "%s.bmp".format(filename))
91 | val img = ImageIO.read(imgPath.toFile())
92 |
93 | val width = img.getWidth()
94 | val height = img.getHeight()
95 |
96 | // Write loss info
97 | val predictTime: Long = t1 - t0
98 | val loss: Double = ImageSeg.lossFn(prediction, lo.label)
99 | val text = "filename = %s\nprediction time = %d ms\nerror = %f\n#spx = %d".format(filename, predictTime, loss, lo.pattern.unaries.cols)
100 | val textInfoImg = ImageSegUtils.getImageWithText(width, height, text)
101 |
102 | // GT
103 | val gtImage = ImageSegUtils.getQuantizedLabelImage(lo.label,
104 | lo.pattern.pixelMapping,
105 | lo.pattern.width,
106 | lo.pattern.height)
107 |
108 | // Prediction
109 | val predImage = ImageSegUtils.getQuantizedLabelImage(prediction,
110 | lo.pattern.pixelMapping,
111 | lo.pattern.width,
112 | lo.pattern.height)
113 |
114 | val prettyOut = ImageSegUtils.printImageTile(img, gtImage, textInfoImg, predImage)
115 | ImageSegUtils.writeImage(prettyOut, outPath.toString())
116 |
117 | }
118 |
119 | for (lo <- valDataSeq) {
120 | val t0 = System.currentTimeMillis()
121 | val prediction = model.predict(lo.pattern)
122 | val t1 = System.currentTimeMillis()
123 |
124 | val filename = lo.pattern.filename
125 | val format = "bmp"
126 | val outPath = Paths.get(imageOutDir.toString(), "val-%s.%s".format(filename, format))
127 |
128 | // Image
129 | val imgPath = Paths.get(dataDir.toString(), "All", "%s.bmp".format(filename))
130 | val img = ImageIO.read(imgPath.toFile())
131 |
132 | val width = img.getWidth()
133 | val height = img.getHeight()
134 |
135 | // Write loss info
136 | val predictTime: Long = t1 - t0
137 | val loss: Double = ImageSeg.lossFn(prediction, lo.label)
138 | val text = "filename = %s\nprediction time = %d ms\nerror = %f\n#spx = %d".format(filename, predictTime, loss, lo.pattern.unaries.cols)
139 | val textInfoImg = ImageSegUtils.getImageWithText(width, height, text)
140 |
141 | // GT
142 | val gtImage = ImageSegUtils.getQuantizedLabelImage(lo.label,
143 | lo.pattern.pixelMapping,
144 | lo.pattern.width,
145 | lo.pattern.height)
146 |
147 | // Prediction
148 | val predImage = ImageSegUtils.getQuantizedLabelImage(prediction,
149 | lo.pattern.pixelMapping,
150 | lo.pattern.width,
151 | lo.pattern.height)
152 |
153 | val prettyOut = ImageSegUtils.printImageTile(img, gtImage, textInfoImg, predImage)
154 | ImageSegUtils.writeImage(prettyOut, outPath.toString())
155 |
156 | }
157 |
158 | val unaryDebugPath =
159 | Paths.get("/home/torekond/dev-local/dissolve-struct/data/generated/msrc/debug",
160 | "%s-unary.csv".format(appname))
161 | val transDebugPath =
162 | Paths.get("/home/torekond/dev-local/dissolve-struct/data/generated/msrc/debug",
163 | "%s-trans.csv".format(appname))
164 | val weights = model.getWeights().toDenseVector
165 |
166 | val xi = trainDataSeq(0).pattern
167 | val d =
168 | if (xi.globalFeatures == null)
169 | xi.unaryFeatures.rows
170 | else
171 | xi.unaryFeatures.rows + xi.globalFeatures.size
172 | val (unaryMat, transMat) = ImageSeg.unpackWeightVec(weights, d)
173 |
174 | breeze.linalg.csvwrite(unaryDebugPath.toFile(), unaryMat)
175 | if (transMat != null)
176 | breeze.linalg.csvwrite(transDebugPath.toFile(), transMat)
177 |
178 | }
179 |
180 | }
--------------------------------------------------------------------------------
/dissolve-struct-examples/src/main/scala/ch/ethz/dalab/dissolve/examples/imageseg/ImageSegTypes.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.examples.imageseg
2 |
3 | /**
4 | * @author torekond
5 | */
6 | object ImageSegTypes {
7 | type Label = Int
8 | type Index = Int
9 | type SuperIndex = Int
10 | type RGB_INT = Int
11 | type AdjacencyList = Array[Array[SuperIndex]]
12 | }
--------------------------------------------------------------------------------
/dissolve-struct-examples/src/main/scala/ch/ethz/dalab/dissolve/examples/multiclass/COVMulticlass.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.examples.multiclass
2 |
3 | import org.apache.log4j.PropertyConfigurator
4 | import org.apache.spark.SparkConf
5 | import org.apache.spark.SparkContext
6 | import org.apache.spark.mllib.regression.LabeledPoint
7 | import org.apache.spark.mllib.util.MLUtils
8 | import org.apache.spark.rdd.RDD
9 |
10 | import breeze.linalg.Vector
11 | import ch.ethz.dalab.dissolve.classification.MultiClassLabel
12 | import ch.ethz.dalab.dissolve.classification.MultiClassSVMWithDBCFW
13 | import ch.ethz.dalab.dissolve.regression.LabeledObject
14 | import ch.ethz.dalab.dissolve.utils.cli.CLAParser
15 |
16 | object COVMulticlass {
17 |
18 | def dissoveCovMulti(args: Array[String]) {
19 |
20 | /**
21 | * Load all options
22 | */
23 | val (solverOptions, kwargs) = CLAParser.argsToOptions[Vector[Double], MultiClassLabel](args)
24 | val covPath = kwargs.getOrElse("input_path", "../data/generated/covtype.scale")
25 | val appname = kwargs.getOrElse("appname", "cov_multi")
26 | val debugPath = kwargs.getOrElse("debug_file", "cov_multi-%d.csv".format(System.currentTimeMillis() / 1000))
27 | solverOptions.debugInfoPath = debugPath
28 |
29 | println(covPath)
30 | println(kwargs)
31 |
32 | // Fix seed for reproducibility
33 | util.Random.setSeed(1)
34 |
35 | val conf = new SparkConf().setAppName(appname)
36 |
37 | val sc = new SparkContext(conf)
38 | sc.setCheckpointDir("checkpoint-files")
39 |
40 | // Needs labels \in [0, numClasses)
41 | val data: RDD[LabeledPoint] = MLUtils
42 | .loadLibSVMFile(sc, covPath)
43 | .map {
44 | case x: LabeledPoint =>
45 | val label = x.label - 1
46 | LabeledPoint(label, x.features)
47 | }
48 |
49 | val minlabel = data.map(_.label).min()
50 | val maxlabel = data.map(_.label).max()
51 | println("min = %f, max = %f".format(minlabel, maxlabel))
52 |
53 | // Split data into training and test set
54 | val splits = data.randomSplit(Array(0.8, 0.2), seed = 1L)
55 | val training = splits(0)
56 | val test = splits(1)
57 |
58 | val numClasses = 7
59 |
60 | val objectifiedTest: RDD[LabeledObject[Vector[Double], MultiClassLabel]] =
61 | test.map {
62 | case x: LabeledPoint =>
63 | new LabeledObject[Vector[Double], MultiClassLabel](MultiClassLabel(x.label, numClasses),
64 | Vector(x.features.toArray))
65 | }
66 |
67 | solverOptions.testDataRDD = Some(objectifiedTest)
68 | val model = MultiClassSVMWithDBCFW.train(data, numClasses, solverOptions)
69 |
70 | // Test Errors
71 | val trueTestPredictions =
72 | objectifiedTest.map {
73 | case x: LabeledObject[Vector[Double], MultiClassLabel] =>
74 | val prediction = model.predict(x.pattern)
75 | if (prediction == x.label)
76 | 1
77 | else
78 | 0
79 | }.fold(0)((acc, ele) => acc + ele)
80 |
81 | println("Accuracy on Test set = %d/%d = %.4f".format(trueTestPredictions,
82 | objectifiedTest.count(),
83 | (trueTestPredictions.toDouble / objectifiedTest.count().toDouble) * 100))
84 |
85 | }
86 |
87 | def main(args: Array[String]): Unit = {
88 |
89 | PropertyConfigurator.configure("conf/log4j.properties")
90 |
91 | System.setProperty("spark.akka.frameSize", "512")
92 |
93 | dissoveCovMulti(args)
94 | }
95 |
96 | }
--------------------------------------------------------------------------------
/dissolve-struct-examples/src/main/scala/ch/ethz/dalab/dissolve/examples/utils/ExampleUtils.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.examples.utils
2 |
3 | import scala.io.Source
4 | import scala.util.Random
5 |
6 | object ExampleUtils {
7 |
8 | /**
9 | * Obtained adj and noun from: https://github.com/kohsuke/wordnet-random-name
10 | */
11 | val adjFilename = "adj.txt"
12 | val nounFilename = "noun.txt"
13 |
14 | val adjList = Source.fromURL(getClass.getResource("/adj.txt")).getLines().toArray
15 | val nounList = Source.fromURL(getClass.getResource("/noun.txt")).getLines().toArray
16 |
17 | def getRandomElement[T](lst: Seq[T]): T = lst(Random.nextInt(lst.size))
18 |
19 | def generateExperimentName(prefix: Seq[String] = List.empty, suffix: Seq[String] = List.empty, separator: String = "-"): String = {
20 |
21 | val nameList: Seq[String] = prefix ++ List(getRandomElement(adjList), getRandomElement(nounList)) ++ suffix
22 |
23 | val separatedNameList = nameList
24 | .flatMap { x => x :: separator :: Nil } // Juxtapose with separators
25 | .dropRight(1) // Drop the last separator
26 |
27 | separatedNameList.reduce(_ + _)
28 | }
29 |
30 | def main(args: Array[String]): Unit = {
31 | println(generateExperimentName())
32 | }
33 |
34 | }
--------------------------------------------------------------------------------
/dissolve-struct-examples/src/test/scala/ch/ethz/dalab/dissolve/diagnostics/FeatureFnSpec.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.diagnostics
2 |
3 | /**
4 | * @author torekond
5 | */
6 | class FeatureFnSpec extends UnitSpec {
7 |
8 | val NUM_ATTEMPTS = 100 // = # times each test case is attempted
9 |
10 | // A sample datapoint
11 | val lo = data(0)
12 | // Size of joint feature map
13 | val d = phi(lo.pattern, lo.label).size
14 | // No. of data examples
15 | val M = data.length
16 |
17 | "dim( ϕ(x_m, y_m) )" should "be fixed for all GIVEN (x_m, y_m)" in {
18 |
19 | val dimDiffSeq: Seq[Int] = for (k <- 0 until NUM_ATTEMPTS) yield {
20 |
21 | // Choose a random example
22 | val m = scala.util.Random.nextInt(M)
23 | val lo = data(m)
24 | val x_m = lo.pattern
25 | val y_m = lo.label
26 |
27 | phi(x_m, y_m).size - d
28 |
29 | }
30 |
31 | // This should be empty
32 | val uneqDimDiffSeq: Seq[Int] = dimDiffSeq.filter(_ != 0)
33 |
34 | assert(uneqDimDiffSeq.length == 0,
35 | "%d / %d cases failed".format(uneqDimDiffSeq.length, dimDiffSeq.length))
36 |
37 | }
38 |
39 | it should "be fixed for all PERTURBED (x_m, y_m)" in {
40 |
41 | val dimDiffSeq: Seq[Int] = for (k <- 0 until NUM_ATTEMPTS) yield {
42 |
43 | // Choose a random example
44 | val m = scala.util.Random.nextInt(M)
45 | val lo = data(m)
46 | val x_m = lo.pattern
47 | val y_m = perturb(lo.label, 0.5)
48 |
49 | phi(x_m, y_m).size - d
50 |
51 | }
52 |
53 | // This should be empty
54 | val uneqDimDiffSeq: Seq[Int] = dimDiffSeq.filter(_ != 0)
55 |
56 | assert(uneqDimDiffSeq.length == 0,
57 | "%d / %d cases failed".format(uneqDimDiffSeq.length, dimDiffSeq.length))
58 |
59 | }
60 | }
--------------------------------------------------------------------------------
/dissolve-struct-examples/src/test/scala/ch/ethz/dalab/dissolve/diagnostics/OracleSpec.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.diagnostics
2 |
3 | import org.scalatest.FlatSpec
4 | import breeze.linalg._
5 | import scala.collection.mutable.ArrayBuffer
6 | import ch.ethz.dalab.dissolve.regression.LabeledObject
7 |
8 | /**
9 | * @author torekond
10 | */
11 | class OracleSpec extends UnitSpec {
12 |
13 | val NUM_WEIGHT_VECS = 500 // = # times each test case is attempted
14 | val EPSILON = 2.2204E-16
15 |
16 | // A sample datapoint
17 | val lo = data(0)
18 | // Size of joint feature map
19 | val d = phi(lo.pattern, lo.label).size
20 | // No. of data examples
21 | val M = data.length
22 |
23 | /**
24 | * Initialize a bunch of weight vectors
25 | */
26 | type WeightVector = Vector[Double]
27 | val weightVectors: Array[WeightVector] = {
28 | val seq = new ArrayBuffer[WeightVector]
29 |
30 | // All 1's
31 | seq += DenseVector.ones[Double](d)
32 |
33 | // All 0's
34 | seq += DenseVector.zeros[Double](d)
35 |
36 | // Few random vectors
37 | for (k <- 0 until NUM_WEIGHT_VECS - seq.size)
38 | seq += DenseVector.rand(d)
39 |
40 | seq.toArray
41 | }
42 |
43 | /**
44 | * Tests - Structured Hinge Loss
45 | */
46 |
47 | "Structured Hinge Loss [Δ(y_m, y*) - < w, ψ(x_m, y*) >]" should "be >= 0 for GIVEN (x_m, y_m) pairs" in {
48 |
49 | val shlSeq: Seq[Double] = for (k <- 0 until NUM_WEIGHT_VECS) yield {
50 |
51 | // Set weight vector
52 | val w: WeightVector = weightVectors(k)
53 | model.setWeights(w)
54 |
55 | // Choose a random example
56 | val m = scala.util.Random.nextInt(M)
57 | val lo = data(m)
58 | val x_m = lo.pattern
59 | val y_m = lo.label
60 |
61 | // Get loss-augmented argmax prediction
62 | val ystar = maxoracle(model, x_m, y_m)
63 | val shl = delta(y_m, ystar) - deltaF(lo, ystar, w)
64 |
65 | shl
66 |
67 | }
68 |
69 | // This should be empty
70 | val negShlSeq: Seq[Double] = shlSeq.filter(_ < -EPSILON)
71 |
72 | assert(negShlSeq.length == 0,
73 | "%d / %d cases failed".format(negShlSeq.length, shlSeq.length))
74 |
75 | }
76 |
77 | it should "be >= 0 for PERTURBED (x_m, y_m) pairs" in {
78 |
79 | val shlSeq: Seq[Double] = for (k <- 0 until NUM_WEIGHT_VECS) yield {
80 |
81 | // Set weight vector
82 | val w: WeightVector = weightVectors(k)
83 | model.setWeights(w)
84 |
85 | // Sample a random (x, y) pair
86 | val m = scala.util.Random.nextInt(M)
87 | val x_m = data(m).pattern
88 | val y_m = perturb(data(m).label, 0.1) // Perturbed copy
89 | val lo = LabeledObject(y_m, x_m)
90 |
91 | // Get loss-augmented argmax prediction
92 | val ystar = maxoracle(model, x_m, y_m)
93 | val shl = delta(y_m, ystar) - deltaF(lo, ystar, w)
94 |
95 | shl
96 |
97 | }
98 |
99 | // This should be empty
100 | val negShlSeq: Seq[Double] = shlSeq.filter(_ < -EPSILON)
101 |
102 | assert(negShlSeq.length == 0,
103 | "%d / %d cases failed".format(negShlSeq.length, shlSeq.length))
104 |
105 | }
106 |
107 | /**
108 | * Tests - Discriminant function
109 | */
110 | "F(x_m, y*)" should "be >= F(x_m, y_m)" in {
111 |
112 | val diffSeq = for (k <- 0 until NUM_WEIGHT_VECS) yield {
113 | // Set weight vector
114 | val w: WeightVector = weightVectors(k)
115 | model.setWeights(w)
116 |
117 | // Choose a random example
118 | val m = scala.util.Random.nextInt(M)
119 | val lo = data(m)
120 | val x_m = lo.pattern
121 | val y_m = lo.label
122 |
123 | // Get argmax prediction
124 | val ystar = predict(model, x_m)
125 |
126 | val F_ystar = F(x_m, ystar, w)
127 | val F_gt = F(x_m, y_m, w)
128 |
129 | val diff = F_ystar - F_gt
130 |
131 | diff
132 | }
133 |
134 | /*println(diffSeq)*/
135 |
136 | // This should be empty
137 | val negDiffSeq: Seq[Double] = diffSeq.filter(_ < -EPSILON)
138 |
139 | assert(negDiffSeq.length == 0,
140 | "%d / %d cases failed".format(negDiffSeq.length, diffSeq.length))
141 |
142 | }
143 |
144 | it should "be >= PERTURBED F(x_m, y_m)" in {
145 |
146 | val diffSeq = for (k <- 0 until NUM_WEIGHT_VECS) yield {
147 | // Set weight vector
148 | val w: WeightVector = weightVectors(k)
149 | model.setWeights(w)
150 |
151 | // Choose a random example
152 | val m = scala.util.Random.nextInt(M)
153 | val lo = data(m)
154 | val x_m = lo.pattern
155 | val y_m = perturb(lo.label, 0.1)
156 |
157 | // Get argmax prediction
158 | val ystar = predict(model, x_m)
159 |
160 | val F_ystar = F(x_m, ystar, w)
161 | val F_gt = F(x_m, y_m, w)
162 |
163 | F_ystar - F_gt
164 | }
165 |
166 | /*println(diffSeq)*/
167 |
168 | // This should be empty
169 | val negDiffSeq: Seq[Double] = diffSeq.filter(_ < -EPSILON)
170 |
171 | assert(negDiffSeq.length == 0,
172 | "%d / %d cases failed".format(negDiffSeq.length, diffSeq.length))
173 |
174 | }
175 |
176 | "H(w; x_m, y_m)" should "be >= Δ(y_m, y_m) + F(x_m, y_m)" in {
177 |
178 | val diffSeq = for (k <- 0 until NUM_WEIGHT_VECS) yield {
179 | // Set weight vector
180 | val w: WeightVector = weightVectors(k)
181 | model.setWeights(w)
182 |
183 | // Choose a random example
184 | val m = scala.util.Random.nextInt(M)
185 | val lo = data(m)
186 | val x_m = lo.pattern
187 | val y_m = lo.label
188 |
189 | // Get loss-augmented argmax prediction
190 | val ystar = maxoracle(model, x_m, y_m)
191 |
192 | val H = delta(y_m, ystar) - deltaF(lo, ystar, w)
193 | val F_loss_aug = delta(y_m, y_m) - deltaF(lo, y_m, w)
194 |
195 | H - F_loss_aug
196 | }
197 |
198 | /*println(diffSeq)*/
199 |
200 | // This should be empty
201 | val negDiffSeq: Seq[Double] = diffSeq.filter(_ < -EPSILON)
202 |
203 | assert(negDiffSeq.length == 0,
204 | "%d / %d cases failed".format(negDiffSeq.length, diffSeq.length))
205 | }
206 |
207 | it should "be >= PERTURBED Δ(y_m, y_m) + F(x_m, y_m)" in {
208 |
209 | val diffSeq = for (k <- 0 until NUM_WEIGHT_VECS) yield {
210 | // Set weight vector
211 | val w: WeightVector = weightVectors(k)
212 | model.setWeights(w)
213 |
214 | // Choose a random example
215 | val m = scala.util.Random.nextInt(M)
216 | val lo = data(m)
217 | val x_m = lo.pattern
218 | val y_m = perturb(lo.label, 0.1)
219 |
220 | // Get loss-augmented argmax prediction
221 | val ystar = maxoracle(model, x_m, y_m)
222 |
223 | val H = delta(y_m, ystar) - deltaF(lo, ystar, w)
224 | val F_loss_aug = delta(y_m, y_m) - deltaF(lo, y_m, w)
225 |
226 | H - F_loss_aug
227 | }
228 |
229 | /*println(diffSeq)*/
230 |
231 | // This should be empty
232 | val negDiffSeq: Seq[Double] = diffSeq.filter(_ < -EPSILON)
233 |
234 | assert(negDiffSeq.length == 0,
235 | "%d / %d cases failed".format(negDiffSeq.length, diffSeq.length))
236 | }
237 |
238 | }
--------------------------------------------------------------------------------
/dissolve-struct-examples/src/test/scala/ch/ethz/dalab/dissolve/diagnostics/StructLossSpec.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.diagnostics
2 |
3 | /**
4 | * @author torekond
5 | */
6 | class StructLossSpec extends UnitSpec {
7 |
8 | val NUM_ATTEMPTS = 100 // = # times each test case is attempted
9 |
10 | // A sample datapoint
11 | val lo = data(0)
12 | // Size of joint feature map
13 | val d = phi(lo.pattern, lo.label).size
14 | // No. of data examples
15 | val M = data.length
16 |
17 | "Δ(y, y)" should "= 0" in {
18 |
19 | val lossSeq: Seq[Double] = for (k <- 0 until NUM_ATTEMPTS) yield {
20 |
21 | // Choose a random example
22 | val m = scala.util.Random.nextInt(M)
23 | val lo = data(m)
24 | val x_m = lo.pattern
25 | val y_m = lo.label
26 |
27 | delta(y_m, y_m)
28 |
29 | }
30 |
31 | // This should be empty
32 | val uneqLossSeq: Seq[Double] = lossSeq.filter(_ != 0.0)
33 |
34 | assert(uneqLossSeq.length == 0,
35 | "%d / %d cases failed".format(uneqLossSeq.length, lossSeq.length))
36 | }
37 |
38 | "Δ(y, y')" should ">= 0" in {
39 |
40 | val lossSeq: Seq[Double] = for (k <- 0 until NUM_ATTEMPTS) yield {
41 |
42 | // Choose a random example
43 | val m = scala.util.Random.nextInt(M)
44 | val lo = data(m)
45 | val x_m = lo.pattern
46 | val degree = scala.util.Random.nextDouble()
47 | val y_m = perturb(lo.label, degree)
48 |
49 | delta(y_m, y_m)
50 |
51 | }
52 |
53 | // This should be empty
54 | val uneqLossSeq: Seq[Double] = lossSeq.filter(_ < 0.0)
55 |
56 | assert(uneqLossSeq.length == 0,
57 | "%d / %d cases failed".format(uneqLossSeq.length, lossSeq.length))
58 | }
59 |
60 | }
--------------------------------------------------------------------------------
/dissolve-struct-examples/src/test/scala/ch/ethz/dalab/dissolve/diagnostics/UnitSpec.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.diagnostics
2 |
3 | import java.nio.file.Paths
4 | import org.scalatest.FlatSpec
5 | import org.scalatest.Inside
6 | import org.scalatest.Inspectors
7 | import org.scalatest.Matchers
8 | import org.scalatest.OptionValues
9 | import breeze.linalg.DenseVector
10 | import breeze.linalg.Matrix
11 | import breeze.linalg.Vector
12 | import ch.ethz.dalab.dissolve.classification.StructSVMModel
13 | import ch.ethz.dalab.dissolve.examples.chain.ChainDemo
14 | import ch.ethz.dalab.dissolve.examples.imageseg.ImageSeg
15 | import ch.ethz.dalab.dissolve.examples.imageseg.ImageSegUtils
16 | import ch.ethz.dalab.dissolve.examples.imageseg.QuantizedImage
17 | import ch.ethz.dalab.dissolve.examples.imageseg.QuantizedLabel
18 | import ch.ethz.dalab.dissolve.optimization.DissolveFunctions
19 | import ch.ethz.dalab.dissolve.regression.LabeledObject
20 | import breeze.linalg.max
21 |
22 | object ImageTestAdapter {
23 | type X = QuantizedImage
24 | type Y = QuantizedLabel
25 |
26 | /**
27 | * Dissolve Functions
28 | */
29 | val dissolveFunctions: DissolveFunctions[X, Y] = ImageSeg
30 | /**
31 | * Some Data
32 | */
33 | val data = {
34 | val dataDir = "../data/generated/msrc"
35 | val trainFilePath = Paths.get(dataDir, "Train.txt")
36 | val trainDataSeq = ImageSegUtils.loadData(dataDir, trainFilePath, limit = 50)
37 |
38 | trainDataSeq
39 | }
40 | /**
41 | * A dummy model
42 | */
43 | val lo = data(0)
44 | val numd = ImageSeg.featureFn(lo.pattern, lo.label).size
45 | val model: StructSVMModel[X, Y] =
46 | new StructSVMModel[X, Y](DenseVector.zeros(numd), 0.0,
47 | DenseVector.zeros(numd), dissolveFunctions, 1)
48 |
49 | def perturb(y: Y, degree: Double = 0.1): Y = {
50 | val d = y.labels.size
51 | val numSwaps = max(1, (degree * d).toInt)
52 |
53 | for (swapNo <- 0 until numSwaps) {
54 | // Swap two random values in y
55 | val (i, j) = (scala.util.Random.nextInt(d), scala.util.Random.nextInt(d))
56 | val temp = y.labels(i)
57 | y.labels(i) = y.labels(j)
58 | y.labels(j) = temp
59 | }
60 |
61 | y
62 |
63 | }
64 | }
65 |
66 | object ChainTestAdapter {
67 | type X = Matrix[Double]
68 | type Y = Vector[Double]
69 |
70 | /**
71 | * Dissolve Functions
72 | */
73 | val dissolveFunctions: DissolveFunctions[X, Y] = ChainDemo
74 | /**
75 | * Some Data
76 | */
77 | val data = {
78 | val dataDir = "../data/generated"
79 | val trainDataSeq: Vector[LabeledObject[Matrix[Double], Vector[Double]]] =
80 | ChainDemo.loadData(dataDir + "/patterns_train.csv",
81 | dataDir + "/labels_train.csv",
82 | dataDir + "/folds_train.csv")
83 |
84 | trainDataSeq.toArray
85 | }
86 | /**
87 | * A dummy model
88 | */
89 | val lo = data(0)
90 | val numd = ChainDemo.featureFn(lo.pattern, lo.label).size
91 | val model: StructSVMModel[X, Y] =
92 | new StructSVMModel[X, Y](DenseVector.zeros(numd), 0.0,
93 | DenseVector.zeros(numd), dissolveFunctions, 1)
94 |
95 | /**
96 | * Perturb
97 | * Return a compatible perturbed Y
98 | * Higher the degree, more perturbed y is
99 | *
100 | * This function perturbs `degree` of the values by swapping
101 | */
102 | def perturb(y: Y, degree: Double = 0.1): Y = {
103 | val d = y.size
104 | val numSwaps = max(1, (degree * d).toInt)
105 |
106 | for (swapNo <- 0 until numSwaps) {
107 | // Swap two random values in y
108 | val (i, j) = (scala.util.Random.nextInt(d), scala.util.Random.nextInt(d))
109 | val temp = y(i)
110 | y(i) = y(j)
111 | y(j) = temp
112 | }
113 |
114 | y
115 |
116 | }
117 | }
118 |
119 | /**
120 | * @author torekond
121 | */
122 | abstract class UnitSpec extends FlatSpec with Matchers with OptionValues with Inside with Inspectors {
123 |
124 | val DissolveAdapter = ImageTestAdapter
125 |
126 | type X = DissolveAdapter.X
127 | type Y = DissolveAdapter.Y
128 |
129 | val dissolveFunctions = DissolveAdapter.dissolveFunctions
130 | val data = DissolveAdapter.data
131 | val model = DissolveAdapter.model
132 |
133 | /**
134 | * Helper functions
135 | */
136 | def perturb = DissolveAdapter.perturb _
137 |
138 | // Joint Feature Map
139 | def phi = dissolveFunctions.featureFn _
140 | def delta = dissolveFunctions.lossFn _
141 | def maxoracle = dissolveFunctions.oracleFn _
142 | def predict = dissolveFunctions.predictFn _
143 |
144 | def psi(lo: LabeledObject[X, Y], ymap: Y) =
145 | phi(lo.pattern, lo.label) - phi(lo.pattern, ymap)
146 |
147 | def F(x: X, y: Y, w: Vector[Double]) =
148 | w dot phi(x, y)
149 | def deltaF(lo: LabeledObject[X, Y], ystar: Y, w: Vector[Double]) =
150 | w dot psi(lo, ystar)
151 |
152 | }
--------------------------------------------------------------------------------
/dissolve-struct-lib/.gitignore:
--------------------------------------------------------------------------------
1 | /bin/
2 |
--------------------------------------------------------------------------------
/dissolve-struct-lib/build.sbt:
--------------------------------------------------------------------------------
1 | name := "DissolveStruct"
2 |
3 | organization := "ch.ethz.dalab"
4 |
5 | version := "0.1-SNAPSHOT"
6 |
7 | scalaVersion := "2.10.4"
8 |
9 | libraryDependencies += "org.apache.spark" %% "spark-core" % "1.4.1"
10 |
11 | libraryDependencies += "org.apache.spark" %% "spark-mllib" % "1.4.1"
12 |
13 | libraryDependencies += "org.scalanlp" %% "breeze" % "0.11.1"
14 |
15 | libraryDependencies += "org.scalanlp" %% "breeze-natives" % "0.11.1"
16 |
17 | libraryDependencies += "org.scalatest" % "scalatest_2.10" % "2.0" % "test"
18 |
19 | libraryDependencies += "com.github.scopt" %% "scopt" % "3.3.0"
20 |
21 | resolvers += Resolver.sonatypeRepo("public")
22 |
23 | mergeStrategy in assembly <<= (mergeStrategy in assembly) { (old) =>
24 | {
25 | case PathList("javax", "servlet", xs @ _*) => MergeStrategy.first
26 | case PathList(ps @ _*) if ps.last endsWith ".html" => MergeStrategy.first
27 | case "application.conf" => MergeStrategy.concat
28 | case "reference.conf" => MergeStrategy.concat
29 | case "log4j.properties" => MergeStrategy.discard
30 | case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard
31 | case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard
32 | case _ => MergeStrategy.first
33 | }
34 | }
35 |
36 | test in assembly := {}
37 |
--------------------------------------------------------------------------------
/dissolve-struct-lib/lib/jython.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dalab/dissolve-struct/67f37377b74c32cf05d8f43a0e3658a10864f9bf/dissolve-struct-lib/lib/jython.jar
--------------------------------------------------------------------------------
/dissolve-struct-lib/project/assembly.sbt:
--------------------------------------------------------------------------------
1 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.13.0")
2 |
--------------------------------------------------------------------------------
/dissolve-struct-lib/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "3.0.0")
2 |
--------------------------------------------------------------------------------
/dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/classification/BinarySVMWithDBCFW.scala:
--------------------------------------------------------------------------------
1 |
2 | /**
3 | *
4 | */
5 | package ch.ethz.dalab.dissolve.classification
6 |
7 | import java.io.FileWriter
8 | import org.apache.spark.mllib.regression.LabeledPoint
9 | import org.apache.spark.rdd.RDD
10 | import breeze.linalg.DenseVector
11 | import breeze.linalg.SparseVector
12 | import breeze.linalg.Vector
13 | import ch.ethz.dalab.dissolve.optimization.DBCFWSolverTuned
14 | import ch.ethz.dalab.dissolve.optimization.DissolveFunctions
15 | import ch.ethz.dalab.dissolve.optimization.SolverOptions
16 | import ch.ethz.dalab.dissolve.optimization.SolverUtils
17 | import ch.ethz.dalab.dissolve.regression.LabeledObject
18 | import breeze.linalg.VectorBuilder
19 | import scala.collection.mutable.HashMap
20 | import org.apache.spark.rdd.PairRDDFunctions
21 |
22 | /**
23 | * @author tribhu
24 | *
25 | */
26 | object BinarySVMWithDBCFW extends DissolveFunctions[Vector[Double], Double] {
27 |
28 | val labelToWeight = HashMap[Double, Double]()
29 |
30 | override def classWeights(label: Double): Double = {
31 | labelToWeight.get(label).getOrElse(3.0)
32 | }
33 |
34 | def generateClassWeights(data: RDD[LabeledPoint]): Unit = {
35 | val labels: Array[Double] = data.map { x => x.label }.distinct().collect()
36 |
37 | val classOccur: PairRDDFunctions[Double, Double] = data.map(x => (x.label, 1.0))
38 | val labelOccur: PairRDDFunctions[Double, Double] = classOccur.reduceByKey((x, y) => x + y)
39 | val labelWeight: PairRDDFunctions[Double, Double] = labelOccur.mapValues { x => 1 / x }
40 |
41 | val weightSum: Double = labelWeight.values.sum()
42 | val nClasses: Int = 2
43 | val scaleValue: Double = nClasses / weightSum
44 |
45 | for ((label, weight) <- labelWeight.collectAsMap()) {
46 | labelToWeight.put(label, scaleValue * weight)
47 | }
48 | }
49 |
50 | /**
51 | * Feature function
52 | *
53 | * Analogous to phi(y) in (2)
54 | * Returns y_i * x_i
55 | *
56 | */
57 | def featureFn(x: Vector[Double], y: Double): Vector[Double] = {
58 | x * y
59 | }
60 |
61 | /**
62 | * Loss function
63 | *
64 | * Returns 0 if yTruth == yPredict, 1 otherwise
65 | * Equivalent to max(0, 1 - y w^T x)
66 | */
67 | def lossFn(yTruth: Double, yPredict: Double): Double =
68 | if (yTruth == yPredict)
69 | 0.0
70 | else
71 | 1.0
72 |
73 | /**
74 | * Maximization Oracle
75 | *
76 | * Want: max L(y_i, y) -
77 | * This returns the most violating (Loss-augmented) label.
78 | */
79 | override def oracleFn(model: StructSVMModel[Vector[Double], Double], xi: Vector[Double], yi: Double): Double = {
80 |
81 | val weights = model.getWeights()
82 |
83 | var score_neg1 = weights dot featureFn(xi, -1.0)
84 | var score_pos1 = weights dot featureFn(xi, 1.0)
85 |
86 | // Loss augment the scores
87 | score_neg1 += 1.0
88 | score_pos1 += 1.0
89 |
90 | if (yi == -1.0)
91 | score_neg1 -= 1.0
92 | else if (yi == 1.0)
93 | score_pos1 -= 1.0
94 | else
95 | throw new IllegalArgumentException("yi not in [-1, 1], yi = " + yi)
96 |
97 | if (score_neg1 > score_pos1)
98 | -1.0
99 | else
100 | 1.0
101 | }
102 |
103 | /**
104 | * Prediction function
105 | */
106 | def predictFn(model: StructSVMModel[Vector[Double], Double], xi: Vector[Double]): Double = {
107 |
108 | val weights = model.getWeights()
109 |
110 | val score_neg1 = weights dot featureFn(xi, -1.0)
111 | val score_pos1 = weights dot featureFn(xi, 1.0)
112 |
113 | if (score_neg1 > score_pos1)
114 | -1.0
115 | else
116 | +1.0
117 |
118 | }
119 |
120 | /**
121 | * Classifying with in-built functions
122 | */
123 | def train(
124 | data: RDD[LabeledPoint],
125 | solverOptions: SolverOptions[Vector[Double], Double]): StructSVMModel[Vector[Double], Double] = {
126 |
127 | train(data, this, solverOptions)
128 |
129 | }
130 |
131 | /**
132 | * Classifying with user-submitted functions
133 | */
134 | def train(
135 | data: RDD[LabeledPoint],
136 | dissolveFunctions: DissolveFunctions[Vector[Double], Double],
137 | solverOptions: SolverOptions[Vector[Double], Double]): StructSVMModel[Vector[Double], Double] = {
138 |
139 | if (solverOptions.classWeights) {
140 | generateClassWeights(data)
141 | }
142 |
143 | // Convert the RDD[LabeledPoint] to RDD[LabeledObject]
144 | val objectifiedData: RDD[LabeledObject[Vector[Double], Double]] =
145 | data.map {
146 | case x: LabeledPoint =>
147 | new LabeledObject[Vector[Double], Double](x.label,
148 | if (solverOptions.sparse) {
149 | val features: Vector[Double] = x.features match {
150 | case features: org.apache.spark.mllib.linalg.SparseVector =>
151 | val builder: VectorBuilder[Double] = new VectorBuilder(features.indices, features.values, features.indices.length, x.features.size)
152 | builder.toSparseVector
153 | case _ => SparseVector(x.features.toArray)
154 | }
155 | features
156 | } else
157 | DenseVector(x.features.toArray))
158 | }
159 |
160 | val repartData =
161 | if (solverOptions.enableManualPartitionSize)
162 | objectifiedData.repartition(solverOptions.NUM_PART)
163 | else
164 | objectifiedData
165 |
166 | println("Running BinarySVMWithDBCFW solver")
167 | println(solverOptions)
168 |
169 | val (trainedModel, debugInfo) = new DBCFWSolverTuned[Vector[Double], Double](
170 | repartData,
171 | dissolveFunctions,
172 | solverOptions,
173 | miniBatchEnabled = false).optimize()
174 |
175 | // Dump debug information into a file
176 | val fw = new FileWriter(solverOptions.debugInfoPath)
177 | // Write the current parameters being used
178 | fw.write(solverOptions.toString())
179 | fw.write("\n")
180 |
181 | // Write spark-specific parameters
182 | fw.write(SolverUtils.getSparkConfString(data.context.getConf))
183 | fw.write("\n")
184 |
185 | // Write values noted from the run
186 | fw.write(debugInfo)
187 | fw.close()
188 |
189 | println(debugInfo)
190 |
191 | trainedModel
192 |
193 | }
194 |
195 | }
--------------------------------------------------------------------------------
/dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/classification/BinarySVMWithSSG.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.classification
2 |
3 | import java.io.FileWriter
4 |
5 | import org.apache.spark.mllib.regression.LabeledPoint
6 | import org.apache.spark.rdd.RDD
7 |
8 | import breeze.linalg.DenseVector
9 | import breeze.linalg.SparseVector
10 | import breeze.linalg.Vector
11 | import ch.ethz.dalab.dissolve.optimization.SSGSolver
12 | import ch.ethz.dalab.dissolve.optimization.DissolveFunctions
13 | import ch.ethz.dalab.dissolve.optimization.SolverOptions
14 | import ch.ethz.dalab.dissolve.optimization.SolverUtils
15 | import ch.ethz.dalab.dissolve.regression.LabeledObject
16 |
17 | /**
18 | * @author thijs
19 | * Adapted from BinarySVMWithDBCFW
20 | *
21 | */
22 | object BinarySVMWithSSG extends DissolveFunctions[Vector[Double], Double] {
23 |
24 | /**
25 | * Feature function
26 | *
27 | * Analogous to phi(y) in (2)
28 | * Returns y_i * x_i
29 | *
30 | */
31 | def featureFn(x: Vector[Double], y: Double): Vector[Double] = {
32 | x * y
33 | }
34 |
35 | /**
36 | * Loss function
37 | *
38 | * Returns 0 if yTruth == yPredict, 1 otherwise
39 | * Equivalent to max(0, 1 - y w^T x)
40 | */
41 | def lossFn(yTruth: Double, yPredict: Double): Double =
42 | if (yTruth == yPredict)
43 | 0.0
44 | else
45 | 1.0
46 |
47 | /**
48 | * Maximization Oracle
49 | *
50 | * Want: max L(y_i, y) -
51 | * This returns the most violating (Loss-augmented) label.
52 | */
53 | override def oracleFn(model: StructSVMModel[Vector[Double], Double], xi: Vector[Double], yi: Double): Double = {
54 |
55 | val weights = model.getWeights()
56 |
57 | var score_neg1 = weights dot featureFn(xi, -1.0)
58 | var score_pos1 = weights dot featureFn(xi, 1.0)
59 |
60 | // Loss augment the scores
61 | score_neg1 += 1.0
62 | score_pos1 += 1.0
63 |
64 | if (yi == -1.0)
65 | score_neg1 -= 1.0
66 | else if (yi == 1.0)
67 | score_pos1 -= 1.0
68 | else
69 | throw new IllegalArgumentException("yi not in [-1, 1], yi = " + yi)
70 |
71 | if (score_neg1 > score_pos1)
72 | -1.0
73 | else
74 | 1.0
75 | }
76 |
77 | /**
78 | * Prediction function
79 | */
80 | def predictFn(model: StructSVMModel[Vector[Double], Double], xi: Vector[Double]): Double = {
81 |
82 | val weights = model.getWeights()
83 |
84 | val score_neg1 = weights dot featureFn(xi, -1.0)
85 | val score_pos1 = weights dot featureFn(xi, 1.0)
86 |
87 | if (score_neg1 > score_pos1)
88 | -1.0
89 | else
90 | +1.0
91 |
92 | }
93 |
94 | /**
95 | * Classifying with in-built functions
96 | */
97 | def train(
98 | data: Seq[LabeledPoint],
99 | solverOptions: SolverOptions[Vector[Double], Double]): StructSVMModel[Vector[Double], Double] = {
100 |
101 | train(data,this,solverOptions)
102 |
103 | }
104 |
105 | /**
106 | * Classifying with user-submitted functions
107 | */
108 | def train(
109 | data: Seq[LabeledPoint],
110 | dissolveFunctions: DissolveFunctions[Vector[Double], Double],
111 | solverOptions: SolverOptions[Vector[Double], Double]): StructSVMModel[Vector[Double], Double] = {
112 |
113 | // Convert the RDD[LabeledPoint] to RDD[LabeledObject]
114 | val objectifiedData: Seq[LabeledObject[Vector[Double], Double]] =
115 | data.map {
116 | case x: LabeledPoint =>
117 | new LabeledObject[Vector[Double], Double](x.label,
118 | if (solverOptions.sparse)
119 | SparseVector(x.features.toArray)
120 | else
121 | DenseVector(x.features.toArray))
122 | }
123 |
124 | println("Running BinarySVMWithSSGsolver")
125 | println(solverOptions)
126 |
127 | val trainedModel = new SSGSolver[Vector[Double], Double](
128 | objectifiedData,
129 | dissolveFunctions,
130 | solverOptions).optimize()
131 |
132 | // Dump debug information into a file
133 | val fw = new FileWriter(solverOptions.debugInfoPath)
134 | // Write the current parameters being used
135 | fw.write(solverOptions.toString())
136 | fw.write("\n")
137 |
138 |
139 | // Write values noted from the run
140 | fw.close()
141 |
142 |
143 | trainedModel
144 |
145 | }
146 |
147 | }
--------------------------------------------------------------------------------
/dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/classification/ClassificationUtils.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.classification
2 |
3 | import org.apache.spark.rdd.RDD
4 | import org.apache.spark.rdd.PairRDDFunctions
5 | import ch.ethz.dalab.dissolve.regression.LabeledObject
6 | import scala.collection.mutable.HashMap
7 | import ch.ethz.dalab.dissolve.regression.LabeledObject
8 | import scala.reflect.ClassTag
9 | import ch.ethz.dalab.dissolve.optimization.SolverOptions
10 | import ch.ethz.dalab.dissolve.regression.LabeledObject
11 | import scala.collection.mutable.MutableList
12 | import ch.ethz.dalab.dissolve.regression.LabeledObject
13 | import org.apache.spark.mllib.regression.LabeledPoint
14 |
15 | object ClassificationUtils {
16 |
17 | /**
18 | * Generates class weights. If classWeights is flase the default value of 1.0 is used. Otherwise if the user submitted a custom weight array
19 | * the weights in there will be used. If the user did not submit a custom array of weights the inverse class freq. will be used
20 | */
21 | def generateClassWeights[X, Y: ClassTag](data: RDD[LabeledObject[X, Y]], classWeights: Boolean = true, customWeights: Option[HashMap[Y,Double]] = None): HashMap[Y, Double] = {
22 | val map = HashMap[Y, Double]()
23 | val labels: Array[Y] = data.map { x: LabeledObject[X, Y] => x.label }.distinct().collect()
24 | if (classWeights) {
25 | if (customWeights.getOrElse(null) == null) {
26 | //inverse class frequency as weight
27 | val classOccur: PairRDDFunctions[Y, Double] = data.map(x => (x.label, 1.0))
28 | val labelOccur: PairRDDFunctions[Y, Double] = classOccur.reduceByKey((x, y) => x + y)
29 | val labelWeight: PairRDDFunctions[Y, Double] = labelOccur.mapValues { x => 1 / x }
30 |
31 | val weightSum: Double = labelWeight.values.sum()
32 | val nClasses: Int = labels.length
33 | val scaleValue: Double = nClasses / weightSum
34 |
35 | var sum: Double = 0.0
36 | for ((label, weight) <- labelWeight.collectAsMap()) {
37 | val clWeight = scaleValue * weight
38 | sum += clWeight
39 | map.put(label, clWeight)
40 | }
41 |
42 | assert(sum == nClasses)
43 | } else {
44 | //use custom weights
45 | assert(labels.length == customWeights.get.size)
46 | for (label <- labels) {
47 | map.put(label, customWeights.get(label))
48 | }
49 | }
50 | } else {
51 | // default weight of 1.0
52 | for (label <- labels) {
53 | map.put(label, 1.0)
54 | }
55 | }
56 | map
57 | }
58 |
59 | def resample[X,Y:ClassTag](data: RDD[LabeledObject[X, Y]],nSamples:HashMap[Y,Int],nSlices:Int): RDD[LabeledObject[X, Y]] = {
60 | val buckets: HashMap[Y, RDD[LabeledObject[X, Y]]] = HashMap()
61 | val newData = MutableList[LabeledObject[X, Y]]()
62 |
63 | val labels: Array[Y] = data.map { x => x.label }.distinct().collect()
64 |
65 | labels.foreach { x => buckets.put(x, data.filter { point => point.label == x }) }
66 |
67 | for (cls <- buckets.keySet) {
68 | val sampledData = buckets.get(cls).get.takeSample(true, nSamples.get(cls).get)
69 | for (x: LabeledObject[X, Y] <- sampledData) {
70 | newData.+=(x)
71 | }
72 | }
73 | data.context.parallelize(newData, nSlices)
74 | }
75 |
76 | def resample(data: RDD[LabeledPoint],nSamples:HashMap[Double,Int],nSlices:Int): RDD[LabeledPoint] = {
77 | val buckets: HashMap[Double, RDD[LabeledPoint]] = HashMap()
78 | val newData = MutableList[LabeledPoint]()
79 |
80 | val labels: Array[Double] = data.map { x => x.label }.distinct().collect()
81 |
82 | labels.foreach { x => buckets.put(x, data.filter { point => point.label == x }) }
83 |
84 | for (cls <- buckets.keySet) {
85 | val sampledData = buckets.get(cls).get.takeSample(true, nSamples.get(cls).get)
86 | for (x: LabeledPoint <- sampledData) {
87 | newData.+=(x)
88 | }
89 | }
90 | data.context.parallelize(newData, nSlices)
91 | }
92 |
93 | }
--------------------------------------------------------------------------------
/dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/classification/MultiClassSVMWithDBCFW.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.classification
2 |
3 | import ch.ethz.dalab.dissolve.optimization.SolverOptions
4 | import org.apache.spark.rdd.RDD
5 | import java.io.FileWriter
6 | import ch.ethz.dalab.dissolve.regression.LabeledObject
7 | import org.apache.spark.mllib.regression.LabeledPoint
8 | import breeze.linalg._
9 | import ch.ethz.dalab.dissolve.optimization.SolverUtils
10 | import ch.ethz.dalab.dissolve.optimization.DissolveFunctions
11 | import ch.ethz.dalab.dissolve.optimization.DBCFWSolverTuned
12 | import scala.collection.mutable.HashMap
13 | import org.apache.spark.rdd.PairRDDFunctions
14 | import ch.ethz.dalab.dissolve.optimization.SSGSolver
15 | import ch.ethz.dalab.dissolve.optimization.SSGSolver
16 | import ch.ethz.dalab.dissolve.optimization.UseDBCFWSolver
17 | import ch.ethz.dalab.dissolve.optimization.UseSSGSolver
18 |
19 | case class MultiClassLabel(label: Double, numClasses: Int)
20 |
21 | object MultiClassSVMWithDBCFW extends DissolveFunctions[Vector[Double], MultiClassLabel] {
22 |
23 | var labelToWeight = HashMap[MultiClassLabel, Double]()
24 |
25 |
26 | override def classWeights(label: MultiClassLabel): Double = {
27 | labelToWeight.get(label).getOrElse(1.0)
28 | }
29 |
30 | /**
31 | * Feature function
32 | *
33 | * Analogous to phi(y) in (2)
34 | * Returns y_i * x_i
35 | *
36 | */
37 | def featureFn(x: Vector[Double], y: MultiClassLabel): Vector[Double] = {
38 | assert(y.label.toInt < y.numClasses,
39 | "numClasses = %d. Found y_i.label = %d"
40 | .format(y.numClasses, y.label.toInt))
41 |
42 | val featureVector = Vector.zeros[Double](x.size * y.numClasses)
43 | val numDims = x.size
44 |
45 | // Populate the featureVector in blocks [ ...].
46 | val startIdx = y.label.toInt * numDims
47 | val endIdx = startIdx + numDims
48 |
49 | featureVector(startIdx until endIdx) := x
50 |
51 | featureVector
52 | }
53 |
54 | /**
55 | * Loss function
56 | *
57 | * Returns 0 if yTruth == yPredict, 1 otherwise
58 | * Equivalent to max(0, 1 - y w^T x)
59 | */
60 | def lossFn(yTruth: MultiClassLabel, yPredict: MultiClassLabel): Double =
61 | if (yTruth.label == yPredict.label)
62 | 0.0
63 | else
64 | 1.0
65 |
66 | /**
67 | * Maximization Oracle
68 | *
69 | * Want: argmax L(y_i, y) -
70 | * This returns the most violating (Loss-augmented) label.
71 | */
72 | override def oracleFn(model: StructSVMModel[Vector[Double], MultiClassLabel], xi: Vector[Double], yi: MultiClassLabel): MultiClassLabel = {
73 |
74 | val weights = model.getWeights()
75 | val numClasses = yi.numClasses
76 |
77 | // Obtain a list of scores for each class
78 | val mostViolatedContraint: (Double, Double) =
79 | (0 until numClasses).map {
80 | case cl =>
81 | (cl, weights dot featureFn(xi, MultiClassLabel(cl, numClasses)))
82 | }.map {
83 | case (cl, score) =>
84 | (cl.toDouble, score + 1.0)
85 | }.map { // Loss-augment the scores
86 | case (cl, score) =>
87 | if (yi.label == cl)
88 | (cl, score - 1.0)
89 | else
90 | (cl, score)
91 | }.maxBy { // Obtain the class with the maximum value
92 | case (cl, score) => score
93 | }
94 |
95 | MultiClassLabel(mostViolatedContraint._1, numClasses)
96 | }
97 |
98 | /**
99 | * Prediction function
100 | */
101 | def predictFn(model: StructSVMModel[Vector[Double], MultiClassLabel], xi: Vector[Double]): MultiClassLabel = {
102 |
103 | val weights = model.getWeights()
104 | val numClasses = model.numClasses
105 |
106 | assert(numClasses > 1)
107 |
108 | val prediction =
109 | (0 until numClasses).map {
110 | case cl =>
111 | (cl.toDouble, weights dot featureFn(xi, MultiClassLabel(cl, numClasses)))
112 | }.maxBy { // Obtain the class with the maximum value
113 | case (cl, score) => score
114 | }
115 |
116 | MultiClassLabel(prediction._1, numClasses)
117 |
118 | }
119 |
120 | /**
121 | * Classifying with in-built functions
122 | *
123 | * data needs to be 0-indexed
124 | */
125 | def train(
126 | data: RDD[LabeledPoint],
127 | numClasses: Int,
128 | solverOptions: SolverOptions[Vector[Double], MultiClassLabel],
129 | customWeights:Option[HashMap[MultiClassLabel,Double]]=None): StructSVMModel[Vector[Double], MultiClassLabel] = {
130 |
131 | solverOptions.numClasses = numClasses
132 |
133 | // Convert the RDD[LabeledPoint] to RDD[LabeledObject]
134 | val objectifiedData: RDD[LabeledObject[Vector[Double], MultiClassLabel]] =
135 | data.map {
136 | case x: LabeledPoint =>
137 | val features: Vector[Double] = x.features match {
138 | case features: org.apache.spark.mllib.linalg.SparseVector =>
139 | val builder: VectorBuilder[Double] = new VectorBuilder(features.indices, features.values, features.indices.length, x.features.size)
140 | builder.toSparseVector
141 | case _ => SparseVector(x.features.toArray)
142 | }
143 | new LabeledObject[Vector[Double], MultiClassLabel](MultiClassLabel(x.label, numClasses), features)
144 | }
145 |
146 | labelToWeight = ClassificationUtils.generateClassWeights(objectifiedData,solverOptions.classWeights,customWeights)
147 |
148 | val repartData =
149 | if (solverOptions.enableManualPartitionSize)
150 | objectifiedData.repartition(solverOptions.NUM_PART)
151 | else
152 | objectifiedData
153 |
154 | println(solverOptions)
155 |
156 |
157 | val (trainedModel,debugInfo) = solverOptions.solver match {
158 | case UseDBCFWSolver => new DBCFWSolverTuned[Vector[Double], MultiClassLabel](
159 | repartData,
160 | this,
161 | solverOptions,
162 | miniBatchEnabled = false).optimize()
163 | case UseSSGSolver => (new SSGSolver[Vector[Double], MultiClassLabel](
164 | repartData.collect(),
165 | this,
166 | solverOptions
167 | ).optimize(),"")
168 | }
169 |
170 | println(debugInfo)
171 |
172 | // Dump debug information into a file
173 | val fw = new FileWriter(solverOptions.debugInfoPath)
174 | // Write the current parameters being used
175 | fw.write(solverOptions.toString())
176 | fw.write("\n")
177 |
178 | // Write spark-specific parameters
179 | fw.write(SolverUtils.getSparkConfString(data.context.getConf))
180 | fw.write("\n")
181 |
182 | // Write values noted from the run
183 | fw.write(debugInfo)
184 | fw.close()
185 |
186 | trainedModel
187 |
188 | }
189 |
190 | /**
191 | * Classifying with user-submitted functions
192 | */
193 | def train(
194 | data: RDD[LabeledPoint],
195 | dissolveFunctions: DissolveFunctions[Vector[Double], MultiClassLabel],
196 | solverOptions: SolverOptions[Vector[Double], MultiClassLabel]): StructSVMModel[Vector[Double], MultiClassLabel] = {
197 |
198 | val numClasses = solverOptions.numClasses
199 | assert(numClasses > 1)
200 |
201 | val minlabel = data.map(_.label).min()
202 | val maxlabel = data.map(_.label).max()
203 | assert(minlabel == 0, "Label classes need to be 0-indexed")
204 | assert(maxlabel - minlabel + 1 == numClasses,
205 | "Number of classes in data do not tally with passed argument")
206 |
207 | // Convert the RDD[LabeledPoint] to RDD[LabeledObject]
208 | val objectifiedData: RDD[LabeledObject[Vector[Double], MultiClassLabel]] =
209 | data.map {
210 | case x: LabeledPoint =>
211 | new LabeledObject[Vector[Double], MultiClassLabel](MultiClassLabel(x.label, numClasses),
212 | if (solverOptions.sparse) {
213 | val features: Vector[Double] = x.features match {
214 | case features: org.apache.spark.mllib.linalg.SparseVector =>
215 | val builder: VectorBuilder[Double] = new VectorBuilder(features.indices, features.values, features.indices.length, x.features.size)
216 | builder.toSparseVector
217 | case _ => SparseVector(x.features.toArray)
218 | }
219 | features
220 | } else
221 | Vector(x.features.toArray))
222 | }
223 |
224 | val repartData =
225 | if (solverOptions.enableManualPartitionSize)
226 | objectifiedData.repartition(solverOptions.NUM_PART)
227 | else
228 | objectifiedData
229 |
230 | println(solverOptions)
231 |
232 | //choose optimizer
233 | val (trainedModel,debugInfo) = solverOptions.solver match {
234 | case UseDBCFWSolver => new DBCFWSolverTuned[Vector[Double], MultiClassLabel](
235 | repartData,
236 | dissolveFunctions,
237 | solverOptions,
238 | miniBatchEnabled = false).optimize()
239 | case UseSSGSolver => (new SSGSolver[Vector[Double], MultiClassLabel](
240 | repartData.collect(),
241 | dissolveFunctions,
242 | solverOptions
243 | ).optimize(),"")
244 | }
245 |
246 | // Dump debug information into a file
247 | val fw = new FileWriter(solverOptions.debugInfoPath)
248 | // Write the current parameters being used
249 | fw.write(solverOptions.toString())
250 | fw.write("\n")
251 |
252 | // Write spark-specific parameters
253 | fw.write(SolverUtils.getSparkConfString(data.context.getConf))
254 | fw.write("\n")
255 |
256 | // Write values noted from the run
257 | fw.write(debugInfo)
258 | fw.close()
259 |
260 | println(debugInfo)
261 |
262 | trainedModel
263 |
264 | }
265 | }
--------------------------------------------------------------------------------
/dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/classification/StructSVMModel.scala:
--------------------------------------------------------------------------------
1 | /**
2 | *
3 | */
4 | package ch.ethz.dalab.dissolve.classification
5 |
6 | import breeze.linalg._
7 | import ch.ethz.dalab.dissolve.optimization.DissolveFunctions
8 |
9 | /**
10 | * This is analogous to the model:
11 | * a) returned by Andrea Vedaldi's svm-struct-matlab
12 | * b) used by BCFWStruct. See (Lacoste-Julien, Jaggi, Schmidt, Pletscher; ICML 2013)
13 | *
14 | * @constructor Create a new StructSVM model
15 | * @param weights Primal variable. Corresponds to w in Algorithm 4
16 | * @param ell Corresponds to l in Algorithm 4
17 | * @param ellMat Corresponds to l_i in Algorithm 4
18 | * @param pred Prediction function
19 | */
20 | class StructSVMModel[X, Y](
21 | var weights: Vector[Double],
22 | var ell: Double,
23 | val ellMat: Vector[Double],
24 | val dissolveFunctions: DissolveFunctions[X, Y],
25 | val numClasses: Int) extends Serializable {
26 |
27 | def this(
28 | weights: Vector[Double],
29 | ell: Double,
30 | ellMat: Vector[Double],
31 | dissolveFunctions: DissolveFunctions[X, Y]) =
32 | this(weights, ell, ellMat, dissolveFunctions, -1)
33 |
34 |
35 | def getWeights(): Vector[Double] = {
36 | weights
37 | }
38 |
39 | def setWeights(newWeights: Vector[Double]) = {
40 | weights = newWeights
41 | }
42 |
43 | def getEll(): Double =
44 | ell
45 |
46 | def setEll(newEll: Double) =
47 | ell = newEll
48 |
49 | def predict(pattern: X): Y = {
50 | dissolveFunctions.predictFn(this, pattern)
51 | }
52 |
53 | override def clone(): StructSVMModel[X, Y] = {
54 | new StructSVMModel(this.weights.copy,
55 | ell,
56 | this.ellMat.copy,
57 | dissolveFunctions,
58 | numClasses)
59 | }
60 | }
--------------------------------------------------------------------------------
/dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/classification/StructSVMWithBCFW.scala:
--------------------------------------------------------------------------------
1 | /**
2 | *
3 | */
4 | package ch.ethz.dalab.dissolve.classification
5 |
6 | import java.io.FileWriter
7 |
8 | import scala.reflect.ClassTag
9 |
10 | import ch.ethz.dalab.dissolve.optimization.BCFWSolver
11 | import ch.ethz.dalab.dissolve.optimization.DissolveFunctions
12 | import ch.ethz.dalab.dissolve.optimization.SolverOptions
13 | import ch.ethz.dalab.dissolve.regression.LabeledObject
14 |
15 | /**
16 | * Analogous to BCFWSolver
17 | *
18 | *
19 | */
20 | class StructSVMWithBCFW[X, Y](
21 | val data: Seq[LabeledObject[X, Y]],
22 | val dissolveFunctions: DissolveFunctions[X, Y],
23 | val solverOptions: SolverOptions[X, Y]) {
24 |
25 | def trainModel()(implicit m: ClassTag[Y]): StructSVMModel[X, Y] = {
26 | val (trainedModel, debugInfo) = new BCFWSolver(data,
27 | dissolveFunctions,
28 | solverOptions).optimize()
29 |
30 | // Dump debug information into a file
31 | val fw = new FileWriter(solverOptions.debugInfoPath)
32 | // Write the current parameters being used
33 | fw.write("# BCFW\n")
34 | fw.write(solverOptions.toString())
35 | fw.write("\n")
36 |
37 | // Write values noted from the run
38 | fw.write(debugInfo)
39 | fw.close()
40 |
41 | trainedModel
42 | }
43 |
44 | }
--------------------------------------------------------------------------------
/dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/classification/StructSVMWithDBCFW.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.classification
2 |
3 | import ch.ethz.dalab.dissolve.regression.LabeledObject
4 | import breeze.linalg._
5 | import ch.ethz.dalab.dissolve.optimization.SolverOptions
6 | import org.apache.spark.SparkContext
7 | import org.apache.spark.rdd.RDD
8 | import java.io.FileWriter
9 | import ch.ethz.dalab.dissolve.optimization.SolverUtils
10 | import scala.reflect.ClassTag
11 | import ch.ethz.dalab.dissolve.optimization.DBCFWSolverTuned
12 | import ch.ethz.dalab.dissolve.optimization.DissolveFunctions
13 |
14 | class StructSVMWithDBCFW[X, Y](
15 | val data: RDD[LabeledObject[X, Y]],
16 | val dissolveFunctions: DissolveFunctions[X, Y],
17 | val solverOptions: SolverOptions[X, Y]) {
18 |
19 | def trainModel()(implicit m: ClassTag[Y]): StructSVMModel[X, Y] = {
20 | val (trainedModel, debugInfo) = new DBCFWSolverTuned[X, Y](
21 | data,
22 | dissolveFunctions,
23 | solverOptions,
24 | miniBatchEnabled = false).optimize()
25 |
26 | // Dump debug information into a file
27 | val fw = new FileWriter(solverOptions.debugInfoPath)
28 | // Write the current parameters being used
29 | fw.write(solverOptions.toString())
30 | fw.write("\n")
31 |
32 | // Write spark-specific parameters
33 | fw.write(SolverUtils.getSparkConfString(data.context.getConf))
34 | fw.write("\n")
35 |
36 | // Write values noted from the run
37 | fw.write(debugInfo)
38 | fw.close()
39 |
40 | print(debugInfo)
41 |
42 | // Return the trained model
43 | trainedModel
44 | }
45 | }
46 |
47 | object StructSVMWithDBCFW {
48 | def train[X, Y](data: RDD[LabeledObject[X, Y]],
49 | dissolveFunctions: DissolveFunctions[X, Y],
50 | solverOptions: SolverOptions[X, Y])(implicit m: ClassTag[Y]): StructSVMModel[X, Y] = {
51 | val (trainedModel, debugInfo) = new DBCFWSolverTuned[X, Y](
52 | data,
53 | dissolveFunctions,
54 | solverOptions,
55 | miniBatchEnabled = false).optimize()
56 |
57 | // Dump debug information into a file
58 | val fw = new FileWriter(solverOptions.debugInfoPath)
59 | // Write the current parameters being used
60 | fw.write(solverOptions.toString())
61 | fw.write("\n")
62 |
63 | // Write spark-specific parameters
64 | fw.write(SolverUtils.getSparkConfString(data.context.getConf))
65 | fw.write("\n")
66 |
67 | // Write values noted from the run
68 | fw.write(debugInfo)
69 | fw.close()
70 |
71 | print(debugInfo)
72 |
73 | // Return the trained model
74 | trainedModel
75 |
76 | }
77 | }
--------------------------------------------------------------------------------
/dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/classification/StructSVMWithMiniBatch.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.classification
2 |
3 | import scala.reflect.ClassTag
4 |
5 | import org.apache.spark.rdd.RDD
6 |
7 | import ch.ethz.dalab.dissolve.optimization.DBCFWSolverTuned
8 | import ch.ethz.dalab.dissolve.optimization.DissolveFunctions
9 | import ch.ethz.dalab.dissolve.optimization.SolverOptions
10 | import ch.ethz.dalab.dissolve.regression.LabeledObject
11 |
12 | class StructSVMWithMiniBatch[X, Y](
13 | val data: RDD[LabeledObject[X, Y]],
14 | val dissolveFunctions: DissolveFunctions[X, Y],
15 | val solverOptions: SolverOptions[X, Y]) {
16 |
17 | def trainModel()(implicit m: ClassTag[Y]): StructSVMModel[X, Y] =
18 | new DBCFWSolverTuned(
19 | data,
20 | dissolveFunctions,
21 | solverOptions,
22 | miniBatchEnabled = true).optimize()._1
23 | }
--------------------------------------------------------------------------------
/dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/classification/StructSVMWithSSG.scala:
--------------------------------------------------------------------------------
1 | /**
2 | *
3 | */
4 | package ch.ethz.dalab.dissolve.classification
5 |
6 | import scala.reflect.ClassTag
7 |
8 | import ch.ethz.dalab.dissolve.optimization.DissolveFunctions
9 | import ch.ethz.dalab.dissolve.optimization.SSGSolver
10 | import ch.ethz.dalab.dissolve.optimization.SolverOptions
11 | import ch.ethz.dalab.dissolve.regression.LabeledObject
12 |
13 | /**
14 | *
15 | */
16 | class StructSVMWithSSG[X, Y](
17 | val data: Seq[LabeledObject[X, Y]],
18 | val dissolveFunctions: DissolveFunctions[X, Y],
19 | val solverOptions: SolverOptions[X, Y]) {
20 |
21 | def trainModel()(implicit m: ClassTag[Y]): StructSVMModel[X, Y] =
22 | new SSGSolver(data,
23 | dissolveFunctions,
24 | solverOptions).optimize()
25 |
26 | }
--------------------------------------------------------------------------------
/dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/classification/Types.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.classification
2 |
3 | import breeze.linalg.{Vector, DenseVector}
4 | import scala.collection.mutable.MutableList
5 |
6 | object Types {
7 |
8 | type Index = Int
9 | type Level = Int
10 | type PrimalInfo = Tuple2[Vector[Double], Double]
11 | type BoundedCacheList[Y] = MutableList[Y]
12 |
13 | }
--------------------------------------------------------------------------------
/dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/optimization/DissolveFunctions.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.optimization
2 |
3 | import breeze.linalg.Vector
4 | import ch.ethz.dalab.dissolve.classification.StructSVMModel
5 |
6 | trait DissolveFunctions[X, Y] extends Serializable {
7 |
8 | def featureFn(x: X, y: Y): Vector[Double]
9 |
10 | def lossFn(yPredicted: Y, yTruth: Y): Double
11 |
12 | // Override either `oracleFn` or `oracleCandidateStream`
13 | def oracleFn(model: StructSVMModel[X, Y], x: X, y: Y): Y =
14 | oracleCandidateStream(model, x, y).head
15 |
16 | def oracleCandidateStream(model: StructSVMModel[X, Y], x: X, y: Y, initLevel: Int = 0): Stream[Y] =
17 | oracleFn(model, x, y) #:: Stream.empty
18 |
19 | def predictFn(model: StructSVMModel[X, Y], x: X): Y
20 |
21 | def classWeights(y:Y): Double = 1.0
22 |
23 | }
--------------------------------------------------------------------------------
/dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/optimization/SSGSolver.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.optimization
2 |
3 | import java.io.File
4 | import java.io.PrintWriter
5 |
6 | import breeze.linalg._
7 | import breeze.linalg.DenseVector
8 | import breeze.linalg.Vector
9 | import breeze.linalg.csvwrite
10 | import breeze.numerics._
11 | import ch.ethz.dalab.dissolve.classification.StructSVMModel
12 | import ch.ethz.dalab.dissolve.regression.LabeledObject
13 |
14 | /**
15 | * Train a structured SVM using standard Stochastic (Sub)Gradient Descent (SGD).
16 | * The implementation here is single machine, not distributed.
17 | *
18 | * Input:
19 | * Each data point (x_i, y_i) is composed of:
20 | * x_i, the data example
21 | * y_i, the label
22 | *
23 | * @param type for the data examples
24 | * @param type for the labels of each example
25 | */
26 | class SSGSolver[X, Y](
27 | val data: Seq[LabeledObject[X, Y]],
28 | val dissolveFunctions: DissolveFunctions[X, Y],
29 | val solverOptions: SolverOptions[X, Y]) {
30 |
31 | val roundLimit = solverOptions.roundLimit
32 | val lambda = solverOptions.lambda
33 | val debugOn: Boolean = solverOptions.debug
34 | val gamma0 = solverOptions.ssg_gamma0
35 |
36 | val maxOracle = dissolveFunctions.oracleFn _
37 | val phi = dissolveFunctions.featureFn _
38 | val lossFn = dissolveFunctions.lossFn _
39 | val cWeight = dissolveFunctions.classWeights _
40 | // Number of dimensions of \phi(x, y)
41 | val ndims: Int = phi(data(0).pattern, data(0).label).size
42 |
43 | // Filenames
44 | val lossWriterFileName = "data/debug/ssg-loss.csv"
45 |
46 | /**
47 | * SSG optimizer
48 | */
49 | def optimize(): StructSVMModel[X, Y] = {
50 |
51 | var k: Integer = 0
52 | val n: Int = data.length
53 | val d: Int = phi(data(0).pattern, data(0).label).size
54 | // Use first example to determine dimension of w
55 | val model: StructSVMModel[X, Y] = new StructSVMModel(DenseVector.zeros(phi(data(0).pattern, data(0).label).size),
56 | 0.0,
57 | DenseVector.zeros(ndims),
58 | dissolveFunctions)
59 |
60 | // Initialization in case of Weighted Averaging
61 | var wAvg: DenseVector[Double] =
62 | if (solverOptions.doWeightedAveraging)
63 | DenseVector.zeros(d)
64 | else null
65 |
66 | var debugIter = if (solverOptions.debugMultiplier == 0) {
67 | solverOptions.debugMultiplier = 100
68 | n
69 | } else {
70 | 1
71 | }
72 | val debugModel: StructSVMModel[X, Y] = new StructSVMModel(DenseVector.zeros(d), 0.0, DenseVector.zeros(ndims), dissolveFunctions)
73 |
74 | val lossWriter = if (solverOptions.debug) new PrintWriter(new File(lossWriterFileName)) else null
75 | if (solverOptions.debug) {
76 | if (solverOptions.testData != null)
77 | lossWriter.write("pass_num,iter,primal,dual,duality_gap,train_error,test_error\n")
78 | else
79 | lossWriter.write("pass_num,iter,primal,dual,duality_gap,train_error\n")
80 | }
81 |
82 | if (debugOn) {
83 | println("Beginning training of %d data points in %d passes with lambda=%f".format(n, roundLimit, lambda))
84 | }
85 |
86 | for (passNum <- 0 until roundLimit) {
87 |
88 | if (debugOn)
89 | println("Starting pass #%d".format(passNum))
90 |
91 | for (dummy <- 0 until n) {
92 | // 1) Pick example
93 | val i: Int = dummy
94 | val pattern: X = data(i).pattern
95 | val label: Y = data(i).label
96 |
97 | // 2) Solve loss-augmented inference for point i
98 | val ystar_i: Y = maxOracle(model, pattern, label)
99 |
100 | // 3) Get the subgradient
101 | val psi_i: Vector[Double] = (phi(pattern, label) - phi(pattern, ystar_i))*cWeight(label)
102 | val w_s: Vector[Double] = psi_i :* (1 / (n * lambda))
103 |
104 | if (debugOn && dummy == (n - 1))
105 | csvwrite(new File("data/debug/scala-w-%d.csv".format(passNum + 1)), w_s.toDenseVector.toDenseMatrix)
106 |
107 | // 4) Step size gamma
108 | val gamma: Double = 1.0 / (gamma0*(k + 1.0))
109 |
110 | // 5) Update the weights of the model
111 | val newWeights: Vector[Double] = (model.getWeights() :* (1 - gamma)) + (w_s :* (gamma * n))
112 | model.setWeights(newWeights)
113 |
114 | k = k + 1
115 |
116 | if (solverOptions.doWeightedAveraging) {
117 | val rho: Double = 2.0 / (k + 2.0)
118 | wAvg = wAvg * (1.0 - rho) + model.getWeights() * rho
119 | }
120 |
121 | if (debugOn && k >= debugIter) {
122 |
123 | if (solverOptions.doWeightedAveraging) {
124 | debugModel.setWeights(wAvg)
125 | } else {
126 | debugModel.setWeights(model.getWeights)
127 | }
128 |
129 | val primal = SolverUtils.primalObjective(data, dissolveFunctions, debugModel, lambda)
130 | val trainError = SolverUtils.averageLoss(data, dissolveFunctions, debugModel)._1
131 |
132 | if (solverOptions.testData != null) {
133 | val testError =
134 | if (solverOptions.testData.isDefined)
135 | SolverUtils.averageLoss(solverOptions.testData.get, dissolveFunctions, debugModel)._1
136 | else
137 | 0.00
138 | println("Pass %d Iteration %d, SVM primal = %f, Train error = %f, Test error = %f"
139 | .format(passNum + 1, k, primal, trainError, testError))
140 |
141 | if (solverOptions.debug)
142 | lossWriter.write("%d,%d,%f,%f,%f\n".format(passNum + 1, k, primal, trainError, testError))
143 | } else {
144 | println("Pass %d Iteration %d, SVM primal = %f, Train error = %f"
145 | .format(passNum + 1, k, primal, trainError))
146 | if (solverOptions.debug)
147 | lossWriter.write("%d,%d,%f,%f,\n".format(passNum + 1, k, primal, trainError))
148 | }
149 |
150 | debugIter = min(debugIter + n, ceil(debugIter * (1 + solverOptions.debugMultiplier / 100)))
151 |
152 | }
153 |
154 | }
155 | if (debugOn)
156 | println("Completed pass #%d".format(passNum))
157 |
158 | }
159 |
160 | return model
161 | }
162 |
163 | }
--------------------------------------------------------------------------------
/dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/optimization/SolverOptions.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.optimization
2 |
3 | import ch.ethz.dalab.dissolve.regression.LabeledObject
4 | import breeze.linalg.Vector
5 | import org.apache.spark.rdd.RDD
6 | import java.io.File
7 |
8 | sealed trait StoppingCriterion
9 |
10 | // Option A - Limit number of communication rounds
11 | case object RoundLimitCriterion extends StoppingCriterion {
12 | override def toString(): String = { "RoundLimitCriterion" }
13 | }
14 |
15 | // Option B - Check gap
16 | case object GapThresholdCriterion extends StoppingCriterion {
17 | override def toString(): String = { "GapThresholdCriterion" }
18 | }
19 |
20 | // Option C - Run for this amount of time (in secs)
21 | case object TimeLimitCriterion extends StoppingCriterion {
22 | override def toString(): String = { "TimeLimitCriterion" }
23 | }
24 |
25 | sealed trait Solver
26 |
27 | case object UseDBCFWSolver extends Solver{
28 | override def toString(): String = {"DBCFWSolver"}
29 | }
30 |
31 | case object UseSSGSolver extends Solver{
32 | override def toString(): String = {"SSGSolver"}
33 | }
34 |
35 | class SolverOptions[X, Y] extends Serializable {
36 | var doWeightedAveraging: Boolean = false
37 |
38 | var randSeed: Int = 42
39 | /**
40 | * BCFW - "uniform", "perm" or "iter"
41 | * DBCFW - "count", "frac"
42 | */
43 | var sample: String = "frac"
44 | var lambda: Double = 0.01 // FIXME This is 1/n in Matlab code
45 |
46 | var testData: Option[Seq[LabeledObject[X, Y]]] = Option.empty[Seq[LabeledObject[X, Y]]]
47 | var testDataRDD: Option[RDD[LabeledObject[X, Y]]] = Option.empty[RDD[LabeledObject[X, Y]]]
48 |
49 | var doLineSearch: Boolean = true
50 |
51 | // Checkpoint once in these many rounds
52 | var checkpointFreq: Int = 50
53 |
54 | // In case of multi-class
55 | var numClasses = -1
56 |
57 | var classWeights:Boolean = true
58 |
59 | // Cache params
60 | var enableOracleCache: Boolean = false
61 | var oracleCacheSize: Int = 10
62 |
63 | // DBCFW specific params
64 | var H: Int = 5 // Number of data points to sample in each round of CoCoA (= number of local coordinate updates)
65 | var sampleFrac: Double = 0.5
66 | var sampleWithReplacement: Boolean = false
67 |
68 | var enableManualPartitionSize: Boolean = false
69 | var NUM_PART: Int = 1 // Number of partitions of the RDD
70 |
71 | // SSG specific params
72 | var ssg_gamma0: Int = 1000
73 |
74 | // For debugging/Testing purposes
75 | // Basic debugging flag
76 | var debug: Boolean = false
77 | // Obtain statistics (primal value, duality gap, train error, test error, etc.) once in these many rounds.
78 | // If 1, obtains statistics in each round
79 | var debugMultiplier: Int = 1
80 |
81 | // Option A - Limit number of communication rounds
82 | var roundLimit: Int = 25
83 |
84 | // Option B - Check gap
85 | var gapThreshold: Double = 0.1
86 | var gapCheck: Int = 1 // Check for once these many rounds
87 |
88 | // Option C - Run for this amount of time (in secs)
89 | var timeLimit: Int = 300
90 |
91 | var stoppingCriterion: StoppingCriterion = RoundLimitCriterion
92 |
93 | // Sparse representation of w_i's
94 | var sparse: Boolean = false
95 |
96 | // Path to write the CSVs
97 | var debugInfoPath: String = new File(".").getCanonicalPath() + "/debugInfo-%d.csv".format(System.currentTimeMillis())
98 |
99 | var solver: Solver = UseDBCFWSolver
100 |
101 | override def toString(): String = {
102 | val sb: StringBuilder = new StringBuilder()
103 |
104 | sb ++= "# numRounds=%s\n".format(roundLimit)
105 | sb ++= "# doWeightedAveraging=%s\n".format(doWeightedAveraging)
106 |
107 | sb ++= "# randSeed=%d\n".format(randSeed)
108 |
109 | sb ++= "# sample=%s\n".format(sample)
110 | sb ++= "# lambda=%f\n".format(lambda)
111 | sb ++= "# doLineSearch=%s\n".format(doLineSearch)
112 |
113 | sb ++= "# enableManualPartitionSize=%s\n".format(enableManualPartitionSize)
114 | sb ++= "# NUM_PART=%s\n".format(NUM_PART)
115 |
116 | sb ++= "# enableOracleCache=%s\n".format(enableOracleCache)
117 | sb ++= "# oracleCacheSize=%d\n".format(oracleCacheSize)
118 |
119 | sb ++= "# H=%d\n".format(H)
120 | sb ++= "# sampleFrac=%f\n".format(sampleFrac)
121 | sb ++= "# sampleWithReplacement=%s\n".format(sampleWithReplacement)
122 |
123 | sb ++= "# debugInfoPath=%s\n".format(debugInfoPath)
124 |
125 | sb ++= "# checkpointFreq=%d\n".format(checkpointFreq)
126 |
127 | sb ++= "# stoppingCriterion=%s\n".format(stoppingCriterion)
128 | this.stoppingCriterion match {
129 | case RoundLimitCriterion => sb ++= "# roundLimit=%d\n".format(roundLimit)
130 | case GapThresholdCriterion => sb ++= "# gapThreshold=%f\n".format(gapThreshold)
131 | case TimeLimitCriterion => sb ++= "# timeLimit=%d\n".format(timeLimit)
132 | case _ => throw new Exception("Unrecognized Stopping Criterion")
133 | }
134 |
135 | sb ++="# solver=%s\n".format(solver)
136 |
137 | sb ++="# class weighting=%s\n".format(classWeights)
138 |
139 | sb ++= "# debugMultiplier=%d\n".format(debugMultiplier)
140 |
141 | sb.toString()
142 | }
143 |
144 | }
145 |
--------------------------------------------------------------------------------
/dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/optimization/SolverUtils.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.optimization
2 |
3 | import ch.ethz.dalab.dissolve.classification.StructSVMModel
4 | import breeze.linalg._
5 | import ch.ethz.dalab.dissolve.regression.LabeledObject
6 | import org.apache.spark.SparkContext
7 | import org.apache.spark.SparkConf
8 | import org.apache.spark.rdd.RDD
9 | import scala.reflect.ClassTag
10 | import org.apache.spark.broadcast.Broadcast
11 |
12 | object SolverUtils {
13 |
14 | /**
15 | * Average loss
16 | */
17 | def averageLoss[X, Y](data: Seq[LabeledObject[X, Y]],
18 | dissolveFunctions: DissolveFunctions[X, Y],
19 | model: StructSVMModel[X, Y]): (Double, Double) = {
20 |
21 | var errorTerm: Double = 0.0
22 | var structuredHingeLoss: Double = 0.0
23 |
24 | for (i <- 0 until data.size) {
25 | val ystar_i = dissolveFunctions.predictFn(model, data(i).pattern)
26 | val loss = dissolveFunctions.lossFn(data(i).label, ystar_i)
27 | errorTerm += loss
28 |
29 | val wFeatureDotProduct = model.getWeights().t * dissolveFunctions.featureFn(data(i).pattern, data(i).label)
30 | val structuredHingeLoss: Double = loss - wFeatureDotProduct
31 | }
32 |
33 | // Return average of loss terms
34 | (errorTerm / (data.size.toDouble), structuredHingeLoss / (data.size.toDouble))
35 | }
36 |
37 | def averageLoss[X, Y](data: RDD[LabeledObject[X, Y]],
38 | dissolveFunctions: DissolveFunctions[X, Y],
39 | model: StructSVMModel[X, Y],
40 | dataSize: Int): (Double, Double) = {
41 |
42 | val (loss, hloss) =
43 | data.map {
44 | case datapoint =>
45 | val ystar_i = dissolveFunctions.predictFn(model, datapoint.pattern)
46 | val loss = dissolveFunctions.lossFn(ystar_i, datapoint.label)
47 | val wFeatureDotProduct = model.getWeights().t * (dissolveFunctions.featureFn(datapoint.pattern, datapoint.label)
48 | - dissolveFunctions.featureFn(datapoint.pattern, ystar_i))
49 | val structuredHingeLoss: Double = loss - wFeatureDotProduct
50 |
51 | (loss, structuredHingeLoss)
52 | }.fold((0.0, 0.0)) {
53 | case ((lossAccum, hlossAccum), (loss, hloss)) =>
54 | (lossAccum + loss, hlossAccum + hloss)
55 | }
56 |
57 | (loss / dataSize, hloss / dataSize)
58 | }
59 |
60 | /**
61 | * Objective function (SVM dual, assuming we know the vector b of all losses. See BCFW paper)
62 | */
63 | def objectiveFunction(w: Vector[Double],
64 | b_alpha: Double,
65 | lambda: Double): Double = {
66 | // Return the value of f(alpha)
67 | 0.5 * lambda * (w.t * w) - b_alpha
68 | }
69 |
70 | /**
71 | * Compute Duality gap
72 | * Requires one full pass of decoding over all data examples.
73 | */
74 | def dualityGap[X, Y](data: Seq[LabeledObject[X, Y]],
75 | featureFn: (X, Y) => Vector[Double],
76 | lossFn: (Y, Y) => Double,
77 | oracleFn: (StructSVMModel[X, Y], X, Y) => Y,
78 | model: StructSVMModel[X, Y],
79 | lambda: Double)(implicit m: ClassTag[Y]): (Double, Vector[Double], Double) = {
80 |
81 | val phi = featureFn
82 | val maxOracle = oracleFn
83 |
84 | val w: Vector[Double] = model.getWeights()
85 | val ell: Double = model.getEll()
86 |
87 | val n: Int = data.size
88 | val d: Int = model.getWeights().size
89 | val yStars = new Array[Y](n)
90 |
91 | for (i <- 0 until n) {
92 | yStars(i) = maxOracle(model, data(i).pattern, data(i).label)
93 | }
94 |
95 | var w_s: DenseVector[Double] = DenseVector.zeros[Double](d)
96 | var ell_s: Double = 0.0
97 | for (i <- 0 until n) {
98 | w_s += phi(data(i).pattern, data(i).label) - phi(data(i).pattern, yStars(i))
99 | ell_s += lossFn(yStars(i), data(i).label)
100 | }
101 |
102 | w_s = w_s / (lambda * n)
103 | ell_s = ell_s / n
104 |
105 | val gap: Double = w.t * (w - w_s) * lambda - ell + ell_s
106 |
107 | (gap, w_s, ell_s)
108 | }
109 |
110 | /**
111 | * Alternative implementation, using fold. TODO: delete this or the above
112 | * Requires one full pass of decoding over all data examples.
113 | */
114 | def dualityGap[X, Y](data: RDD[LabeledObject[X, Y]],
115 | dissolveFunctions: DissolveFunctions[X, Y],
116 | model: StructSVMModel[X, Y],
117 | lambda: Double,
118 | dataSize: Int)(implicit m: ClassTag[Y]): (Double, Vector[Double], Double) = {
119 |
120 | val phi = dissolveFunctions.featureFn _
121 | val maxOracle = dissolveFunctions.oracleFn _
122 | val lossFn = dissolveFunctions.lossFn _
123 | val classWeight = dissolveFunctions.classWeights _
124 |
125 | val w: Vector[Double] = model.getWeights()
126 | val ell: Double = model.getEll()
127 |
128 | val n: Int = dataSize.toInt
129 | val d: Int = model.getWeights().size
130 |
131 | var (w_s, ell_s) = data.map {
132 | case datapoint =>
133 | val yStar = maxOracle(model, datapoint.pattern, datapoint.label)
134 | val w_s = (phi(datapoint.pattern, datapoint.label) - phi(datapoint.pattern, yStar))*classWeight(datapoint.label)
135 | val ell_s = lossFn(yStar, datapoint.label)*classWeight(datapoint.label)
136 |
137 | (w_s, ell_s)
138 | }.fold((Vector.zeros[Double](d), 0.0)) {
139 | case ((w_acc, ell_acc), (w_i, ell_i)) =>
140 | (w_acc + w_i, ell_acc + ell_i)
141 | }
142 |
143 | w_s = w_s / (lambda * n)
144 | ell_s = ell_s / n
145 |
146 | val gap: Double = w.t * (w - w_s) * lambda - ell + ell_s
147 |
148 | (gap, w_s, ell_s)
149 | }
150 |
151 | /**
152 | * Primal objective.
153 | * Requires one full pass of decoding over all data examples.
154 | */
155 | def primalObjective[X, Y](data: Seq[LabeledObject[X, Y]],
156 | dissolveFunctions: DissolveFunctions[X, Y],
157 | model: StructSVMModel[X, Y],
158 | lambda: Double): Double = {
159 |
160 | val featureFn = dissolveFunctions.featureFn _
161 | val oracleFn = dissolveFunctions.oracleFn _
162 | val lossFn = dissolveFunctions.lossFn _
163 | val classWeight = dissolveFunctions.classWeights _
164 |
165 | var hingeLosses: Double = 0.0
166 | for (i <- 0 until data.size) {
167 | val yStar_i = oracleFn(model, data(i).pattern, data(i).label)
168 | val loss_i = lossFn(yStar_i, data(i).label)*classWeight(data(i).label)
169 | val psi_i = featureFn(data(i).pattern, data(i).label) - featureFn(data(i).pattern, yStar_i)*classWeight(data(i).label)
170 |
171 | val hingeloss_i = loss_i - model.getWeights().t * psi_i
172 | // println("loss_i = %f, other_loss = %f".format(loss_i, model.getWeights().t * psi_i))
173 | // assert(hingeloss_i >= 0.0)
174 |
175 | hingeLosses += hingeloss_i
176 | }
177 |
178 | // Compute the primal and return it
179 | 0.5 * lambda * (model.getWeights.t * model.getWeights) + hingeLosses / data.size
180 |
181 | }
182 |
183 | case class DataEval(gap: Double,
184 | avgDelta: Double,
185 | avgHLoss: Double)
186 |
187 | case class PartialTrainDataEval(sum_w_s: Vector[Double],
188 | sum_ell_s: Double,
189 | sum_Delta: Double,
190 | sum_HLoss: Double) {
191 |
192 | def +(that: PartialTrainDataEval): PartialTrainDataEval = {
193 |
194 | val sum_w_s = this.sum_w_s + that.sum_w_s
195 | val sum_ell_s: Double = this.sum_ell_s + that.sum_ell_s
196 | val sum_Delta: Double = this.sum_Delta + that.sum_Delta
197 | val sum_HLoss: Double = this.sum_HLoss + that.sum_HLoss
198 |
199 | PartialTrainDataEval(sum_w_s,
200 | sum_ell_s,
201 | sum_Delta,
202 | sum_HLoss)
203 |
204 | }
205 | }
206 | /**
207 | * Makes an additional pass over the data to compute the following:
208 | * 1. Duality Gap
209 | * 2. Average \Delta
210 | * 3. Average Structured Hinge Loss
211 | * 4. Average Per-Class pixel-loss
212 | * 5. Global loss
213 | */
214 | def trainDataEval[X, Y](data: RDD[LabeledObject[X, Y]],
215 | dissolveFunctions: DissolveFunctions[X, Y],
216 | model: StructSVMModel[X, Y],
217 | lambda: Double,
218 | dataSize: Int)(implicit m: ClassTag[Y]): DataEval = {
219 |
220 | val phi = dissolveFunctions.featureFn _
221 | val maxOracle = dissolveFunctions.oracleFn _
222 | val lossFn = dissolveFunctions.lossFn _
223 | val predictFn = dissolveFunctions.predictFn _
224 | val classWeight = dissolveFunctions.classWeights _
225 |
226 | val n: Int = dataSize.toInt
227 | val d: Int = model.getWeights().size
228 |
229 | val bcModel: Broadcast[StructSVMModel[X, Y]] = data.context.broadcast(model)
230 |
231 | val initEval =
232 | PartialTrainDataEval(DenseVector.zeros[Double](d),
233 | 0.0,
234 | 0.0,
235 | 0.0)
236 |
237 | val partialEval = data.map {
238 | case datapoint =>
239 | /**
240 | * Gap and Structured HingeLoss
241 | */
242 | val lossAug_yStar = maxOracle(bcModel.value, datapoint.pattern, datapoint.label)
243 | val w_s = (phi(datapoint.pattern, datapoint.label) - phi(datapoint.pattern, lossAug_yStar))*classWeight(datapoint.label)
244 | val ell_s = lossFn(datapoint.label, lossAug_yStar)*classWeight(datapoint.label)
245 | val lossAug_wFeatureDotProduct = lossFn(datapoint.label, lossAug_yStar) -
246 | (bcModel.value.getWeights().t * (phi(datapoint.pattern, datapoint.label)
247 | - phi(datapoint.pattern, lossAug_yStar)))
248 | val structuredHingeLoss: Double = lossAug_wFeatureDotProduct
249 |
250 | /**
251 | * \Delta
252 | */
253 | val predict_yStar = predictFn(bcModel.value, datapoint.pattern)
254 | val loss = lossFn(datapoint.label, predict_yStar)
255 |
256 | /**
257 | * Per-class loss
258 | */
259 | val y_truth = datapoint.label
260 | val y_predicted = predict_yStar
261 |
262 | PartialTrainDataEval(w_s,
263 | ell_s,
264 | loss,
265 | structuredHingeLoss)
266 |
267 | }.reduce(_ + _)
268 |
269 | val w: Vector[Double] = model.getWeights()
270 | val ell: Double = model.getEll()
271 |
272 | // Gap
273 | val sum_w_s = partialEval.sum_w_s
274 | val w_s = sum_w_s / (lambda * n)
275 | val ell_s = partialEval.sum_ell_s / n
276 | val gap: Double = w.t * (w - w_s) * lambda - ell + ell_s
277 |
278 | // Loss
279 | val avgLoss = partialEval.sum_Delta / n
280 | val avgHLoss = partialEval.sum_HLoss / n
281 |
282 | DataEval(gap,
283 | avgLoss,
284 | avgHLoss)
285 | }
286 |
287 | /**
288 | * Get Spark's properties
289 | */
290 | def getSparkConfString(sc: SparkConf): String = {
291 | val keys = List("spark.app.name", "spark.executor.memory", "spark.task.cpus", "spark.local.dir", "spark.default.parallelism")
292 | val sb: StringBuilder = new StringBuilder()
293 |
294 | for (key <- keys)
295 | sb ++= "# %s=%s\n".format(key, sc.get(key, "NA"))
296 |
297 | sb.toString()
298 | }
299 |
300 | }
--------------------------------------------------------------------------------
/dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/regression/LabeledObject.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.regression
2 |
3 | import org.apache.spark.mllib.regression.LabeledPoint
4 |
5 | import breeze.linalg._
6 |
7 | case class LabeledObject[X, Y](
8 | val label: Y,
9 | val pattern: X) extends Serializable {
10 | }
--------------------------------------------------------------------------------
/dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/utils/cli/CLAParser.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.utils.cli
2 |
3 | import ch.ethz.dalab.dissolve.optimization._
4 | import scopt.OptionParser
5 |
6 | object CLAParser {
7 |
8 | def getParser(): OptionParser[Config] =
9 | new scopt.OptionParser[Config]("spark-submit ... ") {
10 | head("dissolve^struct", "0.1-SNAPSHOT")
11 |
12 | help("help") text ("prints this usage text")
13 |
14 | opt[Double]("lambda") action { (x, c) =>
15 | c.copy(lambda = x)
16 | } text ("Regularization constant. Default = 0.01")
17 |
18 | opt[Int]("randseed") action { (x, c) =>
19 | c.copy(randSeed = x)
20 | } text ("Random Seed. Default = 42")
21 |
22 | opt[Unit]("linesearch") action { (_, c) =>
23 | c.copy(lineSearch = true)
24 | } text ("Enable Line Search. Default = false")
25 |
26 | opt[Double]("samplefrac") action { (x, c) =>
27 | c.copy(sampleFrac = x)
28 | } text ("Fraction of original dataset to be sampled in each round. Default = 0.5")
29 |
30 | opt[Int]("minpart") action { (x, c) =>
31 | c.copy(minPartitions = x)
32 | } text ("Repartition data RDD to given number of partitions before training. Default = Auto")
33 |
34 | opt[Unit]("sparse") action { (_, c) =>
35 | c.copy(sparse = true)
36 | } text ("Maintain vectors as sparse vectors. Default = Dense.")
37 |
38 | opt[Int]("oraclecachesize") action { (x, c) =>
39 | c.copy(oracleCacheSize = x)
40 | } text ("Oracle Cache Size (caching answers of the maximization oracle for this datapoint). Default = Disabled")
41 |
42 | opt[Int]("cpfreq") action { (x, c) =>
43 | c.copy(checkpointFreq = x)
44 | } text ("Checkpoint Frequency (in rounds). Default = 50")
45 |
46 | opt[String]("stopcrit") action { (x, c) =>
47 | x match {
48 | case "round" =>
49 | c.copy(stoppingCriterion = RoundLimitCriterion)
50 | case "gap" =>
51 | c.copy(stoppingCriterion = GapThresholdCriterion)
52 | case "time" =>
53 | c.copy(stoppingCriterion = TimeLimitCriterion)
54 | }
55 | } validate { x =>
56 | x match {
57 | case "round" => success
58 | case "gap" => success
59 | case "time" => success
60 | case _ => failure("Stopping criterion has to be one of: round | gap | time")
61 | }
62 | } text ("Stopping Criterion. (round | gap | time). Default = round")
63 |
64 | opt[Int]("roundlimit") action { (x, c) =>
65 | c.copy(roundLimit = x)
66 | } text ("Round Limit. Default = 25")
67 |
68 | opt[Double]("gapthresh") action { (x, c) =>
69 | c.copy(gapThreshold = x)
70 | } text ("Gap Threshold. Default = 0.1")
71 |
72 | opt[Int]("gapcheck") action { (x, c) =>
73 | c.copy(gapCheck = x)
74 | } text ("Checks for gap every these many rounds. Default = 25")
75 |
76 | opt[Int]("timelimit") action { (x, c) =>
77 | c.copy(timeLimit = x)
78 | } text ("Time Limit (in secs). Default = 300 secs")
79 |
80 | opt[Unit]("debug") action { (_, c) =>
81 | c.copy(debug = true)
82 | } text ("Enable debugging. Default = false")
83 |
84 | opt[Int]("debugmult") action { (x, c) =>
85 | c.copy(debugMultiplier = x)
86 | } text ("Frequency of debugging. Obtains gap, train and test errors. Default = 1")
87 |
88 | opt[String]("debugfile") action { (x, c) =>
89 | c.copy(debugPath = x)
90 | } text ("Path to debug file. Default = current-dir")
91 |
92 | opt[Map[String, String]]("kwargs") valueName ("k1=v1,k2=v2...") action { (x, c) =>
93 | c.copy(kwargs = x)
94 | } text ("other arguments")
95 | }
96 |
97 | def argsToOptions[X, Y](args: Array[String]): (SolverOptions[X, Y], Map[String, String]) =
98 |
99 | getParser().parse(args, Config()) match {
100 | case Some(config) =>
101 | val solverOptions: SolverOptions[X, Y] = new SolverOptions[X, Y]()
102 | // Copy all config parameters to a Solver Options instance
103 | solverOptions.lambda = config.lambda
104 | solverOptions.randSeed = config.randSeed
105 | solverOptions.doLineSearch = config.lineSearch
106 | solverOptions.doWeightedAveraging = config.wavg
107 |
108 | solverOptions.sampleFrac = config.sampleFrac
109 | if (config.minPartitions > 0) {
110 | solverOptions.enableManualPartitionSize = true
111 | solverOptions.NUM_PART = config.minPartitions
112 | }
113 | solverOptions.sparse = config.sparse
114 |
115 | if (config.oracleCacheSize > 0) {
116 | solverOptions.enableOracleCache = true
117 | solverOptions.oracleCacheSize = config.oracleCacheSize
118 | }
119 |
120 | solverOptions.checkpointFreq = config.checkpointFreq
121 |
122 | solverOptions.stoppingCriterion = config.stoppingCriterion
123 | solverOptions.roundLimit = config.roundLimit
124 | solverOptions.gapCheck = config.gapCheck
125 | solverOptions.gapThreshold = config.gapThreshold
126 | solverOptions.timeLimit = config.timeLimit
127 |
128 | solverOptions.debug = config.debug
129 | solverOptions.debugMultiplier = config.debugMultiplier
130 | solverOptions.debugInfoPath = config.debugPath
131 |
132 | (solverOptions, config.kwargs)
133 |
134 | case None =>
135 | // No options passed. Do nothing.
136 | val solverOptions: SolverOptions[X, Y] = new SolverOptions[X, Y]()
137 | val kwargs = Map[String, String]()
138 |
139 | (solverOptions, kwargs)
140 | }
141 |
142 | def main(args: Array[String]): Unit = {
143 | val foo = argsToOptions(args)
144 | println(foo._1.toString())
145 | println(foo._2)
146 | }
147 |
148 | }
--------------------------------------------------------------------------------
/dissolve-struct-lib/src/main/scala/ch/ethz/dalab/dissolve/utils/cli/Config.scala:
--------------------------------------------------------------------------------
1 | package ch.ethz.dalab.dissolve.utils.cli
2 |
3 | import ch.ethz.dalab.dissolve.optimization._
4 | import java.io.File
5 |
6 | case class Config(
7 |
8 | // BCFW parameters
9 | lambda: Double = 0.01,
10 | randSeed: Int = 42,
11 | lineSearch: Boolean = false,
12 | wavg: Boolean = false,
13 |
14 | // dissolve^struct parameters
15 | sampleFrac: Double = 0.5,
16 | minPartitions: Int = 0,
17 | sparse: Boolean = false,
18 |
19 | // Oracle
20 | oracleCacheSize: Int = 0,
21 |
22 | // Spark
23 | checkpointFreq: Int = 50,
24 |
25 | // Stopping criteria
26 | stoppingCriterion: StoppingCriterion = RoundLimitCriterion,
27 | // A - RoundLimit
28 | roundLimit: Int = 25,
29 | // B - Gap Check
30 | gapThreshold: Double = 0.1,
31 | gapCheck: Int = 10,
32 | // C - Time Limit
33 | timeLimit: Int = 300, // (In seconds)
34 |
35 | // Debug parameters
36 | debug: Boolean = false,
37 | debugMultiplier: Int = 1,
38 | debugPath: String = new File(".", "debug-%d.csv".format(System.currentTimeMillis())).getAbsolutePath,
39 |
40 | // Other parameters
41 | kwargs: Map[String, String] = Map())
--------------------------------------------------------------------------------
/helpers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dalab/dissolve-struct/67f37377b74c32cf05d8f43a0e3658a10864f9bf/helpers/__init__.py
--------------------------------------------------------------------------------
/helpers/benchmark_runner.py:
--------------------------------------------------------------------------------
1 | """Given experimental parameters, runs the required experiments and obtains the data
2 | """
3 | import argparse
4 | import ConfigParser
5 | import datetime
6 | import os
7 |
8 | from benchmark_utils import *
9 |
10 | VALID_PARAMS = {"lambda", "minpart", "samplefrac", "oraclesize"}
11 | VAL_PARAMS = {"lambda", "minpart", "samplefrac", "oraclesize", "stopcrit", "roundlimit", "gaplimit", "gapcheck",
12 | "timelimit", "debugmult"}
13 | BOOL_PARAMS = {"sparse", "debug", "linesearch"}
14 |
15 | WDIR = "/home/ec2-user" # Working directory
16 |
17 |
18 | def str_to_bool(s):
19 | if s in ['True', 'true']:
20 | return True
21 | elif s in ['False', 'false']:
22 | return False
23 | else:
24 | raise ValueError("Boolean value in config '%s' unrecognized")
25 |
26 |
27 | def main():
28 | parser = argparse.ArgumentParser(description='Run benchmark')
29 | parser.add_argument("identity_file", help="SSH private key to log into spark nodes")
30 | parser.add_argument("master_uri", help="URI of master node")
31 | parser.add_argument("expt_config", help="Experimental config file")
32 | args = parser.parse_args()
33 |
34 | master_host = args.master_uri
35 | identity_file = args.identity_file
36 |
37 | def ssh_spark(command, user="root", cwd=WDIR):
38 | command = "source /root/.bash_profile; cd %s; %s" % (cwd, command)
39 | ssh(master_host, user, command, identity_file)
40 |
41 | # Check if setup has been executed
42 | ssh_spark("if [ ! -f /home/ec2-user/onesmallstep ]; then echo \"Run benchmark_setup and try again\"; exit 1; fi",
43 | cwd=WDIR)
44 |
45 | dtf = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")
46 | appname_format = "{dtf}-{expt_name}-R{rep_num}-{param}-{paramval}"
47 | spark_submit_cmd_format = ("{spark_submit} "
48 | "--jars {lib_jar_path} "
49 | "--class \"{class_name}\" "
50 | "{spark_args} "
51 | "{examples_jar_path} "
52 | "{solver_options_args} "
53 | "--kwargs k1=v1,{app_args}")
54 |
55 | config = ConfigParser.ConfigParser()
56 | config.read(args.expt_config)
57 |
58 | expt_name = config.get("general", "experiment_name")
59 | class_name = config.get("general", "class_name")
60 | num_repetitions = config.getint("general", "repetitions")
61 |
62 | # Pivot values
63 | pivot_param = config.get("pivot", "param")
64 | assert (pivot_param in VALID_PARAMS)
65 | pivot_values_raw = config.get("pivot", "values")
66 | pivot_values = map(lambda x: x.strip(), pivot_values_raw.split(","))
67 |
68 | # Paths
69 | examples_jar_path = config.get("paths", "examples_jar_path")
70 | spark_dir = config.get("paths", "spark_dir")
71 | spark_submit_path = os.path.join(spark_dir, "bin", "spark-submit")
72 | hdfs_input_path = config.get("paths", "hdfs_input_path")
73 |
74 | local_output_dir = config.get("paths", "local_output_dir")
75 | local_output_expt_dir = os.path.join(local_output_dir, "%s %s" % (expt_name, dtf))
76 | if not os.path.exists(local_output_expt_dir):
77 | os.makedirs(local_output_expt_dir)
78 |
79 | dissolve_lib_jar_path = config.get("paths", "lib_jar_path")
80 | scopt_jar_path = "/root/.ivy2/cache/com.github.scopt/scopt_2.10/jars/scopt_2.10-3.3.0.jar"
81 | lib_jar_path = ','.join([dissolve_lib_jar_path, scopt_jar_path])
82 |
83 | for rep_num in range(1, num_repetitions + 1):
84 | print "==========================="
85 | print "====== Repetition %d ======" % rep_num
86 | print "==========================="
87 | for pivot_val in pivot_values:
88 | print "=== %s = %s ===" % (pivot_param, pivot_val)
89 | '''
90 | Construct command to execute on spark cluster
91 | '''
92 | appname = appname_format.format(dtf=dtf,
93 | expt_name=expt_name,
94 | rep_num=rep_num,
95 | param=pivot_param,
96 | paramval=pivot_val)
97 |
98 | # === Construct Spark arguments ===
99 | spark_args = ','.join(["--%s %s" % (k, v) for k, v in config.items("spark_args")])
100 |
101 | # === Construct Solver Options arguments ===
102 | valued_parameter_args = ' '.join(
103 | ["--%s %s" % (k, v) for k, v in config.items("parameters") if k in VAL_PARAMS])
104 | boolean_parameter_args = ' '.join(
105 | ["--%s" % k for k, v in config.items("parameters") if k in BOOL_PARAMS and str_to_bool(v)])
106 | valued_dissolve_args = ' '.join(
107 | ["--%s %s" % (k, v) for k, v in config.items("dissolve_args") if k in VAL_PARAMS])
108 | boolean_dissolve_args = ' '.join(
109 | ["--%s" % k for k, v in config.items("dissolve_args") if k in BOOL_PARAMS and str_to_bool(v)])
110 |
111 | # === Add the pivotal parameter ===
112 | assert (pivot_param not in config.options("parameters"))
113 | pivot_param_arg = "--%s %s" % (pivot_param, pivot_val)
114 |
115 | solver_options_args = ' '.join(
116 | [valued_parameter_args, boolean_parameter_args, valued_dissolve_args, boolean_dissolve_args,
117 | pivot_param_arg])
118 |
119 | # == Construct App-specific arguments ===
120 | debug_filename = "%s.csv" % appname
121 | debug_file_path = os.path.join(WDIR, debug_filename)
122 | default_app_args = ("appname={appname},"
123 | "input_path={input_path},"
124 | "debug_file={debug_file_path}").format(appname=appname,
125 | input_path=hdfs_input_path,
126 | debug_file_path=debug_file_path)
127 | extra_app_args = ','.join(["%s=%s" % (k, v) for k, v in config.items("app_args")])
128 |
129 | app_args = ','.join([default_app_args, extra_app_args])
130 |
131 | spark_submit_cmd = spark_submit_cmd_format.format(spark_submit=spark_submit_path,
132 | lib_jar_path=lib_jar_path,
133 | class_name=class_name,
134 | examples_jar_path=examples_jar_path,
135 | spark_args=spark_args,
136 | solver_options_args=solver_options_args,
137 | app_args=app_args)
138 |
139 | '''
140 | Execute Command
141 | '''
142 | print "Executing on %s:\n%s" % (master_host, spark_submit_cmd)
143 | ssh_spark(spark_submit_cmd)
144 |
145 | '''
146 | Obtain required files
147 | '''
148 | scp_from(master_host, identity_file, "root", debug_file_path, local_output_expt_dir)
149 |
150 | '''
151 | Perform clean-up
152 | '''
153 |
154 |
155 | if __name__ == '__main__':
156 | main()
--------------------------------------------------------------------------------
/helpers/benchmark_setup.py:
--------------------------------------------------------------------------------
1 | """Prepare the EC2 cluster for dissolve^struct experiments.
2 |
3 | Given a cluster created using spark-ec2 setup, this script will:
4 | - retrieve the datasets and places them into HDFS
5 | - build the required packages and place them into appropriate folders
6 | - setup the execution environment
7 |
8 | Reuses ssh code from AmpLab's Big Data Benchmarks
9 | """
10 | import argparse
11 | import os
12 |
13 | from benchmark_utils import *
14 |
15 | WDIR = "/home/ec2-user" # Working directory
16 |
17 |
18 | def main():
19 | parser = argparse.ArgumentParser(description='Setup benchmark cluster')
20 | parser.add_argument("identity_file", help="SSH private key to log into spark nodes")
21 | parser.add_argument("master_uri", help="URI of master node")
22 | args = parser.parse_args()
23 |
24 | master_host = args.master_uri
25 | identity_file = args.identity_file
26 |
27 | dissolve_dir = os.path.join(WDIR, "dissolve-struct")
28 | dissolve_lib_dir = os.path.join(dissolve_dir, "dissolve-struct-lib")
29 | dissolve_examples_dir = os.path.join(dissolve_dir, "dissolve-struct-examples")
30 |
31 | def ssh_spark(command, user="root", cwd=WDIR):
32 | command = "source /root/.bash_profile; cd %s; %s" % (cwd, command)
33 | ssh(master_host, user, command, identity_file)
34 |
35 | # === Install all required dependencies ===
36 | # sbt
37 | ssh_spark("curl https://bintray.com/sbt/rpm/rpm | sudo tee /etc/yum.repos.d/bintray-sbt-rpm.repo")
38 | ssh_spark("yum install sbt -y")
39 | # python pip
40 | ssh_spark("yum install python27 -y")
41 | ssh_spark("yum install python-pip -y")
42 |
43 | # === Checkout git repo ===
44 | ssh_spark("git clone https://github.com/dalab/dissolve-struct.git %s" % dissolve_dir)
45 |
46 | # === Build packages ===
47 | # Build lib
48 | # Jar location:
49 | # /root/.ivy2/local/ch.ethz.dalab/dissolvestruct_2.10/0.1-SNAPSHOT/jars/dissolvestruct_2.10.jar
50 | ssh_spark("sbt publish-local", cwd=dissolve_lib_dir)
51 | # Build examples
52 | # Jar location:
53 | # /home/ec2-user/dissolve-struct/dissolve-struct-examples/target/scala-2.10/dissolvestructexample_2.10-0.1-SNAPSHOT.jar
54 | ssh_spark("sbt package", cwd=dissolve_examples_dir)
55 |
56 | # === Data setup ===
57 | # Install pip dependencies
58 | ssh_spark("pip install -r requirements.txt", cwd=dissolve_dir)
59 |
60 | # Execute data retrieval script
61 | ssh_spark("python helpers/retrieve_datasets.py -d", cwd=dissolve_dir)
62 |
63 | # === Setup environment ===
64 | conf_dir = os.path.join(dissolve_examples_dir, "conf")
65 | ssh_spark("cp -r %s %s" % (conf_dir, WDIR))
66 |
67 | # === Move data to HDFS ===
68 | data_dir = os.path.join(dissolve_dir, "data")
69 | ssh_spark("/root/ephemeral-hdfs/bin/hadoop fs -put %s data" % data_dir)
70 |
71 | # === Create a file to mark everything is setup ===
72 | ssh_spark("touch onesmallstep", cwd=WDIR)
73 |
74 |
75 | if __name__ == '__main__':
76 | main()
77 |
--------------------------------------------------------------------------------
/helpers/benchmark_utils.py:
--------------------------------------------------------------------------------
1 | """Common utils by benchmark scripts
2 | """
3 | import subprocess
4 |
5 |
6 | def ssh(host, username, command, identity_file=None):
7 | if not identity_file:
8 | subprocess.check_call(
9 | "ssh -t -o StrictHostKeyChecking=no %s@%s '%s'" %
10 | (username, host, command), shell=True)
11 | else:
12 | subprocess.check_call(
13 | "ssh -t -o StrictHostKeyChecking=no -i %s %s@%s '%s'" %
14 | (identity_file, username, host, command), shell=True)
15 |
16 |
17 | # Copy a file to a given host through scp, throwing an exception if scp fails
18 | def scp_to(host, identity_file, username, local_file, remote_file):
19 | subprocess.check_call(
20 | "scp -q -o StrictHostKeyChecking=no -i %s '%s' '%s@%s:%s'" %
21 | (identity_file, local_file, username, host, remote_file), shell=True)
22 |
23 |
24 | # Copy a file to a given host through scp, throwing an exception if scp fails
25 | def scp_from(host, identity_file, username, remote_file, local_file):
26 | subprocess.check_call(
27 | "scp -q -o StrictHostKeyChecking=no -i %s '%s@%s:%s' '%s'" %
28 | (identity_file, username, host, remote_file, local_file), shell=True)
--------------------------------------------------------------------------------
/helpers/brutus_runner.py:
--------------------------------------------------------------------------------
1 | """Given experimental parameters, runs the required experiments and obtains the data.
2 | To be executed on the Hadoop main node.
3 | """
4 | import argparse
5 | import ConfigParser
6 | import datetime
7 | import re
8 |
9 | from benchmark_utils import *
10 | from paths import *
11 |
12 | VALID_PARAMS = {"lambda", "minpart", "samplefrac", "oraclesize", "num-executors"}
13 | VAL_PARAMS = {"lambda", "minpart", "samplefrac", "oraclesize", "stopcrit", "roundlimit", "gaplimit", "gapcheck", "gapthresh",
14 | "timelimit", "debugmult"}
15 | BOOL_PARAMS = {"sparse", "debug", "linesearch"}
16 |
17 | HOME_DIR = os.getenv("HOME")
18 | PROJ_DIR = os.path.join(HOME_DIR, "dissolve-struct")
19 |
20 | DEFAULT_CORES = 4
21 |
22 |
23 | def execute(command, cwd=PROJ_DIR):
24 | subprocess.check_call(command, cwd=cwd, shell=True)
25 |
26 |
27 | def str_to_bool(s):
28 | if s in ['True', 'true']:
29 | return True
30 | elif s in ['False', 'false']:
31 | return False
32 | else:
33 | raise ValueError("Boolean value in config '%s' unrecognized")
34 |
35 |
36 | def main():
37 | parser = argparse.ArgumentParser(description='Run benchmark')
38 | parser.add_argument("expt_config", help="Experimental config file")
39 | parser.add_argument("--ds", help="Run with debugging separately. Forces execution of two spark jobs",
40 | action='store_true')
41 | args = parser.parse_args()
42 |
43 | # Check if setup has been executed
44 | touchfile_path = os.path.join(HOME_DIR, 'onesmallstep')
45 | execute("if [ ! -f %s ]; then echo \"Run benchmark_setup and try again\"; exit 1; fi" % touchfile_path,
46 | cwd=HOME_DIR)
47 |
48 | dtf = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")
49 | appname_format = "{dtf}-{expt_name}-{param}-{paramval}"
50 | spark_submit_cmd_format = ("{spark_submit} "
51 | "--jars {lib_jar_path} "
52 | "--class \"{class_name}\" "
53 | "{spark_args} "
54 | "{examples_jar_path} "
55 | "{solver_options_args} "
56 | "--kwargs k1=v1,{app_args}")
57 |
58 | config = ConfigParser.ConfigParser()
59 | config.read(args.expt_config)
60 |
61 | expt_name = config.get("general", "experiment_name")
62 | class_name = config.get("general", "class_name")
63 |
64 | if args.ds:
65 | assert ('debug' in config.options('dissolve_args'))
66 | # Want debug set to true, in case of double-debugging
67 | assert (str_to_bool(config.get('dissolve_args', 'debug')))
68 |
69 | # Pivot values
70 | pivot_param = config.get("pivot", "param")
71 | assert (pivot_param in VALID_PARAMS)
72 | pivot_values_raw = config.get("pivot", "values")
73 | pivot_values = map(lambda x: x.strip(), pivot_values_raw.split(","))
74 |
75 | # Paths
76 | examples_jar_path = EXAMPLES_JAR_PATH
77 | spark_submit_path = "spark-submit"
78 | input_path = config.get("paths", "input_path")
79 |
80 | output_dir = EXPT_OUTPUT_DIR
81 | local_output_expt_dir = os.path.join(output_dir, "%s_%s" % (expt_name, dtf))
82 | if not os.path.exists(local_output_expt_dir):
83 | os.makedirs(local_output_expt_dir)
84 |
85 | dissolve_lib_jar_path = LIB_JAR_PATH
86 | scopt_jar_path = SCOPT_JAR_PATH
87 | lib_jar_path = ','.join([dissolve_lib_jar_path, scopt_jar_path])
88 |
89 | '''
90 | Execute experiment
91 | '''
92 | for pivot_val in pivot_values:
93 | print "=== %s = %s ===" % (pivot_param, pivot_val)
94 | '''
95 | Construct command to execute on spark cluster
96 | '''
97 | appname = appname_format.format(dtf=dtf,
98 | expt_name=expt_name,
99 | param=pivot_param,
100 | paramval=pivot_val)
101 |
102 | # === Construct Solver Options arguments ===
103 | valued_parameter_args = ' '.join(
104 | ["--%s %s" % (k, v) for k, v in config.items("parameters") if k in VAL_PARAMS and k not in ['minpart']])
105 | # Treat 'minpart' as a special case. If minpart = 'auto', set minpart = num_cores * num_executors
106 | if 'minpart' in config.options('parameters'):
107 | if config.get('parameters', 'minpart') == 'auto':
108 | if pivot_param == 'num-executors':
109 | num_executors = int(pivot_val)
110 | else:
111 | num_executors = config.getint('spark_args', 'num-executors')
112 | minpart = DEFAULT_CORES * num_executors
113 | else:
114 | minpart = config.getint('parameters', 'minpart')
115 | minpart_arg = '--minpart %d' % minpart
116 | valued_parameter_args = ' '.join([valued_parameter_args, minpart_arg])
117 | boolean_parameter_args = ' '.join(
118 | ["--%s" % k for k, v in config.items("parameters") if k in BOOL_PARAMS and str_to_bool(v)])
119 | valued_dissolve_args = ' '.join(
120 | ["--%s %s" % (k, v) for k, v in config.items("dissolve_args") if k in VAL_PARAMS])
121 | boolean_dissolve_args = ' '.join(
122 | ["--%s" % k for k, v in config.items("dissolve_args") if k in BOOL_PARAMS and str_to_bool(v)])
123 |
124 | solver_options_args = ' '.join(
125 | [valued_parameter_args, boolean_parameter_args, valued_dissolve_args, boolean_dissolve_args])
126 |
127 | # === Construct Spark arguments ===
128 | spark_args = ' '.join(["--%s %s" % (k, v) for k, v in config.items("spark_args")])
129 |
130 | # === Add the pivotal parameter ===
131 | assert (pivot_param not in config.options("parameters"))
132 | assert (pivot_param not in config.options("spark_args"))
133 | pivot_param_arg = "--%s %s" % (pivot_param, pivot_val)
134 |
135 | # Is this pivotal parameters a spark argument or a dissolve argument?
136 | if pivot_param in ['num-executors', ]:
137 | spark_args = ' '.join([spark_args, pivot_param_arg])
138 | else:
139 | solver_options_args = ' '.join([solver_options_args, pivot_param_arg])
140 |
141 | # == Construct App-specific arguments ===
142 | debug_filename = "%s.csv" % appname
143 | debug_file_path = os.path.join('', debug_filename)
144 | default_app_args = ("appname={appname},"
145 | "input_path={input_path},"
146 | "debug_file={debug_file_path}").format(appname=appname,
147 | input_path=input_path,
148 | debug_file_path=debug_file_path)
149 | extra_app_args = ','.join(["%s=%s" % (k, v) for k, v in config.items("app_args")])
150 |
151 | app_args = ','.join([default_app_args, extra_app_args])
152 |
153 | spark_submit_cmd = spark_submit_cmd_format.format(spark_submit=spark_submit_path,
154 | lib_jar_path=lib_jar_path,
155 | class_name=class_name,
156 | examples_jar_path=examples_jar_path,
157 | spark_args=spark_args,
158 | solver_options_args=solver_options_args,
159 | app_args=app_args)
160 |
161 | '''
162 | Execute Command
163 | '''
164 | print "Executing:\n%s" % spark_submit_cmd
165 | execute(spark_submit_cmd)
166 |
167 | '''
168 | If enabled, execute command again, but without the debug flag
169 | '''
170 | if args.ds:
171 | no_debug_appname = appname + '.no_debug'
172 | debug_filename = "%s.csv" % no_debug_appname
173 | debug_file_path = os.path.join('', debug_filename)
174 | default_app_args = ("appname={appname},"
175 | "input_path={input_path},"
176 | "debug_file={debug_file_path}").format(appname=no_debug_appname,
177 | input_path=input_path,
178 | debug_file_path=debug_file_path)
179 |
180 | extra_app_args = ','.join(["%s=%s" % (k, v) for k, v in config.items("app_args")])
181 |
182 | app_args = ','.join([default_app_args, extra_app_args])
183 |
184 | # Get rid of the debugging flag
185 | solver_options_args = re.sub(' --debug$', ' ', solver_options_args)
186 | solver_options_args = re.sub(' --debug ', ' ', solver_options_args)
187 |
188 | no_debug_spark_submit_cmd = spark_submit_cmd_format.format(spark_submit=spark_submit_path,
189 | lib_jar_path=lib_jar_path,
190 | class_name=class_name,
191 | examples_jar_path=examples_jar_path,
192 | spark_args=spark_args,
193 | solver_options_args=solver_options_args,
194 | app_args=app_args)
195 | print "Executing WITHOUT debugging:\n%s" % no_debug_spark_submit_cmd
196 | execute(no_debug_spark_submit_cmd)
197 |
198 |
199 | if __name__ == '__main__':
200 | main()
--------------------------------------------------------------------------------
/helpers/brutus_sample.cfg:
--------------------------------------------------------------------------------
1 | [general]
2 | experiment_name: cov_binary_k2
3 | class_name: ch.ethz.dalab.dissolve.examples.binaryclassification.COVBinary
4 |
5 | [paths]
6 | input_path: /user/torekond/data/generated/covtype.libsvm.binary.scale
7 |
8 | [parameters]
9 | ; numeric parameters
10 | lambda: 0.01
11 | minpart: auto
12 | samplefrac: 0.5
13 | oraclesize: 0
14 | ; boolean parameters
15 | sparse: false
16 |
17 |
18 | [pivot]
19 | ; param is one of: lambda, minpart, samplefrac, oraclesize, num-executors
20 | param: num-executors
21 | values: 4, 8, 16, 32
22 |
23 |
24 | [dissolve_args]
25 | stopcrit: round
26 | roundlimit: 25
27 | debug: true
28 | debugmult: 10
29 |
30 |
31 | [spark_args]
32 | driver-memory: 2G
33 | executor-memory: 7G
34 |
35 | [app_args]
36 | ; Any key-value pairs mentioned here are sent as --kwargs k1=v1,k2=v2
37 | foo:bar
38 |
--------------------------------------------------------------------------------
/helpers/brutus_setup.py:
--------------------------------------------------------------------------------
1 | '''Setup dissolve^struct environment on Brutus. (To be executed ON Brutus in the home directory.)
2 | '''
3 | __author__ = 'tribhu'
4 |
5 | import os
6 | import argparse
7 |
8 | from benchmark_utils import *
9 |
10 | from retrieve_datasets import retrieve, download_to_gen_dir
11 | from paths import PROJECT_DIR, JARS_DIR, DATA_DIR
12 |
13 | LIB_JAR_URL = 'https://dl.dropboxusercontent.com/u/12851272/dissolvestruct_2.10.jar'
14 | EXAMPLES_JAR_PATH = 'https://dl.dropboxusercontent.com/u/12851272/dissolvestructexample_2.10-0.1-SNAPSHOT.jar'
15 | SCOPT_JAR_PATH = 'https://dl.dropboxusercontent.com/u/12851272/scopt_2.10-3.3.0.jar'
16 |
17 | HOME_DIR = os.getenv("HOME")
18 |
19 |
20 | def main():
21 | parser = argparse.ArgumentParser(
22 | description='Setup Brutus cluster for dissolve^struct. Needs to be executed on Brutus login node.')
23 | parser.add_argument("username", help="Username (which has access to Hadoop Cluster)")
24 | args = parser.parse_args()
25 |
26 | username = args.username
27 |
28 | def ssh_brutus(command):
29 | ssh('hadoop', username, command)
30 |
31 | # === Obtain data ===
32 | retrieve(download_all=True)
33 |
34 | # === Obtain the jars ===
35 | jars_dir = JARS_DIR
36 | if not os.path.exists(jars_dir):
37 | os.makedirs(jars_dir)
38 |
39 | print "=== Downloading executables to: ", jars_dir, "==="
40 | for jar_url in [LIB_JAR_URL, EXAMPLES_JAR_PATH, SCOPT_JAR_PATH]:
41 | print "== Retrieving ", jar_url.split('/')[-1], '=='
42 | download_to_gen_dir(jar_url, jars_dir)
43 |
44 | # === Move data to HDFS ===
45 | print "=== Moving data to HDFS ==="
46 | put_data_cmd = "hadoop fs -put -f %s /user/%s/data" % (DATA_DIR, username)
47 | ssh_brutus(put_data_cmd)
48 |
49 | # === Create a file to mark everything is setup ===
50 | touch_file_path = os.path.join(HOME_DIR, 'onesmallstep')
51 | ssh_brutus("touch %s" % touch_file_path)
52 |
53 |
54 | if __name__ == '__main__':
55 | main()
--------------------------------------------------------------------------------
/helpers/buildall.py:
--------------------------------------------------------------------------------
1 | '''Builds the lib and examples packages. Then, syncs to Dropbox
2 | '''
3 | import os
4 | import subprocess
5 | import shutil
6 |
7 | from paths import PROJECT_DIR
8 |
9 | SYNC_JARS = True
10 |
11 | HOME_DIR = os.getenv("HOME")
12 |
13 | OUTPUT_DIR = os.path.join(HOME_DIR, 'Dropbox/Public/')
14 |
15 | LIB_JAR_PATH = os.path.join(HOME_DIR,
16 | '.ivy2/local/ch.ethz.dalab/dissolvestruct_2.10/0.1-SNAPSHOT/jars/dissolvestruct_2.10.jar')
17 | EXAMPLES_JAR_PATH = os.path.join(PROJECT_DIR, 'dissolve-struct-examples', 'target/scala-2.10/',
18 | 'dissolvestructexample_2.10-0.1-SNAPSHOT.jar')
19 | SCOPT_JAR_PATH = os.path.join(HOME_DIR, '.ivy2/cache/com.github.scopt/scopt_2.10/jars/scopt_2.10-3.3.0.jar')
20 |
21 |
22 | def execute(command, cwd='.'):
23 | subprocess.check_call(command, cwd=cwd)
24 |
25 |
26 | def main():
27 | dissolve_lib_dir = os.path.join(PROJECT_DIR, 'dissolve-struct-lib')
28 | dissolve_examples_dir = os.path.join(PROJECT_DIR, 'dissolve-struct-examples')
29 |
30 | # Build lib package
31 | print "=== Building dissolve-struct-lib ==="
32 | lib_build_cmd = ["sbt", "publish-local"]
33 | execute(lib_build_cmd, cwd=dissolve_lib_dir)
34 |
35 | # Build examples package
36 | print "=== Building dissolve-struct-examples ==="
37 | examples_build_cmd = ["sbt", "package"]
38 | execute(examples_build_cmd, cwd=dissolve_examples_dir)
39 |
40 |
41 | # Sync all packages to specified output directory
42 | if SYNC_JARS:
43 | print "=== Syncing Jars to Dropbox Public Folder ==="
44 | print LIB_JAR_PATH
45 | shutil.copy(LIB_JAR_PATH, OUTPUT_DIR)
46 | print EXAMPLES_JAR_PATH
47 | shutil.copy(EXAMPLES_JAR_PATH, OUTPUT_DIR)
48 | print SCOPT_JAR_PATH
49 | shutil.copy(SCOPT_JAR_PATH, OUTPUT_DIR)
50 |
51 |
52 | if __name__ == '__main__':
53 | main()
--------------------------------------------------------------------------------
/helpers/ocr_helpers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.io
3 | import os
4 | from paths import *
5 |
6 | FILTER_FOLD = True
7 | TST_FOLD_NUMS = [0,] # Datapoints from these folds will be treated as test dataset
8 |
9 | def convert_ocr_data():
10 | idx_trn = 0
11 | idx_tst = 0
12 |
13 | ocr_mat_path = os.path.join(DATA_DIR, 'ocr.mat')
14 | patterns_train_path = os.path.join(GEN_DATA_DIR, 'patterns_train.csv')
15 | patterns_test_path = os.path.join(GEN_DATA_DIR, 'patterns_test.csv')
16 | labels_train_path = os.path.join(GEN_DATA_DIR, 'labels_train.csv')
17 | labels_test_path = os.path.join(GEN_DATA_DIR, 'labels_test.csv')
18 | folds_train_path = os.path.join(GEN_DATA_DIR, 'folds_train.csv')
19 | folds_test_path = os.path.join(GEN_DATA_DIR, 'folds_test.csv')
20 |
21 | print "Processing features available in %s" % ocr_mat_path
22 |
23 | mat = scipy.io.loadmat(ocr_mat_path, struct_as_record=False, squeeze_me=True)
24 | n = np.shape(mat['dataset'])[0]
25 | with open(patterns_train_path, 'w') as fpat_trn, open(labels_train_path, 'w') as flab_trn, open(folds_train_path, 'w') as ffold_trn, \
26 | open(patterns_test_path, 'w') as fpat_tst, open(labels_test_path, 'w') as flab_tst, open(folds_test_path, 'w') as ffold_tst:
27 | for i in range(n):
28 | ### Write folds
29 | fold = mat['dataset'][i].__dict__['fold']
30 | if fold in TST_FOLD_NUMS:
31 | fpat, flab, ffold = fpat_tst, flab_tst, ffold_tst
32 | idx_tst += 1
33 | idx = idx_tst
34 | else:
35 | fpat, flab, ffold = fpat_trn, flab_trn, ffold_trn
36 | idx_trn += 1
37 | idx = idx_trn
38 | # FORMAT: id,fold
39 | ffold.write('%d,%d\n' % (idx, fold))
40 |
41 | ### Write patterns (x_i's)
42 | pixels = mat['dataset'][i].__dict__['pixels']
43 | num_letters = np.shape(pixels)[0]
44 | letter_shape = np.shape(pixels[0])
45 | # Create a matrix of size num_pixels+bias_var x num_letters
46 | xi = np.zeros((letter_shape[0] * letter_shape[1] + 1, num_letters))
47 | for letter_id in range(num_letters):
48 | letter = pixels[letter_id] # Returns a 16x8 matrix
49 | xi[:, letter_id] = np.append(letter.flatten(order='F'), [1.])
50 | # Vectorize the above matrix and store it
51 | # After flattening, order is column-major
52 | xi_str = ','.join([`s` for s in xi.flatten('F')])
53 | # FORMAT: id,#rows,#cols,x_0_0,x_0_1,...x_n_m
54 | fpat.write('%d,%d,%d,%s\n' % (idx, np.shape(xi)[0], np.shape(xi)[1], xi_str))
55 |
56 | ### Write labels (y_i's)
57 | labels = mat['dataset'][i].__dict__['word']
58 | labels_str = ','.join([`a` for a in labels])
59 | # FORMAT: id,#letters,letter_0,letter_1,...letter_n
60 | flab.write('%d,%d,%s\n' % (idx, np.shape(labels)[0], labels_str))
61 |
62 |
63 | def main():
64 | convert_ocr_data()
65 |
66 | if __name__ == '__main__':
67 | main()
68 |
--------------------------------------------------------------------------------
/helpers/paths.py:
--------------------------------------------------------------------------------
1 | """Keeps track of paths used by other submodules
2 | """
3 | import os
4 |
5 | current_file_path = os.path.abspath(__file__)
6 | PROJECT_DIR = os.path.join(os.path.dirname(current_file_path), os.pardir)
7 |
8 | DATA_DIR = os.path.join(PROJECT_DIR, 'data')
9 | GEN_DATA_DIR = os.path.join(DATA_DIR, 'generated')
10 |
11 | CHAIN_OCR_FILE = os.path.join(DATA_DIR, "ocr.mat")
12 |
13 | # If an sbt build cannot be performed, jars can be accessed here
14 | JARS_DIR = os.path.join(PROJECT_DIR, 'jars')
15 | LIB_JAR_PATH = os.path.join(JARS_DIR, 'dissolvestruct_2.10.jar')
16 | EXAMPLES_JAR_PATH = os.path.join(JARS_DIR, 'dissolvestructexample_2.10-0.1-SNAPSHOT.jar')
17 | SCOPT_JAR_PATH = os.path.join(JARS_DIR, 'scopt_2.10-3.3.0.jar')
18 |
19 | # Output dir. Any output produced in a subdirectory within this folder.
20 | EXPT_OUTPUT_DIR = os.path.join(PROJECT_DIR, 'benchmark-data')
21 |
22 |
23 |
24 | # URLS for DATASETS
25 | A1A_URL = "http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/a1a"
26 |
27 | COV_BIN_URL = "http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/covtype.libsvm.binary.scale.bz2"
28 |
29 | COV_MULT_URL = "http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/covtype.scale.bz2"
30 |
31 | RCV1_URL = "http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/rcv1_train.binary.bz2"
32 |
33 | MSRC_URL = "https://s3-eu-west-1.amazonaws.com/dissolve-struct/msrc/msrc.tar.gz"
34 |
35 | chain_files = ["folds_test.csv",
36 | "folds_train.csv",
37 | "patterns_train.csv",
38 | "patterns_test.csv",
39 | "labels_train.csv",
40 | "labels_test.csv",
41 | ]
42 | s3_chain_base_url = "https://s3-eu-west-1.amazonaws.com/dissolve-struct/chain"
43 | CHAIN_URLS = [os.path.join(s3_chain_base_url, fname) for fname in chain_files]
44 |
45 |
--------------------------------------------------------------------------------
/helpers/retrieve_datasets.py:
--------------------------------------------------------------------------------
1 | """Retrieves datasets from various sources
2 | """
3 | import urllib
4 | import subprocess
5 | from paths import *
6 | import argparse
7 |
8 |
9 | def decompress(filename):
10 | print "Decompressing: ", filename
11 | try:
12 | subprocess.check_call(['bzip2', '-d', filename])
13 | except subprocess.CalledProcessError as e:
14 | pass
15 |
16 |
17 | def download_to_gen_dir(url, dir=GEN_DATA_DIR):
18 | print "Downloading: ", url
19 | basename = os.path.basename(url)
20 | destname = os.path.join(dir, basename)
21 | urllib.urlretrieve(url, destname)
22 | return destname
23 |
24 |
25 | def download_and_decompress(url):
26 | destname = download_to_gen_dir(url)
27 | decompress(destname)
28 |
29 |
30 | def retrieve(download_all=False):
31 | # Retrieve the files
32 | print "=== A1A ==="
33 | download_to_gen_dir(A1A_URL)
34 |
35 | print "=== COV BINARY ==="
36 | download_and_decompress(COV_BIN_URL)
37 |
38 | print "=== COV MULTICLASS ==="
39 | download_and_decompress(COV_MULT_URL)
40 |
41 | print "=== RCV1 ==="
42 | download_and_decompress(RCV1_URL)
43 |
44 | print "=== CHAIN ==="
45 | if download_all:
46 | for url in CHAIN_URLS:
47 | download_to_gen_dir(url)
48 | else:
49 | import ocr_helpers
50 | ocr_helpers.convert_ocr_data()
51 |
52 |
53 | def main():
54 | parser = argparse.ArgumentParser(description='Retrieve datasets for dissolve^struct')
55 | parser.add_argument("-d", "--download", action="store_true",
56 | help="Download files instead of processing when possible")
57 | args = parser.parse_args()
58 |
59 | retrieve(args.download)
60 |
61 |
62 | if __name__ == '__main__':
63 | main()
64 |
--------------------------------------------------------------------------------
/paper/dissolve-jmlr-software.bbl:
--------------------------------------------------------------------------------
1 | \begin{thebibliography}{6}
2 | \providecommand{\natexlab}[1]{#1}
3 | \providecommand{\url}[1]{\texttt{#1}}
4 | \expandafter\ifx\csname urlstyle\endcsname\relax
5 | \providecommand{\doi}[1]{doi: #1}\else
6 | \providecommand{\doi}{doi: \begingroup \urlstyle{rm}\Url}\fi
7 |
8 | \bibitem[Jaggi et~al.(2014)Jaggi, Smith, Tak{\'a}{\v c}, Terhorst, Krishnan,
9 | Hofmann, and Jordan]{Jaggi:2014vi}
10 | Martin Jaggi, Virginia Smith, Martin Tak{\'a}{\v c}, Jonathan Terhorst, Sanjay
11 | Krishnan, Thomas Hofmann, and Michael~I Jordan.
12 | \newblock {Communication-Efficient Distributed Dual Coordinate Ascent}.
13 | \newblock In \emph{NIPS 2014 - Advances in Neural Information Processing
14 | Systems 27}, pages 3068--3076, 2014.
15 |
16 | \bibitem[Lacoste-Julien et~al.(2013)Lacoste-Julien, Jaggi, Schmidt, and
17 | Pletscher]{LacosteJulien:2013ue}
18 | Simon Lacoste-Julien, Martin Jaggi, Mark Schmidt, and Patrick Pletscher.
19 | \newblock {Block-Coordinate Frank-Wolfe Optimization for Structural SVMs}.
20 | \newblock In \emph{ICML 2013 - Proceedings of the 30th International Conference
21 | on Machine Learning}, 2013.
22 |
23 | \bibitem[Ratliff et~al.(2007)Ratliff, Bagnell, and Zinkevich]{Ratliff:2007ti}
24 | Nathan~D Ratliff, J~Andrew Bagnell, and Martin~A Zinkevich.
25 | \newblock {(Online) Subgradient Methods for Structured Prediction}.
26 | \newblock In \emph{AISTATS}, 2007.
27 |
28 | \bibitem[Shalev-Shwartz et~al.(2010)Shalev-Shwartz, Singer, Srebro, and
29 | Cotter]{ShalevShwartz:2010cg}
30 | Shai Shalev-Shwartz, Yoram Singer, Nathan Srebro, and Andrew Cotter.
31 | \newblock {Pegasos: Primal Estimated Sub-Gradient Solver for SVM}.
32 | \newblock \emph{Mathematical Programming}, 127\penalty0 (1):\penalty0 3--30,
33 | October 2010.
34 |
35 | \bibitem[Taskar et~al.(2003)Taskar, Guestrin, and Koller]{Taskar:2003tt}
36 | Ben Taskar, Carlos Guestrin, and Daphne Koller.
37 | \newblock {Max-Margin Markov Networks}.
38 | \newblock In \emph{NIPS 2014 - Advances in Neural Information Processing
39 | Systems 27}, 2003.
40 |
41 | \bibitem[Tsochantaridis et~al.(2005)Tsochantaridis, Joachims, Hofmann, and
42 | Altun]{Tsochantaridis:2005ww}
43 | Ioannis Tsochantaridis, Thorsten Joachims, Thomas Hofmann, and Yasemin Altun.
44 | \newblock {Large Margin Methods for Structured and Interdependent Output
45 | Variables}.
46 | \newblock \emph{The Journal of Machine Learning Research}, 6, December 2005.
47 |
48 | \end{thebibliography}
49 |
--------------------------------------------------------------------------------
/paper/dissolve-jmlr-software.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dalab/dissolve-struct/67f37377b74c32cf05d8f43a0e3658a10864f9bf/paper/dissolve-jmlr-software.pdf
--------------------------------------------------------------------------------
/paper/references.bib:
--------------------------------------------------------------------------------
1 | %
2 |
3 | @inproceedings{LacosteJulien:2013ue,
4 | author = {Lacoste-Julien, Simon and Jaggi, Martin and Schmidt, Mark and Pletscher, Patrick},
5 | title = {{Block-Coordinate Frank-Wolfe Optimization for Structural SVMs}},
6 | booktitle = {ICML 2013 - Proceedings of the 30th International Conference on Machine Learning},
7 | year = {2013}
8 | }
9 |
10 | @article{Lucchi:2015co,
11 | author = {Lucchi, Aurelien and Marquez-Neila, Pablo and Becker, Carlos and Li, Yunpeng and Smith, Kevin and Knott, Graham and Fua, Pascal},
12 | title = {{Learning Structured Models for Segmentation of 2D and 3D Imagery}},
13 | journal = {IEEE Transactions on Medical Imaging},
14 | year = {2015},
15 | pages = {1096--1110}
16 | }
17 |
18 | @inproceedings{Ma:2015ti,
19 | author = {Ma, Chenxin and Smith, Virginia and Jaggi, Martin and Jordan, Michael I and Richt{\'a}rik, Peter and Tak{\'a}{\v c}, Martin},
20 | title = {{Adding vs. Averaging in Distributed Primal-Dual Optimization}},
21 | booktitle = {ICML 2015 - Proceedings of the 32th International Conference on Machine Learning},
22 | year = {2015},
23 | pages = {1973--1982}
24 | }
25 |
26 | @inproceedings{Jaggi:2014vi,
27 | author = {Jaggi, Martin and Smith, Virginia and Tak{\'a}{\v c}, Martin and Terhorst, Jonathan and Krishnan, Sanjay and Hofmann, Thomas and Jordan, Michael I},
28 | title = {{Communication-Efficient Distributed Dual Coordinate Ascent}},
29 | booktitle = {NIPS 2014 - Advances in Neural Information Processing Systems 27},
30 | year = {2014},
31 | pages = {3068--3076}
32 | }
33 |
34 | @article{Tsochantaridis:2004tg,
35 | author = {Tsochantaridis, Ioannis and Hofmann, Thomas and Joachims, Thorsten and Altun, Yasemin},
36 | title = {{Support vector machine learning for interdependent and structured output spaces}},
37 | journal = {ICML '04: Proceedings of the twenty-first international conference on Machine learning},
38 | year = {2004},
39 | month = jul,
40 | annote = {original SVMstruct paper. also says how you get the Crammer/Singer multiclass as a special case, by copying the features}
41 | }
42 |
43 | @article{Joachims:2009ex,
44 | author = {Joachims, Thorsten and Finley, Thomas and Yu, Chun-Nam John},
45 | title = {{Cutting-Plane Training of Structural SVMs}},
46 | year = {2009},
47 | volume = {77},
48 | number = {1},
49 | pages = {27--59},
50 | month = oct
51 | }
52 |
53 | @inproceedings{Taskar:2003tt,
54 | author = {Taskar, Ben and Guestrin, Carlos and Koller, Daphne},
55 | title = {{Max-Margin Markov Networks}},
56 | booktitle = {NIPS 2014 - Advances in Neural Information Processing Systems 27},
57 | year = {2003},
58 | }
59 |
60 | @article{Tsochantaridis:2005ww,
61 | author = {Tsochantaridis, Ioannis and Joachims, Thorsten and Hofmann, Thomas and Altun, Yasemin},
62 | title = {{Large Margin Methods for Structured and Interdependent Output Variables}},
63 | journal = {The Journal of Machine Learning Research},
64 | year = {2005},
65 | volume = {6},
66 | month = dec,
67 | }
68 |
69 | @inproceedings{Ratliff:2007ti,
70 | author = {Ratliff, Nathan D and Bagnell, J Andrew and Zinkevich, Martin A},
71 | title = {{(Online) Subgradient Methods for Structured Prediction}},
72 | booktitle = {AISTATS},
73 | year = {2007}
74 | }
75 |
76 | @article{ShalevShwartz:2010cg,
77 | author = {Shalev-Shwartz, Shai and Singer, Yoram and Srebro, Nathan and Cotter, Andrew},
78 | title = {{Pegasos: Primal Estimated Sub-Gradient Solver for SVM}},
79 | journal = {Mathematical Programming},
80 | year = {2010},
81 | volume = {127},
82 | number = {1},
83 | pages = {3--30},
84 | month = oct
85 | }
86 |
87 |
88 |
89 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.9.2
2 | scipy==0.15.1
3 |
--------------------------------------------------------------------------------