├── .gitattributes ├── .gitignore ├── Cargo.lock ├── Cargo.toml ├── LICENSE ├── README.md └── src ├── bin ├── cli.yml └── craftml.rs ├── data ├── mod.rs └── sparse_vector.rs ├── lib.rs ├── model ├── eval.rs ├── mod.rs ├── skmeans.rs └── train.rs └── util.rs /.gitattributes: -------------------------------------------------------------------------------- 1 | Cargo.lock -diff 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | /target/ 4 | 5 | # These are backup files generated by rustfmt 6 | **/*.rs.bk 7 | 8 | /.idea 9 | .DS_Store 10 | -------------------------------------------------------------------------------- /Cargo.lock: -------------------------------------------------------------------------------- 1 | [[package]] 2 | name = "ansi_term" 3 | version = "0.11.0" 4 | source = "registry+https://github.com/rust-lang/crates.io-index" 5 | dependencies = [ 6 | "winapi 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)", 7 | ] 8 | 9 | [[package]] 10 | name = "arrayvec" 11 | version = "0.4.7" 12 | source = "registry+https://github.com/rust-lang/crates.io-index" 13 | dependencies = [ 14 | "nodrop 0.1.12 (registry+https://github.com/rust-lang/crates.io-index)", 15 | ] 16 | 17 | [[package]] 18 | name = "atty" 19 | version = "0.2.11" 20 | source = "registry+https://github.com/rust-lang/crates.io-index" 21 | dependencies = [ 22 | "libc 0.2.43 (registry+https://github.com/rust-lang/crates.io-index)", 23 | "termion 1.5.1 (registry+https://github.com/rust-lang/crates.io-index)", 24 | "winapi 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)", 25 | ] 26 | 27 | [[package]] 28 | name = "bincode" 29 | version = "1.0.1" 30 | source = "registry+https://github.com/rust-lang/crates.io-index" 31 | dependencies = [ 32 | "byteorder 1.2.6 (registry+https://github.com/rust-lang/crates.io-index)", 33 | "serde 1.0.79 (registry+https://github.com/rust-lang/crates.io-index)", 34 | ] 35 | 36 | [[package]] 37 | name = "bitflags" 38 | version = "1.0.4" 39 | source = "registry+https://github.com/rust-lang/crates.io-index" 40 | 41 | [[package]] 42 | name = "byteorder" 43 | version = "1.2.6" 44 | source = "registry+https://github.com/rust-lang/crates.io-index" 45 | 46 | [[package]] 47 | name = "cfg-if" 48 | version = "0.1.5" 49 | source = "registry+https://github.com/rust-lang/crates.io-index" 50 | 51 | [[package]] 52 | name = "clap" 53 | version = "2.32.0" 54 | source = "registry+https://github.com/rust-lang/crates.io-index" 55 | dependencies = [ 56 | "ansi_term 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)", 57 | "atty 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)", 58 | "bitflags 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)", 59 | "strsim 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", 60 | "textwrap 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)", 61 | "unicode-width 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)", 62 | "vec_map 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)", 63 | "yaml-rust 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", 64 | ] 65 | 66 | [[package]] 67 | name = "cloudabi" 68 | version = "0.0.3" 69 | source = "registry+https://github.com/rust-lang/crates.io-index" 70 | dependencies = [ 71 | "bitflags 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)", 72 | ] 73 | 74 | [[package]] 75 | name = "craftml" 76 | version = "0.0.1" 77 | dependencies = [ 78 | "bincode 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", 79 | "clap 2.32.0 (registry+https://github.com/rust-lang/crates.io-index)", 80 | "fasthash 0.3.2 (registry+https://github.com/rust-lang/crates.io-index)", 81 | "itertools 0.7.8 (registry+https://github.com/rust-lang/crates.io-index)", 82 | "log 0.4.5 (registry+https://github.com/rust-lang/crates.io-index)", 83 | "maplit 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", 84 | "order-stat 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)", 85 | "pbr 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", 86 | "rand 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)", 87 | "rayon 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)", 88 | "serde 1.0.79 (registry+https://github.com/rust-lang/crates.io-index)", 89 | "serde_derive 1.0.79 (registry+https://github.com/rust-lang/crates.io-index)", 90 | "simple_logger 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)", 91 | "time 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", 92 | ] 93 | 94 | [[package]] 95 | name = "crossbeam-deque" 96 | version = "0.2.0" 97 | source = "registry+https://github.com/rust-lang/crates.io-index" 98 | dependencies = [ 99 | "crossbeam-epoch 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", 100 | "crossbeam-utils 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)", 101 | ] 102 | 103 | [[package]] 104 | name = "crossbeam-epoch" 105 | version = "0.3.1" 106 | source = "registry+https://github.com/rust-lang/crates.io-index" 107 | dependencies = [ 108 | "arrayvec 0.4.7 (registry+https://github.com/rust-lang/crates.io-index)", 109 | "cfg-if 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)", 110 | "crossbeam-utils 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)", 111 | "lazy_static 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)", 112 | "memoffset 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", 113 | "nodrop 0.1.12 (registry+https://github.com/rust-lang/crates.io-index)", 114 | "scopeguard 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)", 115 | ] 116 | 117 | [[package]] 118 | name = "crossbeam-utils" 119 | version = "0.2.2" 120 | source = "registry+https://github.com/rust-lang/crates.io-index" 121 | dependencies = [ 122 | "cfg-if 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)", 123 | ] 124 | 125 | [[package]] 126 | name = "either" 127 | version = "1.5.0" 128 | source = "registry+https://github.com/rust-lang/crates.io-index" 129 | 130 | [[package]] 131 | name = "fasthash" 132 | version = "0.3.2" 133 | source = "registry+https://github.com/rust-lang/crates.io-index" 134 | dependencies = [ 135 | "fasthash-sys 0.3.2 (registry+https://github.com/rust-lang/crates.io-index)", 136 | "seahash 3.0.5 (registry+https://github.com/rust-lang/crates.io-index)", 137 | "xoroshiro128 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", 138 | ] 139 | 140 | [[package]] 141 | name = "fasthash-sys" 142 | version = "0.3.2" 143 | source = "registry+https://github.com/rust-lang/crates.io-index" 144 | dependencies = [ 145 | "gcc 0.3.54 (registry+https://github.com/rust-lang/crates.io-index)", 146 | ] 147 | 148 | [[package]] 149 | name = "fuchsia-zircon" 150 | version = "0.3.3" 151 | source = "registry+https://github.com/rust-lang/crates.io-index" 152 | dependencies = [ 153 | "bitflags 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)", 154 | "fuchsia-zircon-sys 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)", 155 | ] 156 | 157 | [[package]] 158 | name = "fuchsia-zircon-sys" 159 | version = "0.3.3" 160 | source = "registry+https://github.com/rust-lang/crates.io-index" 161 | 162 | [[package]] 163 | name = "gcc" 164 | version = "0.3.54" 165 | source = "registry+https://github.com/rust-lang/crates.io-index" 166 | 167 | [[package]] 168 | name = "itertools" 169 | version = "0.7.8" 170 | source = "registry+https://github.com/rust-lang/crates.io-index" 171 | dependencies = [ 172 | "either 1.5.0 (registry+https://github.com/rust-lang/crates.io-index)", 173 | ] 174 | 175 | [[package]] 176 | name = "kernel32-sys" 177 | version = "0.2.2" 178 | source = "registry+https://github.com/rust-lang/crates.io-index" 179 | dependencies = [ 180 | "winapi 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)", 181 | "winapi-build 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", 182 | ] 183 | 184 | [[package]] 185 | name = "lazy_static" 186 | version = "1.1.0" 187 | source = "registry+https://github.com/rust-lang/crates.io-index" 188 | dependencies = [ 189 | "version_check 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)", 190 | ] 191 | 192 | [[package]] 193 | name = "libc" 194 | version = "0.2.43" 195 | source = "registry+https://github.com/rust-lang/crates.io-index" 196 | 197 | [[package]] 198 | name = "log" 199 | version = "0.4.5" 200 | source = "registry+https://github.com/rust-lang/crates.io-index" 201 | dependencies = [ 202 | "cfg-if 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)", 203 | ] 204 | 205 | [[package]] 206 | name = "maplit" 207 | version = "1.0.1" 208 | source = "registry+https://github.com/rust-lang/crates.io-index" 209 | 210 | [[package]] 211 | name = "memoffset" 212 | version = "0.2.1" 213 | source = "registry+https://github.com/rust-lang/crates.io-index" 214 | 215 | [[package]] 216 | name = "nodrop" 217 | version = "0.1.12" 218 | source = "registry+https://github.com/rust-lang/crates.io-index" 219 | 220 | [[package]] 221 | name = "num_cpus" 222 | version = "1.8.0" 223 | source = "registry+https://github.com/rust-lang/crates.io-index" 224 | dependencies = [ 225 | "libc 0.2.43 (registry+https://github.com/rust-lang/crates.io-index)", 226 | ] 227 | 228 | [[package]] 229 | name = "order-stat" 230 | version = "0.1.3" 231 | source = "registry+https://github.com/rust-lang/crates.io-index" 232 | 233 | [[package]] 234 | name = "pbr" 235 | version = "1.0.1" 236 | source = "registry+https://github.com/rust-lang/crates.io-index" 237 | dependencies = [ 238 | "kernel32-sys 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)", 239 | "libc 0.2.43 (registry+https://github.com/rust-lang/crates.io-index)", 240 | "termion 1.5.1 (registry+https://github.com/rust-lang/crates.io-index)", 241 | "time 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", 242 | "winapi 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)", 243 | ] 244 | 245 | [[package]] 246 | name = "proc-macro2" 247 | version = "0.4.19" 248 | source = "registry+https://github.com/rust-lang/crates.io-index" 249 | dependencies = [ 250 | "unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", 251 | ] 252 | 253 | [[package]] 254 | name = "quote" 255 | version = "0.6.8" 256 | source = "registry+https://github.com/rust-lang/crates.io-index" 257 | dependencies = [ 258 | "proc-macro2 0.4.19 (registry+https://github.com/rust-lang/crates.io-index)", 259 | ] 260 | 261 | [[package]] 262 | name = "rand" 263 | version = "0.4.3" 264 | source = "registry+https://github.com/rust-lang/crates.io-index" 265 | dependencies = [ 266 | "fuchsia-zircon 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)", 267 | "libc 0.2.43 (registry+https://github.com/rust-lang/crates.io-index)", 268 | "winapi 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)", 269 | ] 270 | 271 | [[package]] 272 | name = "rand" 273 | version = "0.6.0" 274 | source = "registry+https://github.com/rust-lang/crates.io-index" 275 | dependencies = [ 276 | "cloudabi 0.0.3 (registry+https://github.com/rust-lang/crates.io-index)", 277 | "fuchsia-zircon 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)", 278 | "libc 0.2.43 (registry+https://github.com/rust-lang/crates.io-index)", 279 | "rand_chacha 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", 280 | "rand_core 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", 281 | "rand_hc 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", 282 | "rand_isaac 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", 283 | "rand_pcg 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", 284 | "rand_xorshift 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", 285 | "rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)", 286 | "winapi 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)", 287 | ] 288 | 289 | [[package]] 290 | name = "rand_chacha" 291 | version = "0.1.0" 292 | source = "registry+https://github.com/rust-lang/crates.io-index" 293 | dependencies = [ 294 | "rand_core 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", 295 | "rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)", 296 | ] 297 | 298 | [[package]] 299 | name = "rand_core" 300 | version = "0.3.0" 301 | source = "registry+https://github.com/rust-lang/crates.io-index" 302 | 303 | [[package]] 304 | name = "rand_hc" 305 | version = "0.1.0" 306 | source = "registry+https://github.com/rust-lang/crates.io-index" 307 | dependencies = [ 308 | "rand_core 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", 309 | ] 310 | 311 | [[package]] 312 | name = "rand_isaac" 313 | version = "0.1.0" 314 | source = "registry+https://github.com/rust-lang/crates.io-index" 315 | dependencies = [ 316 | "rand_core 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", 317 | ] 318 | 319 | [[package]] 320 | name = "rand_pcg" 321 | version = "0.1.1" 322 | source = "registry+https://github.com/rust-lang/crates.io-index" 323 | dependencies = [ 324 | "rand_core 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", 325 | "rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)", 326 | ] 327 | 328 | [[package]] 329 | name = "rand_xorshift" 330 | version = "0.1.0" 331 | source = "registry+https://github.com/rust-lang/crates.io-index" 332 | dependencies = [ 333 | "rand_core 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", 334 | ] 335 | 336 | [[package]] 337 | name = "rayon" 338 | version = "1.0.2" 339 | source = "registry+https://github.com/rust-lang/crates.io-index" 340 | dependencies = [ 341 | "crossbeam-deque 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", 342 | "either 1.5.0 (registry+https://github.com/rust-lang/crates.io-index)", 343 | "rayon-core 1.4.1 (registry+https://github.com/rust-lang/crates.io-index)", 344 | ] 345 | 346 | [[package]] 347 | name = "rayon-core" 348 | version = "1.4.1" 349 | source = "registry+https://github.com/rust-lang/crates.io-index" 350 | dependencies = [ 351 | "crossbeam-deque 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", 352 | "lazy_static 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)", 353 | "libc 0.2.43 (registry+https://github.com/rust-lang/crates.io-index)", 354 | "num_cpus 1.8.0 (registry+https://github.com/rust-lang/crates.io-index)", 355 | ] 356 | 357 | [[package]] 358 | name = "redox_syscall" 359 | version = "0.1.40" 360 | source = "registry+https://github.com/rust-lang/crates.io-index" 361 | 362 | [[package]] 363 | name = "redox_termios" 364 | version = "0.1.1" 365 | source = "registry+https://github.com/rust-lang/crates.io-index" 366 | dependencies = [ 367 | "redox_syscall 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", 368 | ] 369 | 370 | [[package]] 371 | name = "rustc_version" 372 | version = "0.2.3" 373 | source = "registry+https://github.com/rust-lang/crates.io-index" 374 | dependencies = [ 375 | "semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)", 376 | ] 377 | 378 | [[package]] 379 | name = "scopeguard" 380 | version = "0.3.3" 381 | source = "registry+https://github.com/rust-lang/crates.io-index" 382 | 383 | [[package]] 384 | name = "seahash" 385 | version = "3.0.5" 386 | source = "registry+https://github.com/rust-lang/crates.io-index" 387 | 388 | [[package]] 389 | name = "semver" 390 | version = "0.9.0" 391 | source = "registry+https://github.com/rust-lang/crates.io-index" 392 | dependencies = [ 393 | "semver-parser 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", 394 | ] 395 | 396 | [[package]] 397 | name = "semver-parser" 398 | version = "0.7.0" 399 | source = "registry+https://github.com/rust-lang/crates.io-index" 400 | 401 | [[package]] 402 | name = "serde" 403 | version = "1.0.79" 404 | source = "registry+https://github.com/rust-lang/crates.io-index" 405 | 406 | [[package]] 407 | name = "serde_derive" 408 | version = "1.0.79" 409 | source = "registry+https://github.com/rust-lang/crates.io-index" 410 | dependencies = [ 411 | "proc-macro2 0.4.19 (registry+https://github.com/rust-lang/crates.io-index)", 412 | "quote 0.6.8 (registry+https://github.com/rust-lang/crates.io-index)", 413 | "syn 0.15.7 (registry+https://github.com/rust-lang/crates.io-index)", 414 | ] 415 | 416 | [[package]] 417 | name = "simple_logger" 418 | version = "0.5.0" 419 | source = "registry+https://github.com/rust-lang/crates.io-index" 420 | dependencies = [ 421 | "log 0.4.5 (registry+https://github.com/rust-lang/crates.io-index)", 422 | "time 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", 423 | ] 424 | 425 | [[package]] 426 | name = "strsim" 427 | version = "0.7.0" 428 | source = "registry+https://github.com/rust-lang/crates.io-index" 429 | 430 | [[package]] 431 | name = "syn" 432 | version = "0.15.7" 433 | source = "registry+https://github.com/rust-lang/crates.io-index" 434 | dependencies = [ 435 | "proc-macro2 0.4.19 (registry+https://github.com/rust-lang/crates.io-index)", 436 | "quote 0.6.8 (registry+https://github.com/rust-lang/crates.io-index)", 437 | "unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", 438 | ] 439 | 440 | [[package]] 441 | name = "termion" 442 | version = "1.5.1" 443 | source = "registry+https://github.com/rust-lang/crates.io-index" 444 | dependencies = [ 445 | "libc 0.2.43 (registry+https://github.com/rust-lang/crates.io-index)", 446 | "redox_syscall 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", 447 | "redox_termios 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", 448 | ] 449 | 450 | [[package]] 451 | name = "textwrap" 452 | version = "0.10.0" 453 | source = "registry+https://github.com/rust-lang/crates.io-index" 454 | dependencies = [ 455 | "unicode-width 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)", 456 | ] 457 | 458 | [[package]] 459 | name = "time" 460 | version = "0.1.40" 461 | source = "registry+https://github.com/rust-lang/crates.io-index" 462 | dependencies = [ 463 | "libc 0.2.43 (registry+https://github.com/rust-lang/crates.io-index)", 464 | "redox_syscall 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", 465 | "winapi 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)", 466 | ] 467 | 468 | [[package]] 469 | name = "unicode-width" 470 | version = "0.1.5" 471 | source = "registry+https://github.com/rust-lang/crates.io-index" 472 | 473 | [[package]] 474 | name = "unicode-xid" 475 | version = "0.1.0" 476 | source = "registry+https://github.com/rust-lang/crates.io-index" 477 | 478 | [[package]] 479 | name = "vec_map" 480 | version = "0.8.1" 481 | source = "registry+https://github.com/rust-lang/crates.io-index" 482 | 483 | [[package]] 484 | name = "version_check" 485 | version = "0.1.5" 486 | source = "registry+https://github.com/rust-lang/crates.io-index" 487 | 488 | [[package]] 489 | name = "winapi" 490 | version = "0.2.8" 491 | source = "registry+https://github.com/rust-lang/crates.io-index" 492 | 493 | [[package]] 494 | name = "winapi" 495 | version = "0.3.6" 496 | source = "registry+https://github.com/rust-lang/crates.io-index" 497 | dependencies = [ 498 | "winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", 499 | "winapi-x86_64-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", 500 | ] 501 | 502 | [[package]] 503 | name = "winapi-build" 504 | version = "0.1.1" 505 | source = "registry+https://github.com/rust-lang/crates.io-index" 506 | 507 | [[package]] 508 | name = "winapi-i686-pc-windows-gnu" 509 | version = "0.4.0" 510 | source = "registry+https://github.com/rust-lang/crates.io-index" 511 | 512 | [[package]] 513 | name = "winapi-x86_64-pc-windows-gnu" 514 | version = "0.4.0" 515 | source = "registry+https://github.com/rust-lang/crates.io-index" 516 | 517 | [[package]] 518 | name = "xoroshiro128" 519 | version = "0.3.0" 520 | source = "registry+https://github.com/rust-lang/crates.io-index" 521 | dependencies = [ 522 | "rand 0.4.3 (registry+https://github.com/rust-lang/crates.io-index)", 523 | ] 524 | 525 | [[package]] 526 | name = "yaml-rust" 527 | version = "0.3.5" 528 | source = "registry+https://github.com/rust-lang/crates.io-index" 529 | 530 | [metadata] 531 | "checksum ansi_term 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ee49baf6cb617b853aa8d93bf420db2383fab46d314482ca2803b40d5fde979b" 532 | "checksum arrayvec 0.4.7 (registry+https://github.com/rust-lang/crates.io-index)" = "a1e964f9e24d588183fcb43503abda40d288c8657dfc27311516ce2f05675aef" 533 | "checksum atty 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)" = "9a7d5b8723950951411ee34d271d99dddcc2035a16ab25310ea2c8cfd4369652" 534 | "checksum bincode 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "9f2fb9e29e72fd6bc12071533d5dc7664cb01480c59406f656d7ac25c7bd8ff7" 535 | "checksum bitflags 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)" = "228047a76f468627ca71776ecdebd732a3423081fcf5125585bcd7c49886ce12" 536 | "checksum byteorder 1.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "90492c5858dd7d2e78691cfb89f90d273a2800fc11d98f60786e5d87e2f83781" 537 | "checksum cfg-if 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "0c4e7bb64a8ebb0d856483e1e682ea3422f883c5f5615a90d51a2c82fe87fdd3" 538 | "checksum clap 2.32.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b957d88f4b6a63b9d70d5f454ac8011819c6efa7727858f458ab71c756ce2d3e" 539 | "checksum cloudabi 0.0.3 (registry+https://github.com/rust-lang/crates.io-index)" = "ddfc5b9aa5d4507acaf872de71051dfd0e309860e88966e1051e462a077aac4f" 540 | "checksum crossbeam-deque 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "f739f8c5363aca78cfb059edf753d8f0d36908c348f3d8d1503f03d8b75d9cf3" 541 | "checksum crossbeam-epoch 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "927121f5407de9956180ff5e936fe3cf4324279280001cd56b669d28ee7e9150" 542 | "checksum crossbeam-utils 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "2760899e32a1d58d5abb31129f8fae5de75220bc2176e77ff7c627ae45c918d9" 543 | "checksum either 1.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3be565ca5c557d7f59e7cfcf1844f9e3033650c929c6566f511e8005f205c1d0" 544 | "checksum fasthash 0.3.2 (registry+https://github.com/rust-lang/crates.io-index)" = "ce56b715df3559085323d0a6724eccfd52994ac5abac9e9ffc6093853163f3bb" 545 | "checksum fasthash-sys 0.3.2 (registry+https://github.com/rust-lang/crates.io-index)" = "b6de941abfe2e715cdd34009d90546f850597eb69ca628ddfbf616e53dda28f8" 546 | "checksum fuchsia-zircon 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2e9763c69ebaae630ba35f74888db465e49e259ba1bc0eda7d06f4a067615d82" 547 | "checksum fuchsia-zircon-sys 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3dcaa9ae7725d12cdb85b3ad99a434db70b468c09ded17e012d86b5c1010f7a7" 548 | "checksum gcc 0.3.54 (registry+https://github.com/rust-lang/crates.io-index)" = "5e33ec290da0d127825013597dbdfc28bee4964690c7ce1166cbc2a7bd08b1bb" 549 | "checksum itertools 0.7.8 (registry+https://github.com/rust-lang/crates.io-index)" = "f58856976b776fedd95533137617a02fb25719f40e7d9b01c7043cd65474f450" 550 | "checksum kernel32-sys 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7507624b29483431c0ba2d82aece8ca6cdba9382bff4ddd0f7490560c056098d" 551 | "checksum lazy_static 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca488b89a5657b0a2ecd45b95609b3e848cf1755da332a0da46e2b2b1cb371a7" 552 | "checksum libc 0.2.43 (registry+https://github.com/rust-lang/crates.io-index)" = "76e3a3ef172f1a0b9a9ff0dd1491ae5e6c948b94479a3021819ba7d860c8645d" 553 | "checksum log 0.4.5 (registry+https://github.com/rust-lang/crates.io-index)" = "d4fcce5fa49cc693c312001daf1d13411c4a5283796bac1084299ea3e567113f" 554 | "checksum maplit 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "08cbb6b4fef96b6d77bfc40ec491b1690c779e77b05cd9f07f787ed376fd4c43" 555 | "checksum memoffset 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "0f9dc261e2b62d7a622bf416ea3c5245cdd5d9a7fcc428c0d06804dfce1775b3" 556 | "checksum nodrop 0.1.12 (registry+https://github.com/rust-lang/crates.io-index)" = "9a2228dca57108069a5262f2ed8bd2e82496d2e074a06d1ccc7ce1687b6ae0a2" 557 | "checksum num_cpus 1.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "c51a3322e4bca9d212ad9a158a02abc6934d005490c054a2778df73a70aa0a30" 558 | "checksum order-stat 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "efa535d5117d3661134dbf1719b6f0ffe06f2375843b13935db186cd094105eb" 559 | "checksum pbr 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "deb73390ab68d81992bd994d145f697451bb0b54fd39738e72eef32458ad6907" 560 | "checksum proc-macro2 0.4.19 (registry+https://github.com/rust-lang/crates.io-index)" = "ffe022fb8c8bd254524b0b3305906c1921fa37a84a644e29079a9e62200c3901" 561 | "checksum quote 0.6.8 (registry+https://github.com/rust-lang/crates.io-index)" = "dd636425967c33af890042c483632d33fa7a18f19ad1d7ea72e8998c6ef8dea5" 562 | "checksum rand 0.4.3 (registry+https://github.com/rust-lang/crates.io-index)" = "8356f47b32624fef5b3301c1be97e5944ecdd595409cc5da11d05f211db6cfbd" 563 | "checksum rand 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)" = "de3f08319b5395bd19b70e73c4c465329495db02dafeb8ca711a20f1c2bd058c" 564 | "checksum rand_chacha 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "771b009e3a508cb67e8823dda454aaa5368c7bc1c16829fb77d3e980440dd34a" 565 | "checksum rand_core 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "0905b6b7079ec73b314d4c748701f6931eb79fd97c668caa3f1899b22b32c6db" 566 | "checksum rand_hc 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "7b40677c7be09ae76218dc623efbf7b18e34bced3f38883af07bb75630a21bc4" 567 | "checksum rand_isaac 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "2d6ecfe9ebf36acd47a49d150990b047a5f7db0a7236ee2414b7ff5cc1097c7b" 568 | "checksum rand_pcg 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "086bd09a33c7044e56bb44d5bdde5a60e7f119a9e95b0775f545de759a32fe05" 569 | "checksum rand_xorshift 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "effa3fcaa47e18db002bdde6060944b6d2f9cfd8db471c30e873448ad9187be3" 570 | "checksum rayon 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "df7a791f788cb4c516f0e091301a29c2b71ef680db5e644a7d68835c8ae6dbfa" 571 | "checksum rayon-core 1.4.1 (registry+https://github.com/rust-lang/crates.io-index)" = "b055d1e92aba6877574d8fe604a63c8b5df60f60e5982bf7ccbb1338ea527356" 572 | "checksum redox_syscall 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)" = "c214e91d3ecf43e9a4e41e578973adeb14b474f2bee858742d127af75a0112b1" 573 | "checksum redox_termios 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "7e891cfe48e9100a70a3b6eb652fef28920c117d366339687bd5576160db0f76" 574 | "checksum rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a" 575 | "checksum scopeguard 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "94258f53601af11e6a49f722422f6e3425c52b06245a5cf9bc09908b174f5e27" 576 | "checksum seahash 3.0.5 (registry+https://github.com/rust-lang/crates.io-index)" = "e048636bed25842fcdc36e5ad1ec6295b72d4b5b8a4b759b64915a4ce2b9d09d" 577 | "checksum semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1d7eb9ef2c18661902cc47e535f9bc51b78acd254da71d375c2f6720d9a40403" 578 | "checksum semver-parser 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3" 579 | "checksum serde 1.0.79 (registry+https://github.com/rust-lang/crates.io-index)" = "84257ccd054dc351472528c8587b4de2dbf0dc0fe2e634030c1a90bfdacebaa9" 580 | "checksum serde_derive 1.0.79 (registry+https://github.com/rust-lang/crates.io-index)" = "31569d901045afbff7a9479f793177fe9259819aff10ab4f89ef69bbc5f567fe" 581 | "checksum simple_logger 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "2c0619150c42143a91bd79aa00b5f01f9b0a3ec38b1a59bc0b2f5aa24fc4c9bd" 582 | "checksum strsim 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "bb4f380125926a99e52bc279241539c018323fab05ad6368b56f93d9369ff550" 583 | "checksum syn 0.15.7 (registry+https://github.com/rust-lang/crates.io-index)" = "455a6ec9b368f8c479b0ae5494d13b22dc00990d2f00d68c9dc6a2dc4f17f210" 584 | "checksum termion 1.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "689a3bdfaab439fd92bc87df5c4c78417d3cbe537487274e9b0b2dce76e92096" 585 | "checksum textwrap 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)" = "307686869c93e71f94da64286f9a9524c0f308a9e1c87a583de8e9c9039ad3f6" 586 | "checksum time 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)" = "d825be0eb33fda1a7e68012d51e9c7f451dc1a69391e7fdc197060bb8c56667b" 587 | "checksum unicode-width 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "882386231c45df4700b275c7ff55b6f3698780a650026380e72dabe76fa46526" 588 | "checksum unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "fc72304796d0818e357ead4e000d19c9c174ab23dc11093ac919054d20a6a7fc" 589 | "checksum vec_map 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)" = "05c78687fb1a80548ae3250346c3db86a80a7cdd77bda190189f2d0a0987c81a" 590 | "checksum version_check 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "914b1a6776c4c929a602fafd8bc742e06365d4bcbe48c30f9cca5824f70dc9dd" 591 | "checksum winapi 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)" = "167dc9d6949a9b857f3451275e911c3f44255842c1f7a76f33c55103a909087a" 592 | "checksum winapi 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "92c1eb33641e276cfa214a0522acad57be5c56b10cb348b3c5117db75f3ac4b0" 593 | "checksum winapi-build 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "2d315eee3b34aca4797b2da6b13ed88266e6d612562a0c46390af8299fc699bc" 594 | "checksum winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" 595 | "checksum winapi-x86_64-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" 596 | "checksum xoroshiro128 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e0eeda34baec49c4f1eb2c04d59b761582fd6330010f9330ca696ca1a355dfcd" 597 | "checksum yaml-rust 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)" = "e66366e18dc58b46801afbf2ca7661a9f59cc8c5962c29892b6039b4f86fa992" 598 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "craftml" 3 | version = "0.0.1" 4 | authors = ["Tom Dong "] 5 | 6 | [dependencies] 7 | fasthash = "0.3" 8 | itertools = "0.7" 9 | rand = "0.6" 10 | serde = "1.0" 11 | serde_derive = "1.0" 12 | maplit = "1.0" 13 | time = "0.1" 14 | bincode = "1.0" 15 | log = "0.4" 16 | simple_logger = "0.5" 17 | pbr = "1.0" 18 | clap = {version = "2.32", features = ["yaml"]} 19 | rayon = "1.0" 20 | order-stat = "0.1" 21 | 22 | [profile.dev] 23 | opt-level = 3 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Yubing Dong (Tom) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # craftml-rs 2 | A Rust implementation of CRAFTML, an Efficient Clustering-based Random Forest for Extreme Multi-label Learning (Siblini et al., 2018). 3 | 4 | ## Performance 5 | 6 | This implementation has been tested on datasets from the [Extreme Classification Repository](http://manikvarma.org/downloads/XC/XMLRepository.html). Each data set comes either with a single data file and separate files for train / test splits, or with two separate train / test data files. 7 | 8 | A data file starts with a header line with three space-separated integers: total number of examples, number of features, and number of labels. Following the header line, there is one line per each example, starting with comma-separated labels, followed by space-separated feature:value pairs: 9 | ``` 10 | label1,label2,...labelk ft1:ft1_val ft2:ft2_val ft3:ft3_val .. ftd:ftd_val 11 | ``` 12 | 13 | A split file is a integer matrix, with one line per row, and columns separated by spaces. The integers are example indices (1-indexed) in the corresponding data file, and each column corresponds to a separate split. 14 | 15 | Precisions at 1, 3, and 5 are calculated for models trained with default hyper-parameters, e.g. 16 | - `craftml train Mediamill/data.txt --cv_splits_path Mediamill/train_split.txt` for Mediamill, which has a single data file and separate train / test split files; 17 | - `craftml train EURLex-4K/train.txt --test_data EURLex-4K/test.txt` for EURLex-4K, which has separate train / test data files. 18 | 19 | | Dataset | P@1 | P@3 | P@5 | 20 | | --- | --- | --- | --- | 21 | | Mediamill | 85.51 | 69.94 | 56.39 | 22 | | Bibtex | 61.47 | 37.20 | 27.32 | 23 | | Delicious | 67.78 | 62.15 | 57.63 | 24 | | EURLex-4K | 79.52 | 66.42 | 55.25 | 25 | | Wiki10-31K | 83.57 | 72.69 | 63.65 | 26 | | WikiLSHTC-325K | 51.79 | 32.41 | 23.43 | 27 | | Delicious-200K | 47.34 | 40.85 | 37.67 | 28 | | Amazon-670K | 38.40 | 34.21 | 31.41 | 29 | | AmazonCat-13K | 92.88 | 77.48 | 61.32 | 30 | 31 | These numbers are generally consistent with those reported in the original paper. 32 | 33 | Note that if there isn't enough memory to train on a large data set, the `--test_trees_singly` flag can be set to only train & test one tree at a time, and discard each tree when it's been tested. This allows one to obtain test results without being able to fit the entire model in memory. One can also tune the `--centroid_preserve_ratio` option to trade off between model size and accuracy. 34 | 35 | ## Build 36 | The project can be easily built with [Cargo](https://doc.rust-lang.org/cargo/getting-started/installation.html): 37 | ``` 38 | $ cargo build --release 39 | ``` 40 | 41 | The compiled binary file will be available at `target/release/craftml`. 42 | 43 | ## Usage 44 | ``` 45 | $ craftml train --help 46 | 47 | craftml-train 48 | Train a new CRAFTML model 49 | 50 | USAGE: 51 | craftml train [FLAGS] [OPTIONS] 52 | 53 | FLAGS: 54 | -h, --help Prints help information 55 | --test_trees_singly Test forest tree by tree, freeing each before training the next to reduce memory usage. 56 | Model cannot be saved. 57 | -V, --version Prints version information 58 | 59 | OPTIONS: 60 | --centroid_min_n_preserve 61 | The minimum number of entries to preserve from puning, regardless preserve ratio setting. [default: 10] 62 | 63 | --centroid_preserve_ratio 64 | A real number between 0 and 1, which is the ratio of entries with largest absoulte values to preserve. The 65 | rest of the entries are pruned. [default: 0.1] 66 | --cluster_sample_size 67 | Number of examples drawn for clustering on a branching node [default: 20000] 68 | 69 | --cv_splits_path 70 | Path to the k-fold cross validation splits file, with k space-separated columns of indices (starting from 1) 71 | for training splits. 72 | --k_clusters Number of clusters on a branching node [default: 10] 73 | --leaf_max_size 74 | Maximum number of distinct examples on a leaf node [default: 10] 75 | 76 | --model_path Path to which the trained model will be saved if provided 77 | --n_cluster_iters 78 | Number of clustering iterations to run on each branching node [default: 2] 79 | 80 | --n_feature_buckets 81 | Number of buckets into which features are hashed [default: 10000] 82 | 83 | --n_label_buckets 84 | Number of buckets into which labels are hashed [default: 10000] 85 | 86 | --n_threads 87 | Number of worker threads. If 0, the number is selected automatically. [default: 0] 88 | 89 | --n_trees Number of trees in the random forest [default: 50] 90 | --out_path 91 | Path to the which predictions will be written, if provided 92 | 93 | --test_data 94 | Path to test dataset file used to calculate metrics if provided (in the format of the Extreme Classification 95 | Repository) 96 | 97 | ARGS: 98 | Path to training dataset file (in the format of the Extreme Classification Repository) 99 | ``` 100 | 101 | ``` 102 | $ craftml test --help 103 | 104 | craftml-test 105 | Test an existing CRAFTML model 106 | 107 | USAGE: 108 | craftml test [OPTIONS] 109 | 110 | FLAGS: 111 | -h, --help Prints help information 112 | -V, --version Prints version information 113 | 114 | OPTIONS: 115 | --k_top Number of top predictions to write out for each test example [default: 5] 116 | --n_threads Number of worker threads. If 0, the number is selected automatically. [default: 0] 117 | --out_path Path to the which predictions will be written, if provided 118 | 119 | ARGS: 120 | Path to the trained model 121 | Path to test dataset file (in the format of the Extreme Classification Repository) 122 | ``` 123 | 124 | ## References 125 | 126 | - Siblini, W., Kuntz, P., & Meyer, F. (2018). *CRAFTML, an Efficient Clustering-based Random Forest for Extreme Multi-label Learning.* In Proceedings of the 35th International Conference on Machine Learning (Vol. 80, pp. 4664–4673). Stockholmsmässan, Stockholm Sweden: PMLR. http://proceedings.mlr.press/v80/siblini18a.html 127 | -------------------------------------------------------------------------------- /src/bin/cli.yml: -------------------------------------------------------------------------------- 1 | name: craftml 2 | about: Clusering-based RAndom Forest of predictive Trees for extreme Multi-label Learning 3 | 4 | subcommands: 5 | - train: 6 | about: Train a new CRAFTML model 7 | args: 8 | - training_data: 9 | help: Path to training dataset file (in the format of the Extreme Classification Repository) 10 | index: 1 11 | required: true 12 | - cv_splits_path: 13 | help: Path to the k-fold cross validation splits file, with k space-separated columns of indices (starting from 1) for training splits. 14 | long: cv_splits_path 15 | takes_value: true 16 | value_name: PATH 17 | conflicts_with: 18 | - test_data 19 | - model_path 20 | - test_data: 21 | help: Path to test dataset file used to calculate metrics if provided (in the format of the Extreme Classification Repository) 22 | long: test_data 23 | takes_value: true 24 | value_name: PATH 25 | - test_trees_singly: 26 | help: Test forest tree by tree, freeing each before training the next to reduce memory usage. Model cannot be saved. 27 | long: test_trees_singly 28 | takes_value: false 29 | requires: 30 | - test_data 31 | conflicts_with: 32 | - model_path 33 | - cv_splits_path 34 | - out_path: 35 | help: Path to the which predictions will be written, if provided 36 | long: out_path 37 | takes_value: true 38 | value_name: PATH 39 | requires: 40 | - test_data 41 | conflicts_with: 42 | - cv_splits_path 43 | - model_path: 44 | help: Path to which the trained model will be saved if provided 45 | long: model_path 46 | takes_value: true 47 | value_name: PATH 48 | - n_threads: 49 | help: Number of worker threads. If 0, the number is selected automatically. 50 | long: n_threads 51 | takes_value: true 52 | default_value: "0" 53 | - n_trees: 54 | help: Number of trees in the random forest 55 | long: n_trees 56 | takes_value: true 57 | default_value: "50" 58 | - n_feature_buckets: 59 | help: Number of buckets into which features are hashed 60 | long: n_feature_buckets 61 | takes_value: true 62 | default_value: "10000" 63 | - n_label_buckets: 64 | help: Number of buckets into which labels are hashed 65 | long: n_label_buckets 66 | takes_value: true 67 | default_value: "10000" 68 | - leaf_max_size: 69 | help: Maximum number of distinct examples on a leaf node 70 | long: leaf_max_size 71 | takes_value: true 72 | default_value: "10" 73 | - k_clusters: 74 | help: Number of clusters on a branching node 75 | long: k_clusters 76 | takes_value: true 77 | default_value: "10" 78 | - cluster_sample_size: 79 | help: Number of examples drawn for clustering on a branching node 80 | long: cluster_sample_size 81 | takes_value: true 82 | default_value: "20000" 83 | - n_cluster_iters: 84 | help: Number of clustering iterations to run on each branching node 85 | long: n_cluster_iters 86 | takes_value: true 87 | default_value: "2" 88 | - centroid_preserve_ratio: 89 | help: A real number between 0 and 1, which is the ratio of entries with largest absoulte values to preserve. The rest of the entries are pruned. 90 | long: centroid_preserve_ratio 91 | takes_value: true 92 | default_value: "0.1" 93 | - centroid_min_n_preserve: 94 | help: The minimum number of entries to preserve from puning, regardless preserve ratio setting. 95 | long: centroid_min_n_preserve 96 | takes_value: true 97 | default_value: "10" 98 | - test: 99 | about: Test an existing CRAFTML model 100 | args: 101 | - model_path: 102 | help: Path to the trained model 103 | index: 1 104 | required: true 105 | - test_data: 106 | help: Path to test dataset file (in the format of the Extreme Classification Repository) 107 | index: 2 108 | required: true 109 | - out_path: 110 | help: Path to the which predictions will be written, if provided 111 | long: out_path 112 | takes_value: true 113 | value_name: PATH 114 | - n_threads: 115 | help: Number of worker threads. If 0, the number is selected automatically. 116 | long: n_threads 117 | takes_value: true 118 | default_value: "0" 119 | - k_top: 120 | help: Number of top predictions to write out for each test example 121 | requires: out_path 122 | long: k_top 123 | takes_value: true 124 | default_value: "5" 125 | -------------------------------------------------------------------------------- /src/bin/craftml.rs: -------------------------------------------------------------------------------- 1 | extern crate craftml; 2 | #[macro_use] 3 | extern crate clap; 4 | extern crate rayon; 5 | 6 | use craftml::data::{DataSet, DataSplits, Label}; 7 | use craftml::model::{eval, CraftmlModel, CraftmlTrainer}; 8 | use std::fs::File; 9 | use std::io::{BufReader, BufWriter, Write}; 10 | 11 | fn set_num_threads(arg_matches: &clap::ArgMatches) { 12 | let n_threads = arg_matches 13 | .value_of("n_threads") 14 | .and_then(|s| s.parse::().ok()) 15 | .expect("Failed to parse n_threads"); 16 | 17 | rayon::ThreadPoolBuilder::new() 18 | .num_threads(n_threads) 19 | .build_global() 20 | .unwrap(); 21 | } 22 | 23 | macro_rules! parse_trainer { 24 | ($m:ident; $( $v:ident ),+) => {{ 25 | let mut trainer = CraftmlTrainer::default(); 26 | $( 27 | if let Some($v) = $m.value_of(stringify!($v)) { 28 | trainer.$v = $v.parse().expect(&format!("Failed to parse {}", stringify!($v))); 29 | } 30 | )* 31 | trainer 32 | }}; 33 | } 34 | 35 | fn maybe_write_predictions_file(arg_matches: &clap::ArgMatches, predictions: &[Vec<(Label, f32)>]) { 36 | if let Some(out_path) = arg_matches.value_of("out_path") { 37 | let k_top = arg_matches 38 | .value_of("k_top") 39 | .and_then(|s| s.parse::().ok()) 40 | .expect("Failed to parse k_top"); 41 | 42 | let mut writer = 43 | BufWriter::new(File::create(out_path).expect("Failed to create output file")); 44 | for prediction in predictions { 45 | for (i, &(ref label, score)) in prediction.iter().take(k_top).enumerate() { 46 | if i > 0 { 47 | write!(&mut writer, "\t"); 48 | } 49 | write!(&mut writer, "{} {:.3}", label, score); 50 | } 51 | writeln!(&mut writer); 52 | } 53 | } 54 | } 55 | 56 | fn train(arg_matches: &clap::ArgMatches) { 57 | set_num_threads(&arg_matches); 58 | let trainer = parse_trainer!(arg_matches; 59 | n_trees, n_feature_buckets, n_label_buckets, leaf_max_size, k_clusters, 60 | cluster_sample_size, n_cluster_iters, centroid_preserve_ratio, centroid_min_n_preserve); 61 | 62 | let training_path = arg_matches.value_of("training_data").unwrap(); 63 | let training_dataset = 64 | DataSet::load_xc_repo_data_file(training_path).expect("Failed to load training data"); 65 | 66 | if let Some(cv_splits_path) = arg_matches.value_of("cv_splits_path") { 67 | let data_splits = DataSplits::parse_xc_repo_data_split_file(cv_splits_path) 68 | .expect("Failed to load splits"); 69 | eval::cross_validate(&training_dataset, &data_splits, &trainer); 70 | } else if arg_matches.is_present("test_trees_singly") { 71 | let test_path = arg_matches.value_of("test_data").unwrap(); 72 | let test_dataset = 73 | DataSet::load_xc_repo_data_file(test_path).expect("Failed to load test data"); 74 | 75 | let (predictions, _) = eval::test_trees_singly(&training_dataset, &test_dataset, &trainer); 76 | maybe_write_predictions_file(arg_matches, &predictions); 77 | } else { 78 | let model = trainer.train(&training_dataset); 79 | 80 | if let Some(model_path) = arg_matches.value_of("model_path") { 81 | let model_file = File::create(model_path).expect("Failed to create model file"); 82 | model 83 | .save(BufWriter::new(model_file)) 84 | .expect("Failed to save model"); 85 | } 86 | 87 | if let Some(test_path) = arg_matches.value_of("test_data") { 88 | let test_dataset = 89 | DataSet::load_xc_repo_data_file(test_path).expect("Failed to load test data"); 90 | let (predictions, _) = eval::test_all(&model, &test_dataset); 91 | maybe_write_predictions_file(arg_matches, &predictions); 92 | } 93 | } 94 | } 95 | 96 | fn test(arg_matches: &clap::ArgMatches) { 97 | set_num_threads(&arg_matches); 98 | let model_path = arg_matches.value_of("model_path").unwrap(); 99 | let model_file = File::open(model_path).expect("Failed to open model file"); 100 | let model = CraftmlModel::load(BufReader::new(model_file)).expect("Failed to load model"); 101 | 102 | let test_path = arg_matches.value_of("test_data").unwrap(); 103 | let test_dataset = 104 | DataSet::load_xc_repo_data_file(test_path).expect("Failed to load test data"); 105 | let (predictions, _) = eval::test_all(&model, &test_dataset); 106 | maybe_write_predictions_file(arg_matches, &predictions); 107 | } 108 | 109 | fn main() { 110 | simple_logger::init().unwrap(); 111 | 112 | let yaml = load_yaml!("cli.yml"); 113 | let arg_matches = clap::App::from_yaml(yaml).get_matches(); 114 | 115 | if let Some(arg_matches) = arg_matches.subcommand_matches("train") { 116 | train(&arg_matches); 117 | } else if let Some(arg_matches) = arg_matches.subcommand_matches("test") { 118 | test(&arg_matches); 119 | } else { 120 | println!("{}", arg_matches.usage()); 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /src/data/mod.rs: -------------------------------------------------------------------------------- 1 | use pbr::ProgressBar; 2 | use std::collections::HashSet; 3 | use std::fs::File; 4 | use std::io::prelude::*; 5 | use std::io::{BufReader, Error, ErrorKind, Result}; 6 | use time; 7 | 8 | mod sparse_vector; 9 | pub use self::sparse_vector::SparseVector; 10 | 11 | pub type Feature = u32; 12 | 13 | pub type Label = u32; 14 | 15 | #[derive(Clone, Debug, PartialEq)] 16 | pub struct Example { 17 | pub features: Vec<(Feature, f32)>, 18 | pub labels: Vec