├── .github └── workflows │ └── doc_deploy.yml ├── .gitignore ├── .pypirc ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── book.toml ├── book ├── .nojekyll ├── 404.html ├── FontAwesome │ ├── css │ │ └── font-awesome.css │ └── fonts │ │ ├── FontAwesome.ttf │ │ ├── fontawesome-webfont.eot │ │ ├── fontawesome-webfont.svg │ │ ├── fontawesome-webfont.ttf │ │ ├── fontawesome-webfont.woff │ │ └── fontawesome-webfont.woff2 ├── augmentations │ ├── cutmix.html │ ├── dsa.html │ ├── mixup.html │ └── overview.html ├── ayu-highlight.css ├── book.js ├── clipboard.min.js ├── config │ └── overview.html ├── contributing.html ├── css │ ├── chrome.css │ ├── general.css │ ├── print.css │ └── variables.css ├── datasets │ └── overview.html ├── elasticlunr.min.js ├── favicon.png ├── favicon.svg ├── fonts │ ├── OPEN-SANS-LICENSE.txt │ ├── SOURCE-CODE-PRO-LICENSE.txt │ ├── fonts.css │ ├── open-sans-v17-all-charsets-300.woff2 │ ├── open-sans-v17-all-charsets-300italic.woff2 │ ├── open-sans-v17-all-charsets-600.woff2 │ ├── open-sans-v17-all-charsets-600italic.woff2 │ ├── open-sans-v17-all-charsets-700.woff2 │ ├── open-sans-v17-all-charsets-700italic.woff2 │ ├── open-sans-v17-all-charsets-800.woff2 │ ├── open-sans-v17-all-charsets-800italic.woff2 │ ├── open-sans-v17-all-charsets-italic.woff2 │ ├── open-sans-v17-all-charsets-regular.woff2 │ └── source-code-pro-v11-all-charsets-500.woff2 ├── getting-started │ ├── installation.html │ └── quick-start.html ├── highlight.css ├── highlight.js ├── index.html ├── introduction.html ├── mark.min.js ├── metrics │ ├── ars.html │ ├── general.html │ ├── lrs-hard-label.html │ ├── lrs-soft-label.html │ └── overview.html ├── models │ ├── alexnet.html │ ├── convnet.html │ ├── lenet.html │ ├── mlp.html │ ├── overview.html │ ├── resnet.html │ └── vgg.html ├── print.html ├── searcher.js ├── searchindex.js ├── searchindex.json ├── static │ ├── configurations.png │ ├── history.png │ ├── logo.png │ └── team │ │ └── zekai.jpg ├── toc.html ├── toc.js └── tomorrow-night.css ├── configs ├── Demo_ARS.yaml ├── Demo_LRS_Hard_Label.yaml └── Demo_LRS_Soft_Label.yaml ├── ddranking ├── __init__.py ├── aug │ ├── __init__.py │ ├── cutmix.py │ ├── dsa.py │ ├── mixup.py │ └── zca.py ├── config │ ├── __init__.py │ └── user_config.py ├── loss │ ├── __init__.py │ ├── kl.py │ ├── mse_gt.py │ └── sce.py ├── metrics │ ├── __init__.py │ ├── aug_robust.py │ ├── general.py │ ├── hard_label.py │ └── soft_label.py └── utils │ ├── __init__.py │ ├── data.py │ ├── meter.py │ ├── misc.py │ ├── model.py │ ├── networks.py │ └── train_and_eval.py ├── demo_aug.py ├── demo_hard.py ├── demo_soft.py ├── dist ├── ddranking-0.2.0-py3-none-any.whl └── ddranking-0.2.0.tar.gz ├── doc ├── SUMMARY.md ├── augmentations │ ├── cutmix.md │ ├── dsa.md │ ├── mixup.md │ ├── overview.md │ └── torchvision.md ├── config │ └── overview.md ├── contributing.md ├── datasets │ └── overview.md ├── getting-started │ ├── installation.md │ └── quick-start.md ├── introduction.md ├── metrics │ ├── ars.md │ ├── general.md │ ├── lrs-hard-label.md │ ├── lrs-soft-label.md │ └── overview.md ├── models │ ├── alexnet.md │ ├── convnet.md │ ├── lenet.md │ ├── mlp.md │ ├── overview.md │ ├── resnet.md │ └── vgg.md └── static │ ├── configurations.png │ ├── history.png │ ├── logo.png │ └── team │ └── zekai.jpg ├── index.html ├── pyproject.toml ├── setup.py └── static ├── configurations.png ├── history.png └── logo.png /.github/workflows/doc_deploy.yml: -------------------------------------------------------------------------------- 1 | name: Deploy 2 | on: 3 | push: 4 | branches: 5 | - main 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | permissions: 11 | contents: write # To push a branch 12 | pages: write # To push to a GitHub Pages site 13 | id-token: write # To update the deployment status 14 | steps: 15 | - uses: actions/checkout@v4 16 | with: 17 | fetch-depth: 0 18 | - name: Install latest mdbook 19 | run: | 20 | tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name') 21 | url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz" 22 | mkdir mdbook 23 | curl -sSL $url | tar -xz --directory=./mdbook 24 | echo `pwd`/mdbook >> $GITHUB_PATH 25 | - name: Build Book 26 | run: | 27 | # This assumes your book is in the root of your repository. 28 | # Just add a `cd` here if you need to change to another directory. 29 | mdbook build 30 | - name: Setup Pages 31 | uses: actions/configure-pages@v4 32 | - name: Upload artifact 33 | uses: actions/upload-pages-artifact@v3 34 | with: 35 | # Upload entire repository 36 | path: 'book' 37 | - name: Deploy to GitHub Pages 38 | id: deployment 39 | uses: actions/deploy-pages@v4 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg 2 | *.eggs 3 | *.eggs-info 4 | *.eggs/ 5 | eggs-info/* 6 | dd_ranking.eggs-info/* 7 | PKG-INFO 8 | build/ 9 | *.pyc 10 | *.txt -------------------------------------------------------------------------------- /.pypirc: -------------------------------------------------------------------------------- 1 | [distutils] 2 | index-servers =pypi 3 | 4 | [pypi] 5 | repository:https://upload.pypi.org/legacy/ 6 | username:zekai_li 7 | password:Lzk@Cwq020304 -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Welcome! We are glad that you by willing to contribute to the field of dataset distillation. 2 | 3 | - **New Baselines**: If you would like to report new baselines, please submit them by creating a pull request. The exact format is below: name of the baseline, code link, [paper link and score run using this tool]. 4 | 5 | - **New Components**: If you would like to integrate new components, such as new model architectures, new data augmentation methods, and new soft label strategies, please submit them by creating a pull request. 6 | 7 | - **Issues**: If you want to submit issues, you are encouraged to submit yes directly in issues. 8 | 9 | - **Appeal**: If you want to appeal for the score of your method, please submit an issue with your code and a detailed readme file of how to reproduce your results. We tried our best to replicate all methods in the leaderboard based on their papers and open-source code. We are sorry if we miss some details and will be grateful if you can help us improve the leaderboard. 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Data Intelligence Team 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | This project incorporates code from the following repositories: 24 | 25 | 1. [DatasetCondensation](https://github.com/VICO-UoE/DatasetCondensation) 26 | License: MIT 27 | 28 | 2. [mtt-distillation](https://github.com/GeorgeCazenavette/mtt-distillation) 29 | License: MIT 30 | 31 | 3. [SRe2L](https://github.com/VILA-Lab/SRe2L/tree/main/SRe2L) 32 | License: MIT 33 | 34 | 4. [RDED](https://github.com/LINs-lab/RDED) 35 | License: Apache License 2.0 36 | 37 | The licenses of the above repositories are compatible with the license of this project. 38 | 39 | Copyright (c) 2021 The Python Packaging Authority 40 | 41 | Permission is hereby granted, free of charge, to any person obtaining a copy 42 | of this software and associated documentation files (the "Software"), to deal 43 | in the Software without restriction, including without limitation the rights 44 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 45 | copies of the Software, and to permit persons to whom the Software is 46 | furnished to do so, subject to the following conditions: 47 | 48 | The above copyright notice and this permission notice shall be included in all 49 | copies or substantial portions of the Software. 50 | 51 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 52 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 53 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 54 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 55 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 56 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 57 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | recursive-include config * 4 | recursive-include dd_ranking * 5 | recursive-exclude __pycache__ * 6 | recursive-exclude *.log * -------------------------------------------------------------------------------- /book.toml: -------------------------------------------------------------------------------- 1 | # Documentation 2 | # * mdbook https://rust-lang.github.io/mdBook/ 3 | # * template https://github.com/kg4zow/mdbook-template/ 4 | 5 | [book] 6 | authors = ["DD-Ranking Team"] 7 | language = "en" 8 | multilingual = false 9 | src = "doc" 10 | title = "DD-Ranking API Documentation" 11 | 12 | [output.html] 13 | mathjax-support = true -------------------------------------------------------------------------------- /book/.nojekyll: -------------------------------------------------------------------------------- 1 | This file makes sure that Github Pages doesn't process mdBook's output. 2 | -------------------------------------------------------------------------------- /book/404.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Page not found - DD-Ranking API Documentation 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 42 | 43 | 44 | 45 | 46 |
47 | 48 | 62 | 63 | 64 | 73 | 74 | 75 | 76 | 77 | 90 | 91 | 101 | 102 |
103 | 104 |
105 | 106 | 135 | 136 | 146 | 147 | 148 | 155 | 156 |
157 |
158 |

Document not found (404)

159 |

This URL is invalid, sorry. Please use the navigation bar or search to continue.

160 | 161 |
162 | 163 | 169 |
170 |
171 | 172 | 175 | 176 |
177 | 178 | 179 | 194 | 195 | 196 | 197 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 |
214 | 215 | 216 | -------------------------------------------------------------------------------- /book/FontAwesome/fonts/FontAwesome.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/book/FontAwesome/fonts/FontAwesome.ttf -------------------------------------------------------------------------------- /book/FontAwesome/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/book/FontAwesome/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /book/FontAwesome/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/book/FontAwesome/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /book/FontAwesome/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/book/FontAwesome/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /book/FontAwesome/fonts/fontawesome-webfont.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/book/FontAwesome/fonts/fontawesome-webfont.woff2 -------------------------------------------------------------------------------- /book/ayu-highlight.css: -------------------------------------------------------------------------------- 1 | /* 2 | Based off of the Ayu theme 3 | Original by Dempfi (https://github.com/dempfi/ayu) 4 | */ 5 | 6 | .hljs { 7 | display: block; 8 | overflow-x: auto; 9 | background: #191f26; 10 | color: #e6e1cf; 11 | } 12 | 13 | .hljs-comment, 14 | .hljs-quote { 15 | color: #5c6773; 16 | font-style: italic; 17 | } 18 | 19 | .hljs-variable, 20 | .hljs-template-variable, 21 | .hljs-attribute, 22 | .hljs-attr, 23 | .hljs-regexp, 24 | .hljs-link, 25 | .hljs-selector-id, 26 | .hljs-selector-class { 27 | color: #ff7733; 28 | } 29 | 30 | .hljs-number, 31 | .hljs-meta, 32 | .hljs-builtin-name, 33 | .hljs-literal, 34 | .hljs-type, 35 | .hljs-params { 36 | color: #ffee99; 37 | } 38 | 39 | .hljs-string, 40 | .hljs-bullet { 41 | color: #b8cc52; 42 | } 43 | 44 | .hljs-title, 45 | .hljs-built_in, 46 | .hljs-section { 47 | color: #ffb454; 48 | } 49 | 50 | .hljs-keyword, 51 | .hljs-selector-tag, 52 | .hljs-symbol { 53 | color: #ff7733; 54 | } 55 | 56 | .hljs-name { 57 | color: #36a3d9; 58 | } 59 | 60 | .hljs-tag { 61 | color: #00568d; 62 | } 63 | 64 | .hljs-emphasis { 65 | font-style: italic; 66 | } 67 | 68 | .hljs-strong { 69 | font-weight: bold; 70 | } 71 | 72 | .hljs-addition { 73 | color: #91b362; 74 | } 75 | 76 | .hljs-deletion { 77 | color: #d96c75; 78 | } 79 | -------------------------------------------------------------------------------- /book/clipboard.min.js: -------------------------------------------------------------------------------- 1 | /*! 2 | * clipboard.js v2.0.4 3 | * https://zenorocha.github.io/clipboard.js 4 | * 5 | * Licensed MIT © Zeno Rocha 6 | */ 7 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.ClipboardJS=e():t.ClipboardJS=e()}(this,function(){return function(n){var o={};function r(t){if(o[t])return o[t].exports;var e=o[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,r),e.l=!0,e.exports}return r.m=n,r.c=o,r.d=function(t,e,n){r.o(t,e)||Object.defineProperty(t,e,{enumerable:!0,get:n})},r.r=function(t){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(t,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(t,"__esModule",{value:!0})},r.t=function(e,t){if(1&t&&(e=r(e)),8&t)return e;if(4&t&&"object"==typeof e&&e&&e.__esModule)return e;var n=Object.create(null);if(r.r(n),Object.defineProperty(n,"default",{enumerable:!0,value:e}),2&t&&"string"!=typeof e)for(var o in e)r.d(n,o,function(t){return e[t]}.bind(null,o));return n},r.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return r.d(e,"a",e),e},r.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},r.p="",r(r.s=0)}([function(t,e,n){"use strict";var r="function"==typeof Symbol&&"symbol"==typeof Symbol.iterator?function(t){return typeof t}:function(t){return t&&"function"==typeof Symbol&&t.constructor===Symbol&&t!==Symbol.prototype?"symbol":typeof t},i=function(){function o(t,e){for(var n=0;n .buttons { 30 | z-index: 2; 31 | } 32 | 33 | a, a:visited, a:active, a:hover { 34 | color: #4183c4; 35 | text-decoration: none; 36 | } 37 | 38 | h1, h2, h3, h4, h5, h6 { 39 | page-break-inside: avoid; 40 | page-break-after: avoid; 41 | } 42 | 43 | pre, code { 44 | page-break-inside: avoid; 45 | white-space: pre-wrap; 46 | } 47 | 48 | .fa { 49 | display: none !important; 50 | } 51 | -------------------------------------------------------------------------------- /book/css/variables.css: -------------------------------------------------------------------------------- 1 | 2 | /* Globals */ 3 | 4 | :root { 5 | --sidebar-width: 300px; 6 | --sidebar-resize-indicator-width: 8px; 7 | --sidebar-resize-indicator-space: 2px; 8 | --page-padding: 15px; 9 | --content-max-width: 750px; 10 | --menu-bar-height: 50px; 11 | --mono-font: "Source Code Pro", Consolas, "Ubuntu Mono", Menlo, "DejaVu Sans Mono", monospace, monospace; 12 | --code-font-size: 0.875em /* please adjust the ace font size accordingly in editor.js */ 13 | } 14 | 15 | /* Themes */ 16 | 17 | .ayu { 18 | --bg: hsl(210, 25%, 8%); 19 | --fg: #c5c5c5; 20 | 21 | --sidebar-bg: #14191f; 22 | --sidebar-fg: #c8c9db; 23 | --sidebar-non-existant: #5c6773; 24 | --sidebar-active: #ffb454; 25 | --sidebar-spacer: #2d334f; 26 | 27 | --scrollbar: var(--sidebar-fg); 28 | 29 | --icons: #737480; 30 | --icons-hover: #b7b9cc; 31 | 32 | --links: #0096cf; 33 | 34 | --inline-code-color: #ffb454; 35 | 36 | --theme-popup-bg: #14191f; 37 | --theme-popup-border: #5c6773; 38 | --theme-hover: #191f26; 39 | 40 | --quote-bg: hsl(226, 15%, 17%); 41 | --quote-border: hsl(226, 15%, 22%); 42 | 43 | --warning-border: #ff8e00; 44 | 45 | --table-border-color: hsl(210, 25%, 13%); 46 | --table-header-bg: hsl(210, 25%, 28%); 47 | --table-alternate-bg: hsl(210, 25%, 11%); 48 | 49 | --searchbar-border-color: #848484; 50 | --searchbar-bg: #424242; 51 | --searchbar-fg: #fff; 52 | --searchbar-shadow-color: #d4c89f; 53 | --searchresults-header-fg: #666; 54 | --searchresults-border-color: #888; 55 | --searchresults-li-bg: #252932; 56 | --search-mark-bg: #e3b171; 57 | 58 | --color-scheme: dark; 59 | 60 | /* Same as `--icons` */ 61 | --copy-button-filter: invert(45%) sepia(6%) saturate(621%) hue-rotate(198deg) brightness(99%) contrast(85%); 62 | /* Same as `--sidebar-active` */ 63 | --copy-button-filter-hover: invert(68%) sepia(55%) saturate(531%) hue-rotate(341deg) brightness(104%) contrast(101%); 64 | } 65 | 66 | .coal { 67 | --bg: hsl(200, 7%, 8%); 68 | --fg: #98a3ad; 69 | 70 | --sidebar-bg: #292c2f; 71 | --sidebar-fg: #a1adb8; 72 | --sidebar-non-existant: #505254; 73 | --sidebar-active: #3473ad; 74 | --sidebar-spacer: #393939; 75 | 76 | --scrollbar: var(--sidebar-fg); 77 | 78 | --icons: #43484d; 79 | --icons-hover: #b3c0cc; 80 | 81 | --links: #2b79a2; 82 | 83 | --inline-code-color: #c5c8c6; 84 | 85 | --theme-popup-bg: #141617; 86 | --theme-popup-border: #43484d; 87 | --theme-hover: #1f2124; 88 | 89 | --quote-bg: hsl(234, 21%, 18%); 90 | --quote-border: hsl(234, 21%, 23%); 91 | 92 | --warning-border: #ff8e00; 93 | 94 | --table-border-color: hsl(200, 7%, 13%); 95 | --table-header-bg: hsl(200, 7%, 28%); 96 | --table-alternate-bg: hsl(200, 7%, 11%); 97 | 98 | --searchbar-border-color: #aaa; 99 | --searchbar-bg: #b7b7b7; 100 | --searchbar-fg: #000; 101 | --searchbar-shadow-color: #aaa; 102 | --searchresults-header-fg: #666; 103 | --searchresults-border-color: #98a3ad; 104 | --searchresults-li-bg: #2b2b2f; 105 | --search-mark-bg: #355c7d; 106 | 107 | --color-scheme: dark; 108 | 109 | /* Same as `--icons` */ 110 | --copy-button-filter: invert(26%) sepia(8%) saturate(575%) hue-rotate(169deg) brightness(87%) contrast(82%); 111 | /* Same as `--sidebar-active` */ 112 | --copy-button-filter-hover: invert(36%) sepia(70%) saturate(503%) hue-rotate(167deg) brightness(98%) contrast(89%); 113 | } 114 | 115 | .light, html:not(.js) { 116 | --bg: hsl(0, 0%, 100%); 117 | --fg: hsl(0, 0%, 0%); 118 | 119 | --sidebar-bg: #fafafa; 120 | --sidebar-fg: hsl(0, 0%, 0%); 121 | --sidebar-non-existant: #aaaaaa; 122 | --sidebar-active: #1f1fff; 123 | --sidebar-spacer: #f4f4f4; 124 | 125 | --scrollbar: #8F8F8F; 126 | 127 | --icons: #747474; 128 | --icons-hover: #000000; 129 | 130 | --links: #20609f; 131 | 132 | --inline-code-color: #301900; 133 | 134 | --theme-popup-bg: #fafafa; 135 | --theme-popup-border: #cccccc; 136 | --theme-hover: #e6e6e6; 137 | 138 | --quote-bg: hsl(197, 37%, 96%); 139 | --quote-border: hsl(197, 37%, 91%); 140 | 141 | --warning-border: #ff8e00; 142 | 143 | --table-border-color: hsl(0, 0%, 95%); 144 | --table-header-bg: hsl(0, 0%, 80%); 145 | --table-alternate-bg: hsl(0, 0%, 97%); 146 | 147 | --searchbar-border-color: #aaa; 148 | --searchbar-bg: #fafafa; 149 | --searchbar-fg: #000; 150 | --searchbar-shadow-color: #aaa; 151 | --searchresults-header-fg: #666; 152 | --searchresults-border-color: #888; 153 | --searchresults-li-bg: #e4f2fe; 154 | --search-mark-bg: #a2cff5; 155 | 156 | --color-scheme: light; 157 | 158 | /* Same as `--icons` */ 159 | --copy-button-filter: invert(45.49%); 160 | /* Same as `--sidebar-active` */ 161 | --copy-button-filter-hover: invert(14%) sepia(93%) saturate(4250%) hue-rotate(243deg) brightness(99%) contrast(130%); 162 | } 163 | 164 | .navy { 165 | --bg: hsl(226, 23%, 11%); 166 | --fg: #bcbdd0; 167 | 168 | --sidebar-bg: #282d3f; 169 | --sidebar-fg: #c8c9db; 170 | --sidebar-non-existant: #505274; 171 | --sidebar-active: #2b79a2; 172 | --sidebar-spacer: #2d334f; 173 | 174 | --scrollbar: var(--sidebar-fg); 175 | 176 | --icons: #737480; 177 | --icons-hover: #b7b9cc; 178 | 179 | --links: #2b79a2; 180 | 181 | --inline-code-color: #c5c8c6; 182 | 183 | --theme-popup-bg: #161923; 184 | --theme-popup-border: #737480; 185 | --theme-hover: #282e40; 186 | 187 | --quote-bg: hsl(226, 15%, 17%); 188 | --quote-border: hsl(226, 15%, 22%); 189 | 190 | --warning-border: #ff8e00; 191 | 192 | --table-border-color: hsl(226, 23%, 16%); 193 | --table-header-bg: hsl(226, 23%, 31%); 194 | --table-alternate-bg: hsl(226, 23%, 14%); 195 | 196 | --searchbar-border-color: #aaa; 197 | --searchbar-bg: #aeaec6; 198 | --searchbar-fg: #000; 199 | --searchbar-shadow-color: #aaa; 200 | --searchresults-header-fg: #5f5f71; 201 | --searchresults-border-color: #5c5c68; 202 | --searchresults-li-bg: #242430; 203 | --search-mark-bg: #a2cff5; 204 | 205 | --color-scheme: dark; 206 | 207 | /* Same as `--icons` */ 208 | --copy-button-filter: invert(51%) sepia(10%) saturate(393%) hue-rotate(198deg) brightness(86%) contrast(87%); 209 | /* Same as `--sidebar-active` */ 210 | --copy-button-filter-hover: invert(46%) sepia(20%) saturate(1537%) hue-rotate(156deg) brightness(85%) contrast(90%); 211 | } 212 | 213 | .rust { 214 | --bg: hsl(60, 9%, 87%); 215 | --fg: #262625; 216 | 217 | --sidebar-bg: #3b2e2a; 218 | --sidebar-fg: #c8c9db; 219 | --sidebar-non-existant: #505254; 220 | --sidebar-active: #e69f67; 221 | --sidebar-spacer: #45373a; 222 | 223 | --scrollbar: var(--sidebar-fg); 224 | 225 | --icons: #737480; 226 | --icons-hover: #262625; 227 | 228 | --links: #2b79a2; 229 | 230 | --inline-code-color: #6e6b5e; 231 | 232 | --theme-popup-bg: #e1e1db; 233 | --theme-popup-border: #b38f6b; 234 | --theme-hover: #99908a; 235 | 236 | --quote-bg: hsl(60, 5%, 75%); 237 | --quote-border: hsl(60, 5%, 70%); 238 | 239 | --warning-border: #ff8e00; 240 | 241 | --table-border-color: hsl(60, 9%, 82%); 242 | --table-header-bg: #b3a497; 243 | --table-alternate-bg: hsl(60, 9%, 84%); 244 | 245 | --searchbar-border-color: #aaa; 246 | --searchbar-bg: #fafafa; 247 | --searchbar-fg: #000; 248 | --searchbar-shadow-color: #aaa; 249 | --searchresults-header-fg: #666; 250 | --searchresults-border-color: #888; 251 | --searchresults-li-bg: #dec2a2; 252 | --search-mark-bg: #e69f67; 253 | 254 | /* Same as `--icons` */ 255 | --copy-button-filter: invert(51%) sepia(10%) saturate(393%) hue-rotate(198deg) brightness(86%) contrast(87%); 256 | /* Same as `--sidebar-active` */ 257 | --copy-button-filter-hover: invert(77%) sepia(16%) saturate(1798%) hue-rotate(328deg) brightness(98%) contrast(83%); 258 | } 259 | 260 | @media (prefers-color-scheme: dark) { 261 | html:not(.js) { 262 | --bg: hsl(200, 7%, 8%); 263 | --fg: #98a3ad; 264 | 265 | --sidebar-bg: #292c2f; 266 | --sidebar-fg: #a1adb8; 267 | --sidebar-non-existant: #505254; 268 | --sidebar-active: #3473ad; 269 | --sidebar-spacer: #393939; 270 | 271 | --scrollbar: var(--sidebar-fg); 272 | 273 | --icons: #43484d; 274 | --icons-hover: #b3c0cc; 275 | 276 | --links: #2b79a2; 277 | 278 | --inline-code-color: #c5c8c6; 279 | 280 | --theme-popup-bg: #141617; 281 | --theme-popup-border: #43484d; 282 | --theme-hover: #1f2124; 283 | 284 | --quote-bg: hsl(234, 21%, 18%); 285 | --quote-border: hsl(234, 21%, 23%); 286 | 287 | --warning-border: #ff8e00; 288 | 289 | --table-border-color: hsl(200, 7%, 13%); 290 | --table-header-bg: hsl(200, 7%, 28%); 291 | --table-alternate-bg: hsl(200, 7%, 11%); 292 | 293 | --searchbar-border-color: #aaa; 294 | --searchbar-bg: #b7b7b7; 295 | --searchbar-fg: #000; 296 | --searchbar-shadow-color: #aaa; 297 | --searchresults-header-fg: #666; 298 | --searchresults-border-color: #98a3ad; 299 | --searchresults-li-bg: #2b2b2f; 300 | --search-mark-bg: #355c7d; 301 | 302 | --color-scheme: dark; 303 | 304 | /* Same as `--icons` */ 305 | --copy-button-filter: invert(26%) sepia(8%) saturate(575%) hue-rotate(169deg) brightness(87%) contrast(82%); 306 | /* Same as `--sidebar-active` */ 307 | --copy-button-filter-hover: invert(36%) sepia(70%) saturate(503%) hue-rotate(167deg) brightness(98%) contrast(89%); 308 | } 309 | } 310 | -------------------------------------------------------------------------------- /book/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/book/favicon.png -------------------------------------------------------------------------------- /book/favicon.svg: -------------------------------------------------------------------------------- 1 | 2 | 7 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /book/fonts/SOURCE-CODE-PRO-LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright 2010, 2012 Adobe Systems Incorporated (http://www.adobe.com/), with Reserved Font Name 'Source'. All Rights Reserved. Source is a trademark of Adobe Systems Incorporated in the United States and/or other countries. 2 | 3 | This Font Software is licensed under the SIL Open Font License, Version 1.1. 4 | This license is copied below, and is also available with a FAQ at: 5 | http://scripts.sil.org/OFL 6 | 7 | 8 | ----------------------------------------------------------- 9 | SIL OPEN FONT LICENSE Version 1.1 - 26 February 2007 10 | ----------------------------------------------------------- 11 | 12 | PREAMBLE 13 | The goals of the Open Font License (OFL) are to stimulate worldwide 14 | development of collaborative font projects, to support the font creation 15 | efforts of academic and linguistic communities, and to provide a free and 16 | open framework in which fonts may be shared and improved in partnership 17 | with others. 18 | 19 | The OFL allows the licensed fonts to be used, studied, modified and 20 | redistributed freely as long as they are not sold by themselves. The 21 | fonts, including any derivative works, can be bundled, embedded, 22 | redistributed and/or sold with any software provided that any reserved 23 | names are not used by derivative works. The fonts and derivatives, 24 | however, cannot be released under any other type of license. The 25 | requirement for fonts to remain under this license does not apply 26 | to any document created using the fonts or their derivatives. 27 | 28 | DEFINITIONS 29 | "Font Software" refers to the set of files released by the Copyright 30 | Holder(s) under this license and clearly marked as such. This may 31 | include source files, build scripts and documentation. 32 | 33 | "Reserved Font Name" refers to any names specified as such after the 34 | copyright statement(s). 35 | 36 | "Original Version" refers to the collection of Font Software components as 37 | distributed by the Copyright Holder(s). 38 | 39 | "Modified Version" refers to any derivative made by adding to, deleting, 40 | or substituting -- in part or in whole -- any of the components of the 41 | Original Version, by changing formats or by porting the Font Software to a 42 | new environment. 43 | 44 | "Author" refers to any designer, engineer, programmer, technical 45 | writer or other person who contributed to the Font Software. 46 | 47 | PERMISSION & CONDITIONS 48 | Permission is hereby granted, free of charge, to any person obtaining 49 | a copy of the Font Software, to use, study, copy, merge, embed, modify, 50 | redistribute, and sell modified and unmodified copies of the Font 51 | Software, subject to the following conditions: 52 | 53 | 1) Neither the Font Software nor any of its individual components, 54 | in Original or Modified Versions, may be sold by itself. 55 | 56 | 2) Original or Modified Versions of the Font Software may be bundled, 57 | redistributed and/or sold with any software, provided that each copy 58 | contains the above copyright notice and this license. These can be 59 | included either as stand-alone text files, human-readable headers or 60 | in the appropriate machine-readable metadata fields within text or 61 | binary files as long as those fields can be easily viewed by the user. 62 | 63 | 3) No Modified Version of the Font Software may use the Reserved Font 64 | Name(s) unless explicit written permission is granted by the corresponding 65 | Copyright Holder. This restriction only applies to the primary font name as 66 | presented to the users. 67 | 68 | 4) The name(s) of the Copyright Holder(s) or the Author(s) of the Font 69 | Software shall not be used to promote, endorse or advertise any 70 | Modified Version, except to acknowledge the contribution(s) of the 71 | Copyright Holder(s) and the Author(s) or with their explicit written 72 | permission. 73 | 74 | 5) The Font Software, modified or unmodified, in part or in whole, 75 | must be distributed entirely under this license, and must not be 76 | distributed under any other license. The requirement for fonts to 77 | remain under this license does not apply to any document created 78 | using the Font Software. 79 | 80 | TERMINATION 81 | This license becomes null and void if any of the above conditions are 82 | not met. 83 | 84 | DISCLAIMER 85 | THE FONT SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 86 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY WARRANTIES OF 87 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT 88 | OF COPYRIGHT, PATENT, TRADEMARK, OR OTHER RIGHT. IN NO EVENT SHALL THE 89 | COPYRIGHT HOLDER BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 90 | INCLUDING ANY GENERAL, SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL 91 | DAMAGES, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 92 | FROM, OUT OF THE USE OR INABILITY TO USE THE FONT SOFTWARE OR FROM 93 | OTHER DEALINGS IN THE FONT SOFTWARE. 94 | -------------------------------------------------------------------------------- /book/fonts/fonts.css: -------------------------------------------------------------------------------- 1 | /* Open Sans is licensed under the Apache License, Version 2.0. See http://www.apache.org/licenses/LICENSE-2.0 */ 2 | /* Source Code Pro is under the Open Font License. See https://scripts.sil.org/cms/scripts/page.php?site_id=nrsi&id=OFL */ 3 | 4 | /* open-sans-300 - latin_vietnamese_latin-ext_greek-ext_greek_cyrillic-ext_cyrillic */ 5 | @font-face { 6 | font-family: 'Open Sans'; 7 | font-style: normal; 8 | font-weight: 300; 9 | src: local('Open Sans Light'), local('OpenSans-Light'), 10 | url('open-sans-v17-all-charsets-300.woff2') format('woff2'); 11 | } 12 | 13 | /* open-sans-300italic - latin_vietnamese_latin-ext_greek-ext_greek_cyrillic-ext_cyrillic */ 14 | @font-face { 15 | font-family: 'Open Sans'; 16 | font-style: italic; 17 | font-weight: 300; 18 | src: local('Open Sans Light Italic'), local('OpenSans-LightItalic'), 19 | url('open-sans-v17-all-charsets-300italic.woff2') format('woff2'); 20 | } 21 | 22 | /* open-sans-regular - latin_vietnamese_latin-ext_greek-ext_greek_cyrillic-ext_cyrillic */ 23 | @font-face { 24 | font-family: 'Open Sans'; 25 | font-style: normal; 26 | font-weight: 400; 27 | src: local('Open Sans Regular'), local('OpenSans-Regular'), 28 | url('open-sans-v17-all-charsets-regular.woff2') format('woff2'); 29 | } 30 | 31 | /* open-sans-italic - latin_vietnamese_latin-ext_greek-ext_greek_cyrillic-ext_cyrillic */ 32 | @font-face { 33 | font-family: 'Open Sans'; 34 | font-style: italic; 35 | font-weight: 400; 36 | src: local('Open Sans Italic'), local('OpenSans-Italic'), 37 | url('open-sans-v17-all-charsets-italic.woff2') format('woff2'); 38 | } 39 | 40 | /* open-sans-600 - latin_vietnamese_latin-ext_greek-ext_greek_cyrillic-ext_cyrillic */ 41 | @font-face { 42 | font-family: 'Open Sans'; 43 | font-style: normal; 44 | font-weight: 600; 45 | src: local('Open Sans SemiBold'), local('OpenSans-SemiBold'), 46 | url('open-sans-v17-all-charsets-600.woff2') format('woff2'); 47 | } 48 | 49 | /* open-sans-600italic - latin_vietnamese_latin-ext_greek-ext_greek_cyrillic-ext_cyrillic */ 50 | @font-face { 51 | font-family: 'Open Sans'; 52 | font-style: italic; 53 | font-weight: 600; 54 | src: local('Open Sans SemiBold Italic'), local('OpenSans-SemiBoldItalic'), 55 | url('open-sans-v17-all-charsets-600italic.woff2') format('woff2'); 56 | } 57 | 58 | /* open-sans-700 - latin_vietnamese_latin-ext_greek-ext_greek_cyrillic-ext_cyrillic */ 59 | @font-face { 60 | font-family: 'Open Sans'; 61 | font-style: normal; 62 | font-weight: 700; 63 | src: local('Open Sans Bold'), local('OpenSans-Bold'), 64 | url('open-sans-v17-all-charsets-700.woff2') format('woff2'); 65 | } 66 | 67 | /* open-sans-700italic - latin_vietnamese_latin-ext_greek-ext_greek_cyrillic-ext_cyrillic */ 68 | @font-face { 69 | font-family: 'Open Sans'; 70 | font-style: italic; 71 | font-weight: 700; 72 | src: local('Open Sans Bold Italic'), local('OpenSans-BoldItalic'), 73 | url('open-sans-v17-all-charsets-700italic.woff2') format('woff2'); 74 | } 75 | 76 | /* open-sans-800 - latin_vietnamese_latin-ext_greek-ext_greek_cyrillic-ext_cyrillic */ 77 | @font-face { 78 | font-family: 'Open Sans'; 79 | font-style: normal; 80 | font-weight: 800; 81 | src: local('Open Sans ExtraBold'), local('OpenSans-ExtraBold'), 82 | url('open-sans-v17-all-charsets-800.woff2') format('woff2'); 83 | } 84 | 85 | /* open-sans-800italic - latin_vietnamese_latin-ext_greek-ext_greek_cyrillic-ext_cyrillic */ 86 | @font-face { 87 | font-family: 'Open Sans'; 88 | font-style: italic; 89 | font-weight: 800; 90 | src: local('Open Sans ExtraBold Italic'), local('OpenSans-ExtraBoldItalic'), 91 | url('open-sans-v17-all-charsets-800italic.woff2') format('woff2'); 92 | } 93 | 94 | /* source-code-pro-500 - latin_vietnamese_latin-ext_greek_cyrillic-ext_cyrillic */ 95 | @font-face { 96 | font-family: 'Source Code Pro'; 97 | font-style: normal; 98 | font-weight: 500; 99 | src: url('source-code-pro-v11-all-charsets-500.woff2') format('woff2'); 100 | } 101 | -------------------------------------------------------------------------------- /book/fonts/open-sans-v17-all-charsets-300.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/book/fonts/open-sans-v17-all-charsets-300.woff2 -------------------------------------------------------------------------------- /book/fonts/open-sans-v17-all-charsets-300italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/book/fonts/open-sans-v17-all-charsets-300italic.woff2 -------------------------------------------------------------------------------- /book/fonts/open-sans-v17-all-charsets-600.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/book/fonts/open-sans-v17-all-charsets-600.woff2 -------------------------------------------------------------------------------- /book/fonts/open-sans-v17-all-charsets-600italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/book/fonts/open-sans-v17-all-charsets-600italic.woff2 -------------------------------------------------------------------------------- /book/fonts/open-sans-v17-all-charsets-700.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/book/fonts/open-sans-v17-all-charsets-700.woff2 -------------------------------------------------------------------------------- /book/fonts/open-sans-v17-all-charsets-700italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/book/fonts/open-sans-v17-all-charsets-700italic.woff2 -------------------------------------------------------------------------------- /book/fonts/open-sans-v17-all-charsets-800.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/book/fonts/open-sans-v17-all-charsets-800.woff2 -------------------------------------------------------------------------------- /book/fonts/open-sans-v17-all-charsets-800italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/book/fonts/open-sans-v17-all-charsets-800italic.woff2 -------------------------------------------------------------------------------- /book/fonts/open-sans-v17-all-charsets-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/book/fonts/open-sans-v17-all-charsets-italic.woff2 -------------------------------------------------------------------------------- /book/fonts/open-sans-v17-all-charsets-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/book/fonts/open-sans-v17-all-charsets-regular.woff2 -------------------------------------------------------------------------------- /book/fonts/source-code-pro-v11-all-charsets-500.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/book/fonts/source-code-pro-v11-all-charsets-500.woff2 -------------------------------------------------------------------------------- /book/highlight.css: -------------------------------------------------------------------------------- 1 | /* 2 | * An increased contrast highlighting scheme loosely based on the 3 | * "Base16 Atelier Dune Light" theme by Bram de Haan 4 | * (http://atelierbram.github.io/syntax-highlighting/atelier-schemes/dune) 5 | * Original Base16 color scheme by Chris Kempson 6 | * (https://github.com/chriskempson/base16) 7 | */ 8 | 9 | /* Comment */ 10 | .hljs-comment, 11 | .hljs-quote { 12 | color: #575757; 13 | } 14 | 15 | /* Red */ 16 | .hljs-variable, 17 | .hljs-template-variable, 18 | .hljs-attribute, 19 | .hljs-attr, 20 | .hljs-tag, 21 | .hljs-name, 22 | .hljs-regexp, 23 | .hljs-link, 24 | .hljs-name, 25 | .hljs-selector-id, 26 | .hljs-selector-class { 27 | color: #d70025; 28 | } 29 | 30 | /* Orange */ 31 | .hljs-number, 32 | .hljs-meta, 33 | .hljs-built_in, 34 | .hljs-builtin-name, 35 | .hljs-literal, 36 | .hljs-type, 37 | .hljs-params { 38 | color: #b21e00; 39 | } 40 | 41 | /* Green */ 42 | .hljs-string, 43 | .hljs-symbol, 44 | .hljs-bullet { 45 | color: #008200; 46 | } 47 | 48 | /* Blue */ 49 | .hljs-title, 50 | .hljs-section { 51 | color: #0030f2; 52 | } 53 | 54 | /* Purple */ 55 | .hljs-keyword, 56 | .hljs-selector-tag { 57 | color: #9d00ec; 58 | } 59 | 60 | .hljs { 61 | display: block; 62 | overflow-x: auto; 63 | background: #f6f7f6; 64 | color: #000; 65 | } 66 | 67 | .hljs-emphasis { 68 | font-style: italic; 69 | } 70 | 71 | .hljs-strong { 72 | font-weight: bold; 73 | } 74 | 75 | .hljs-addition { 76 | color: #22863a; 77 | background-color: #f0fff4; 78 | } 79 | 80 | .hljs-deletion { 81 | color: #b31d28; 82 | background-color: #ffeef0; 83 | } 84 | -------------------------------------------------------------------------------- /book/static/configurations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/book/static/configurations.png -------------------------------------------------------------------------------- /book/static/history.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/book/static/history.png -------------------------------------------------------------------------------- /book/static/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/book/static/logo.png -------------------------------------------------------------------------------- /book/static/team/zekai.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/book/static/team/zekai.jpg -------------------------------------------------------------------------------- /book/toc.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 |
  1. Introduction
  2. Contributing
  3. Getting Started
  4. Installation
  5. Quick Start
  6. Reference Guide
  7. Metrics
    1. LabelRobustScoreHard
    2. LabelRobustScoreSoft
    3. AugmentationRobustScore
    4. GeneralEvaluator
  8. Augmentations
    1. DSA
    2. CutMix
    3. Mixup
  9. Models
    1. ConvNet
    2. AlexNet
    3. ResNet
    4. LeNet
    5. VGG
    6. MLP
  10. Datasets
  11. Config
31 | 32 | 33 | -------------------------------------------------------------------------------- /book/toc.js: -------------------------------------------------------------------------------- 1 | // Populate the sidebar 2 | // 3 | // This is a script, and not included directly in the page, to control the total size of the book. 4 | // The TOC contains an entry for each page, so if each page includes a copy of the TOC, 5 | // the total size of the page becomes O(n**2). 6 | class MDBookSidebarScrollbox extends HTMLElement { 7 | constructor() { 8 | super(); 9 | } 10 | connectedCallback() { 11 | this.innerHTML = '
  1. Introduction
  2. Contributing
  3. Getting Started
  4. Installation
  5. Quick Start
  6. Reference Guide
  7. Metrics
    1. LabelRobustScoreHard
    2. LabelRobustScoreSoft
    3. AugmentationRobustScore
    4. GeneralEvaluator
  8. Augmentations
    1. DSA
    2. CutMix
    3. Mixup
  9. Models
    1. ConvNet
    2. AlexNet
    3. ResNet
    4. LeNet
    5. VGG
    6. MLP
  10. Datasets
  11. Config
'; 12 | // Set the current, active page, and reveal it if it's hidden 13 | let current_page = document.location.href.toString(); 14 | if (current_page.endsWith("/")) { 15 | current_page += "index.html"; 16 | } 17 | var links = Array.prototype.slice.call(this.querySelectorAll("a")); 18 | var l = links.length; 19 | for (var i = 0; i < l; ++i) { 20 | var link = links[i]; 21 | var href = link.getAttribute("href"); 22 | if (href && !href.startsWith("#") && !/^(?:[a-z+]+:)?\/\//.test(href)) { 23 | link.href = path_to_root + href; 24 | } 25 | // The "index" page is supposed to alias the first chapter in the book. 26 | if (link.href === current_page || (i === 0 && path_to_root === "" && current_page.endsWith("/index.html"))) { 27 | link.classList.add("active"); 28 | var parent = link.parentElement; 29 | if (parent && parent.classList.contains("chapter-item")) { 30 | parent.classList.add("expanded"); 31 | } 32 | while (parent) { 33 | if (parent.tagName === "LI" && parent.previousElementSibling) { 34 | if (parent.previousElementSibling.classList.contains("chapter-item")) { 35 | parent.previousElementSibling.classList.add("expanded"); 36 | } 37 | } 38 | parent = parent.parentElement; 39 | } 40 | } 41 | } 42 | // Track and set sidebar scroll position 43 | this.addEventListener('click', function(e) { 44 | if (e.target.tagName === 'A') { 45 | sessionStorage.setItem('sidebar-scroll', this.scrollTop); 46 | } 47 | }, { passive: true }); 48 | var sidebarScrollTop = sessionStorage.getItem('sidebar-scroll'); 49 | sessionStorage.removeItem('sidebar-scroll'); 50 | if (sidebarScrollTop) { 51 | // preserve sidebar scroll position when navigating via links within sidebar 52 | this.scrollTop = sidebarScrollTop; 53 | } else { 54 | // scroll sidebar to current active section when navigating via "next/previous chapter" buttons 55 | var activeSection = document.querySelector('#sidebar .active'); 56 | if (activeSection) { 57 | activeSection.scrollIntoView({ block: 'center' }); 58 | } 59 | } 60 | // Toggle buttons 61 | var sidebarAnchorToggles = document.querySelectorAll('#sidebar a.toggle'); 62 | function toggleSection(ev) { 63 | ev.currentTarget.parentElement.classList.toggle('expanded'); 64 | } 65 | Array.from(sidebarAnchorToggles).forEach(function (el) { 66 | el.addEventListener('click', toggleSection); 67 | }); 68 | } 69 | } 70 | window.customElements.define("mdbook-sidebar-scrollbox", MDBookSidebarScrollbox); 71 | -------------------------------------------------------------------------------- /book/tomorrow-night.css: -------------------------------------------------------------------------------- 1 | /* Tomorrow Night Theme */ 2 | /* https://github.com/jmblog/color-themes-for-highlightjs */ 3 | /* Original theme - https://github.com/chriskempson/tomorrow-theme */ 4 | /* https://github.com/jmblog/color-themes-for-highlightjs */ 5 | 6 | /* Tomorrow Comment */ 7 | .hljs-comment { 8 | color: #969896; 9 | } 10 | 11 | /* Tomorrow Red */ 12 | .hljs-variable, 13 | .hljs-attribute, 14 | .hljs-attr, 15 | .hljs-tag, 16 | .hljs-regexp, 17 | .ruby .hljs-constant, 18 | .xml .hljs-tag .hljs-title, 19 | .xml .hljs-pi, 20 | .xml .hljs-doctype, 21 | .html .hljs-doctype, 22 | .css .hljs-id, 23 | .css .hljs-class, 24 | .css .hljs-pseudo { 25 | color: #cc6666; 26 | } 27 | 28 | /* Tomorrow Orange */ 29 | .hljs-number, 30 | .hljs-preprocessor, 31 | .hljs-pragma, 32 | .hljs-built_in, 33 | .hljs-literal, 34 | .hljs-params, 35 | .hljs-constant { 36 | color: #de935f; 37 | } 38 | 39 | /* Tomorrow Yellow */ 40 | .ruby .hljs-class .hljs-title, 41 | .css .hljs-rule .hljs-attribute { 42 | color: #f0c674; 43 | } 44 | 45 | /* Tomorrow Green */ 46 | .hljs-string, 47 | .hljs-value, 48 | .hljs-inheritance, 49 | .hljs-header, 50 | .hljs-name, 51 | .ruby .hljs-symbol, 52 | .xml .hljs-cdata { 53 | color: #b5bd68; 54 | } 55 | 56 | /* Tomorrow Aqua */ 57 | .hljs-title, 58 | .hljs-section, 59 | .css .hljs-hexcolor { 60 | color: #8abeb7; 61 | } 62 | 63 | /* Tomorrow Blue */ 64 | .hljs-function, 65 | .python .hljs-decorator, 66 | .python .hljs-title, 67 | .ruby .hljs-function .hljs-title, 68 | .ruby .hljs-title .hljs-keyword, 69 | .perl .hljs-sub, 70 | .javascript .hljs-title, 71 | .coffeescript .hljs-title { 72 | color: #81a2be; 73 | } 74 | 75 | /* Tomorrow Purple */ 76 | .hljs-keyword, 77 | .javascript .hljs-function { 78 | color: #b294bb; 79 | } 80 | 81 | .hljs { 82 | display: block; 83 | overflow-x: auto; 84 | background: #1d1f21; 85 | color: #c5c8c6; 86 | } 87 | 88 | .coffeescript .javascript, 89 | .javascript .xml, 90 | .tex .hljs-formula, 91 | .xml .javascript, 92 | .xml .vbscript, 93 | .xml .css, 94 | .xml .hljs-cdata { 95 | opacity: 0.5; 96 | } 97 | 98 | .hljs-addition { 99 | color: #718c00; 100 | } 101 | 102 | .hljs-deletion { 103 | color: #c82829; 104 | } 105 | -------------------------------------------------------------------------------- /configs/Demo_ARS.yaml: -------------------------------------------------------------------------------- 1 | # real data 2 | dataset: ImageNet1K 3 | real_data_path: ./dataset/ImageNet1K/ 4 | 5 | # synthetic data 6 | ipc: 10 7 | im_size: [224, 224] 8 | 9 | # agent model 10 | model_name: ResNet-18-BN 11 | stu_use_torchvision: true 12 | tea_use_torchvision: true 13 | teacher_dir: ./teacher_models 14 | teacher_model_names: [ResNet-18-BN] 15 | 16 | # syntheticdata augmentation 17 | data_aug_func: cutmix 18 | aug_params: 19 | beta: 1.0 20 | use_zca: false 21 | 22 | custom_train_trans: 23 | - name: RandomResizedCrop 24 | args: 25 | size: 224 26 | scale: [0.08, 1.0] 27 | - name: RandomHorizontalFlip 28 | args: 29 | p: 0.5 30 | - name: ToTensor 31 | - name: Normalize 32 | args: 33 | mean: [0.485, 0.456, 0.406] 34 | std: [0.229, 0.224, 0.225] 35 | 36 | custom_val_trans: 37 | - name: Resize 38 | args: 39 | size: 256 40 | - name: CenterCrop 41 | args: 42 | size: 224 43 | - name: ToTensor 44 | - name: Normalize 45 | args: 46 | mean: [0.485, 0.456, 0.406] 47 | std: [0.229, 0.224, 0.225] 48 | 49 | # soft label settings 50 | label_type: soft 51 | soft_label_mode: M 52 | soft_label_criterion: kl 53 | loss_fn_kwargs: 54 | temperature: 1.0 55 | scale_loss: false 56 | 57 | # training specifics 58 | optimizer: adamw 59 | lr_scheduler: cosine 60 | weight_decay: 0.01 61 | momentum: 0.9 62 | num_eval: 5 63 | num_epochs: 300 64 | num_workers: 4 65 | device: cuda 66 | dist: true 67 | batch_size: 1024 68 | random_data_path: ./random_data/my_method/ImageNet1K/IPC10/ 69 | random_data_format: image 70 | 71 | # save path 72 | save_path: ./my_method_ars_scores.csv 73 | -------------------------------------------------------------------------------- /configs/Demo_LRS_Hard_Label.yaml: -------------------------------------------------------------------------------- 1 | 2 | # real data 3 | dataset: CIFAR10 4 | real_data_path: ./dataset/ 5 | custom_val_trans: null 6 | 7 | # synthetic data 8 | ipc: 10 9 | im_size: [32, 32] 10 | 11 | # agent model 12 | model_name: ConvNet-3 13 | use_torchvision: false 14 | 15 | # data augmentation 16 | data_aug_func: "dsa" 17 | aug_params: 18 | flip: 0.5 19 | rotate: 15.0 20 | saturation: 2.0 21 | brightness: 1.0 22 | contrast: 0.5 23 | scale: 1.2 24 | crop: 0.125 25 | cutout: 0.5 26 | use_zca: false 27 | 28 | custom_train_trans: null 29 | custom_val_trans: null 30 | 31 | # training specifics 32 | optimizer: sgd 33 | lr_scheduler: step 34 | step_size: 500 35 | weight_decay: 0.0005 36 | momentum: 0.9 37 | num_eval: 5 38 | num_epochs: 1000 39 | syn_batch_size: 128 40 | real_batch_size: 256 41 | default_lr: 0.01 42 | num_workers: 4 43 | device: cuda 44 | dist: true 45 | eval_full_data: false 46 | random_data_path: ./results/my_method_random_data.pt 47 | random_data_format: tensor 48 | 49 | # save path 50 | save_path: ./results/my_method_hard_label_scores.csv 51 | -------------------------------------------------------------------------------- /configs/Demo_LRS_Soft_Label.yaml: -------------------------------------------------------------------------------- 1 | # real data 2 | dataset: ImageNet1K 3 | real_data_path: ./dataset/ImageNet1K/ 4 | 5 | # synthetic data 6 | ipc: 10 7 | im_size: [224, 224] 8 | 9 | # agent model 10 | model_name: ResNet-18-BN 11 | stu_use_torchvision: true 12 | tea_use_torchvision: true 13 | teacher_dir: ./teacher_models 14 | teacher_model_names: [ResNet-18-BN] 15 | 16 | # syntheticdata augmentation 17 | data_aug_func: cutmix 18 | aug_params: 19 | beta: 1.0 20 | use_zca: false 21 | 22 | custom_train_trans: 23 | - name: RandomResizedCrop 24 | args: 25 | size: 224 26 | scale: [0.08, 1.0] 27 | - name: RandomHorizontalFlip 28 | args: 29 | p: 0.5 30 | - name: ToTensor 31 | - name: Normalize 32 | args: 33 | mean: [0.485, 0.456, 0.406] 34 | std: [0.229, 0.224, 0.225] 35 | 36 | custom_val_trans: 37 | - name: Resize 38 | args: 39 | size: 256 40 | - name: CenterCrop 41 | args: 42 | size: 224 43 | - name: ToTensor 44 | - name: Normalize 45 | args: 46 | mean: [0.485, 0.456, 0.406] 47 | std: [0.229, 0.224, 0.225] 48 | 49 | use_aug_for_hard: false 50 | 51 | # soft label settings 52 | soft_label_mode: M 53 | soft_label_criterion: kl 54 | loss_fn_kwargs: 55 | temperature: 1.0 56 | scale_loss: false 57 | 58 | # training specifics 59 | optimizer: adamw 60 | lr_scheduler: cosine 61 | weight_decay: 0.01 62 | momentum: 0.9 63 | num_eval: 5 64 | num_epochs: 300 65 | num_workers: 4 66 | device: cuda 67 | dist: true 68 | eval_full_data: false 69 | syn_batch_size: 1024 70 | real_batch_size: 1024 71 | random_data_path: ./random_data/my_method/ImageNet1K/IPC10/ 72 | random_data_format: image 73 | 74 | # save path 75 | save_path: ./my_method_soft_label_scores.csv 76 | -------------------------------------------------------------------------------- /ddranking/__init__.py: -------------------------------------------------------------------------------- 1 | from .aug import DSA, Mixup, Cutmix, ZCAWhitening 2 | from .config import Config 3 | from .loss import KLDivergenceLoss, SoftCrossEntropyLoss, MSEGTLoss 4 | from .metrics import LabelRobustScoreHard, LabelRobustScoreSoft, AugmentationRobustScore, GeneralEvaluator 5 | from .utils import get_dataset, build_model, get_convnet, get_lenet, get_resnet, get_vgg, get_alexnet 6 | -------------------------------------------------------------------------------- /ddranking/aug/__init__.py: -------------------------------------------------------------------------------- 1 | from .dsa import DSA 2 | from .mixup import Mixup 3 | from .cutmix import Cutmix 4 | from .zca import ZCAWhitening -------------------------------------------------------------------------------- /ddranking/aug/cutmix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import kornia 4 | 5 | 6 | class Cutmix: 7 | def __init__(self, params: dict): 8 | self.beta = params["beta"] 9 | 10 | def rand_bbox(self, size, lam): 11 | W = size[2] 12 | H = size[3] 13 | cut_rat = np.sqrt(1.0 - lam) 14 | cut_w = int(W * cut_rat) 15 | cut_h = int(H * cut_rat) 16 | 17 | # uniform 18 | cx = np.random.randint(W) 19 | cy = np.random.randint(H) 20 | 21 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 22 | bby1 = np.clip(cy - cut_h // 2, 0, H) 23 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 24 | bby2 = np.clip(cy + cut_h // 2, 0, H) 25 | 26 | return bbx1, bby1, bbx2, bby2 27 | 28 | def cutmix(self, images): 29 | rand_index = torch.randperm(images.size()[0]).to(images.device) 30 | lam = np.random.beta(self.beta, self.beta) 31 | bbx1, bby1, bbx2, bby2 = self.rand_bbox(images.size(), lam) 32 | 33 | images[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2] 34 | return images 35 | 36 | def __call__(self, images): 37 | return self.cutmix(images) 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /ddranking/aug/dsa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | 6 | class DSA: 7 | 8 | def __init__(self, params: dict, seed: int=-1, aug_mode: str='S'): 9 | self.params = params 10 | self.seed = seed 11 | self.aug_mode = aug_mode 12 | 13 | default_funcs = ['scale', 'rotate', 'flip', 'color', 'crop', 'cutout'] 14 | self.transform_funcs = self.create_transform_funcs(default_funcs) 15 | 16 | def create_transform_funcs(self, func_names): 17 | funcs = [] 18 | for func_name in func_names: 19 | funcs.append(getattr(self, 'rand_' + func_name)) 20 | return funcs 21 | 22 | def set_seed_DiffAug(self): 23 | if self.params["latestseed"] == -1: 24 | return 25 | else: 26 | torch.random.manual_seed(self.params["latestseed"]) 27 | self.params["latestseed"] += 1 28 | 29 | # The following differentiable augmentation strategies are adapted from https://github.com/VICO-UoE/DatasetCondensation 30 | def rand_scale(self, x): 31 | # x>1, max scale 32 | # sx, sy: (0, +oo), 1: orignial size, 0.5: enlarge 2 times 33 | ratio = self.params["scale"] 34 | self.set_seed_DiffAug() 35 | sx = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio 36 | self.set_seed_DiffAug() 37 | sy = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio 38 | theta = [[[sx[i], 0, 0], 39 | [0, sy[i], 0],] for i in range(x.shape[0])] 40 | theta = torch.tensor(theta, dtype=torch.float) 41 | if self.params["siamese"]: # Siamese augmentation: 42 | theta[:] = theta[0] 43 | grid = F.affine_grid(theta, x.shape, align_corners=True).to(x.device) 44 | x = F.grid_sample(x, grid, align_corners=True) 45 | return x 46 | 47 | def rand_rotate(self, x): # [-180, 180], 90: anticlockwise 90 degree 48 | ratio = self.params["rotate"] 49 | self.set_seed_DiffAug() 50 | theta = (torch.rand(x.shape[0]) - 0.5) * 2 * ratio / 180 * float(np.pi) 51 | theta = [[[torch.cos(theta[i]), torch.sin(-theta[i]), 0], 52 | [torch.sin(theta[i]), torch.cos(theta[i]), 0],] for i in range(x.shape[0])] 53 | theta = torch.tensor(theta, dtype=torch.float) 54 | if self.params["siamese"]: # Siamese augmentation: 55 | theta[:] = theta[0] 56 | grid = F.affine_grid(theta, x.shape, align_corners=True).to(x.device) 57 | x = F.grid_sample(x, grid, align_corners=True) 58 | return x 59 | 60 | def rand_flip(self, x): 61 | prob = self.params["flip"] 62 | self.set_seed_DiffAug() 63 | randf = torch.rand(x.size(0), 1, 1, 1, device=x.device) 64 | if self.params["siamese"]: # Siamese augmentation: 65 | randf[:] = randf[0] 66 | return torch.where(randf < prob, x.flip(3), x) 67 | 68 | def rand_brightness(self, x): 69 | ratio = self.params["brightness"] 70 | self.set_seed_DiffAug() 71 | randb = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) 72 | if self.params["siamese"]: # Siamese augmentation: 73 | randb[:] = randb[0] 74 | x = x + (randb - 0.5)*ratio 75 | return x 76 | 77 | def rand_saturation(self, x): 78 | ratio = self.params["saturation"] 79 | x_mean = x.mean(dim=1, keepdim=True) 80 | self.set_seed_DiffAug() 81 | rands = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) 82 | if self.params["siamese"]: # Siamese augmentation: 83 | rands[:] = rands[0] 84 | x = (x - x_mean) * (rands * ratio) + x_mean 85 | return x 86 | 87 | def rand_contrast(self, x): 88 | ratio = self.params["contrast"] 89 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 90 | self.set_seed_DiffAug() 91 | randc = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) 92 | if self.params["siamese"]: # Siamese augmentation: 93 | randc[:] = randc[0] 94 | x = (x - x_mean) * (randc + ratio) + x_mean 95 | return x 96 | 97 | def rand_color(self, x): 98 | return self.rand_contrast(self.rand_saturation(self.rand_brightness(x))) 99 | 100 | def rand_crop(self, x): 101 | # The image is padded on its surrounding and then cropped. 102 | ratio = self.params["crop"] 103 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 104 | self.set_seed_DiffAug() 105 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 106 | self.set_seed_DiffAug() 107 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 108 | if self.params["siamese"]: # Siamese augmentation: 109 | translation_x[:] = translation_x[0] 110 | translation_y[:] = translation_y[0] 111 | grid_batch, grid_x, grid_y = torch.meshgrid( 112 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 113 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 114 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 115 | ) 116 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 117 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 118 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 119 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) 120 | return x 121 | 122 | def rand_cutout(self, x): 123 | ratio = self.params["cutout"] 124 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 125 | self.set_seed_DiffAug() 126 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) 127 | self.set_seed_DiffAug() 128 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) 129 | if self.params["siamese"]: # Siamese augmentation: 130 | offset_x[:] = offset_x[0] 131 | offset_y[:] = offset_y[0] 132 | grid_batch, grid_x, grid_y = torch.meshgrid( 133 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 134 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 135 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 136 | ) 137 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 138 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 139 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 140 | mask[grid_batch, grid_x, grid_y] = 0 141 | x = x * mask.unsqueeze(1) 142 | return x 143 | 144 | def __call__(self, images): 145 | 146 | if not self.transform_funcs: 147 | return images 148 | 149 | if self.seed == -1: 150 | self.params["siamese"] = False 151 | else: 152 | self.params["siamese"] = True 153 | 154 | self.params["latestseed"] = self.seed 155 | 156 | if self.aug_mode == 'M': # original 157 | for f in self.transform_funcs: 158 | images = f(images) 159 | 160 | elif self.aug_mode == 'S': 161 | self.set_seed_DiffAug() 162 | p = self.transform_funcs[torch.randint(0, len(self.transform_funcs), size=(1,)).item()] 163 | images = p(images) 164 | 165 | images = images.contiguous() 166 | 167 | return images -------------------------------------------------------------------------------- /ddranking/aug/mixup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import kornia 4 | 5 | 6 | class Mixup: 7 | def __init__(self, params: dict): 8 | self.lambda_ = params["lambda"] 9 | 10 | def mixup(self, images): 11 | rand_index = torch.randperm(images.size()[0]).to(images.device) 12 | lam = np.random.beta(self.lambda_, self.lambda_) 13 | 14 | mixed_images = lam * images + (1 - lam) * images[rand_index] 15 | return mixed_images 16 | 17 | def __call__(self, images): 18 | return self.mixup(images) 19 | -------------------------------------------------------------------------------- /ddranking/aug/zca.py: -------------------------------------------------------------------------------- 1 | import kornia 2 | 3 | 4 | class ZCAWhitening: 5 | def __init__(self, params: dict): 6 | self.transform = kornia.enhance.ZCAWhitening() 7 | 8 | def __call__(self, images): 9 | return self.transform(images, include_fit=True) -------------------------------------------------------------------------------- /ddranking/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .user_config import Config -------------------------------------------------------------------------------- /ddranking/config/user_config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import json 3 | from typing import Dict, Any 4 | from torchvision import transforms 5 | from torchvision.transforms import v2 6 | 7 | class Config: 8 | """Configuration object to manage individual configurations.""" 9 | def __init__(self, config: Dict[str, Any] = None): 10 | """Initialize with a configuration dictionary.""" 11 | self.config = config or {} 12 | 13 | @classmethod 14 | def from_file(cls, filepath: str): 15 | """Load configuration from a YAML or JSON file.""" 16 | if filepath.endswith(".yaml") or filepath.endswith(".yml"): 17 | with open(filepath, "r") as f: 18 | config = yaml.safe_load(f) 19 | elif filepath.endswith(".json"): 20 | with open(filepath, "r") as f: 21 | config = json.load(f) 22 | else: 23 | raise ValueError("Unsupported file format. Use YAML or JSON.") 24 | return cls(config) 25 | 26 | def load_transforms_from_yaml(self, values): 27 | if values is None: 28 | return None 29 | transform_list = [] 30 | for transform in values: 31 | name = transform["name"] 32 | args = transform.get("args", []) 33 | if isinstance(args, dict): 34 | if hasattr(transforms, name): 35 | transform_list.append(getattr(transforms, name)(**args)) 36 | elif hasattr(v2, name): 37 | transform_list.append(getattr(v2, name)(**args)) 38 | else: 39 | raise NotImplementedError 40 | else: 41 | transform_list.append(getattr(transforms, name)(*args)) 42 | 43 | return transforms.Compose(transform_list) 44 | 45 | def get(self, key: str, default: Any = None): 46 | """Get a value from the config.""" 47 | if key == "custom_train_trans": 48 | return self.load_transforms_from_yaml(self.config["custom_train_trans"]) 49 | elif key == "custom_val_trans": 50 | return self.load_transforms_from_yaml(self.config["custom_val_trans"]) 51 | elif key == "im_size": 52 | return tuple(self.config.get("im_size", default)) 53 | 54 | return self.config.get(key, default) 55 | 56 | def update(self, overrides: Dict[str, Any]): 57 | """Update the configuration with overrides.""" 58 | self.config.update(overrides) 59 | 60 | def __repr__(self): 61 | return f"Config({self.config})" -------------------------------------------------------------------------------- /ddranking/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .kl import KLDivergenceLoss 2 | from .sce import SoftCrossEntropyLoss 3 | from .mse_gt import MSEGTLoss -------------------------------------------------------------------------------- /ddranking/loss/kl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class KLDivergenceLoss(nn.Module): 7 | def __init__(self, temperature=1.2, scale_loss=False): 8 | super(KLDivergenceLoss, self).__init__() 9 | self.temperature = temperature 10 | self.scale_loss = scale_loss 11 | 12 | def forward(self, stu_outputs, tea_outputs): 13 | stu_probs = F.log_softmax(stu_outputs / self.temperature, dim=1) 14 | with torch.no_grad(): 15 | tea_probs = F.softmax(tea_outputs / self.temperature, dim=1) 16 | loss = F.kl_div(stu_probs, tea_probs, reduction='batchmean') 17 | if self.scale_loss: 18 | loss = loss * (self.temperature ** 2) 19 | return loss -------------------------------------------------------------------------------- /ddranking/loss/mse_gt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class MSEGTLoss(nn.Module): 7 | def __init__(self, mse_weight=1.0, ce_weight=0.0): 8 | super(MSEGTLoss, self).__init__() 9 | self.mse_weight = mse_weight 10 | self.ce_weight = ce_weight 11 | 12 | def forward(self, stu_outputs, tea_outputs, ground_truth): 13 | mse_loss = F.mse_loss(stu_outputs, tea_outputs) 14 | ce_loss = F.cross_entropy(stu_outputs, ground_truth) 15 | loss = self.mse_weight * mse_loss + self.ce_weight * ce_loss 16 | return loss -------------------------------------------------------------------------------- /ddranking/loss/sce.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SoftCrossEntropyLoss(nn.Module): 7 | def __init__(self, temperature=1.2, scale_loss=False): 8 | super(SoftCrossEntropyLoss, self).__init__() 9 | self.temperature = temperature 10 | self.scale_loss = scale_loss 11 | 12 | def forward(self, stu_outputs, tea_outputs): 13 | input_log_likelihood = -F.log_softmax(stu_outputs / self.temperature, dim=1) 14 | target_log_likelihood = F.softmax(tea_outputs / self.temperature, dim=1) 15 | batch_size = stu_outputs.size(0) 16 | loss = torch.sum(torch.mul(input_log_likelihood, target_log_likelihood)) / batch_size 17 | if self.scale_loss: 18 | loss = loss * (self.temperature ** 2) 19 | return loss -------------------------------------------------------------------------------- /ddranking/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .general import GeneralEvaluator 2 | from .soft_label import LabelRobustScoreSoft 3 | from .hard_label import LabelRobustScoreHard 4 | from .aug_robust import AugmentationRobustScore -------------------------------------------------------------------------------- /ddranking/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .misc import set_seed, save_results, setup_dist, logging, broadcast_string 2 | from .data import get_dataset, get_random_data_tensors, get_random_data_path_from_cifar, get_random_data_path, TensorDataset 3 | from .model import build_model, get_pretrained_model_path, get_convnet, get_lenet, get_resnet, get_vgg, get_alexnet 4 | from .train_and_eval import get_optimizer, get_lr_scheduler, train_one_epoch, validate, REAL_DATA_TRAINING_CONFIG, REAL_DATA_ACC_CACHE 5 | from .meter import MetricLogger, accuracy -------------------------------------------------------------------------------- /ddranking/utils/meter.py: -------------------------------------------------------------------------------- 1 | import time 2 | import datetime 3 | import torch 4 | from collections import deque, defaultdict 5 | from .misc import reduce_across_processes 6 | 7 | 8 | class SmoothedValue: 9 | """Track a series of values and provide access to smoothed values over a 10 | window or the global series average. 11 | """ 12 | 13 | def __init__(self, window_size=20, fmt=None): 14 | if fmt is None: 15 | fmt = "{median:.4f} ({global_avg:.4f})" 16 | self.deque = deque(maxlen=window_size) 17 | self.total = 0.0 18 | self.count = 0 19 | self.fmt = fmt 20 | 21 | def update(self, value, n=1): 22 | self.deque.append(value) 23 | self.count += n 24 | self.total += value * n 25 | 26 | def synchronize_between_processes(self): 27 | """ 28 | Warning: does not synchronize the deque! 29 | """ 30 | t = reduce_across_processes([self.count, self.total]) 31 | t = t.tolist() 32 | self.count = int(t[0]) 33 | self.total = t[1] 34 | 35 | @property 36 | def median(self): 37 | d = torch.tensor(list(self.deque)) 38 | return d.median().item() 39 | 40 | @property 41 | def avg(self): 42 | d = torch.tensor(list(self.deque), dtype=torch.float32) 43 | return d.mean().item() 44 | 45 | @property 46 | def global_avg(self): 47 | return self.total / self.count 48 | 49 | @property 50 | def max(self): 51 | return max(self.deque) 52 | 53 | @property 54 | def value(self): 55 | return self.deque[-1] 56 | 57 | def __str__(self): 58 | return self.fmt.format( 59 | median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value 60 | ) 61 | 62 | 63 | class MetricLogger: 64 | def __init__(self, delimiter="\t"): 65 | self.meters = defaultdict(SmoothedValue) 66 | self.delimiter = delimiter 67 | 68 | def update(self, **kwargs): 69 | for k, v in kwargs.items(): 70 | if isinstance(v, torch.Tensor): 71 | v = v.item() 72 | assert isinstance(v, (float, int)) 73 | self.meters[k].update(v) 74 | 75 | def __getattr__(self, attr): 76 | if attr in self.meters: 77 | return self.meters[attr] 78 | if attr in self.__dict__: 79 | return self.__dict__[attr] 80 | raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'") 81 | 82 | def __str__(self): 83 | loss_str = [] 84 | for name, meter in self.meters.items(): 85 | loss_str.append(f"{name}: {str(meter)}") 86 | return self.delimiter.join(loss_str) 87 | 88 | def synchronize_between_processes(self): 89 | for meter in self.meters.values(): 90 | meter.synchronize_between_processes() 91 | 92 | def add_meter(self, name, meter): 93 | self.meters[name] = meter 94 | 95 | def log_every(self, iterable, print_freq, header=None): 96 | i = 0 97 | if not header: 98 | header = "" 99 | start_time = time.time() 100 | end = time.time() 101 | iter_time = SmoothedValue(fmt="{avg:.4f}") 102 | data_time = SmoothedValue(fmt="{avg:.4f}") 103 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 104 | if torch.cuda.is_available(): 105 | log_msg = self.delimiter.join( 106 | [ 107 | header, 108 | "[{0" + space_fmt + "}/{1}]", 109 | "eta: {eta}", 110 | "{meters}", 111 | "time: {time}", 112 | "data: {data}", 113 | "max mem: {memory:.0f}", 114 | ] 115 | ) 116 | else: 117 | log_msg = self.delimiter.join( 118 | [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"] 119 | ) 120 | MB = 1024.0 * 1024.0 121 | for obj in iterable: 122 | data_time.update(time.time() - end) 123 | yield obj 124 | iter_time.update(time.time() - end) 125 | if i % print_freq == 0: 126 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 127 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 128 | if torch.cuda.is_available(): 129 | print( 130 | log_msg.format( 131 | i, 132 | len(iterable), 133 | eta=eta_string, 134 | meters=str(self), 135 | time=str(iter_time), 136 | data=str(data_time), 137 | memory=torch.cuda.max_memory_allocated() / MB, 138 | ) 139 | ) 140 | else: 141 | print( 142 | log_msg.format( 143 | i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time) 144 | ) 145 | ) 146 | i += 1 147 | end = time.time() 148 | total_time = time.time() - start_time 149 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 150 | print(f"{header} Total time: {total_time_str}") 151 | 152 | 153 | def accuracy(output, target, topk=(1,)): 154 | """Computes the accuracy over the k top predictions for the specified values of k""" 155 | with torch.inference_mode(): 156 | maxk = max(topk) 157 | batch_size = target.size(0) 158 | if target.ndim == 2: 159 | target = target.max(dim=1)[1] 160 | 161 | _, pred = output.topk(maxk, 1, True, True) 162 | pred = pred.t() 163 | correct = pred.eq(target[None]) 164 | 165 | res = [] 166 | for k in topk: 167 | correct_k = correct[:k].flatten().sum(dtype=torch.float32) 168 | res.append(correct_k * (100.0 / batch_size)) 169 | return res 170 | -------------------------------------------------------------------------------- /ddranking/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import numpy as np 5 | import random 6 | import pandas as pd 7 | 8 | 9 | def set_seed(seed=None): 10 | if seed is None: 11 | seed = int(time.time() * 1000) % 1000000 12 | torch.manual_seed(seed) 13 | torch.cuda.manual_seed(seed) 14 | torch.cuda.manual_seed_all(seed) 15 | np.random.seed(seed) 16 | random.seed(seed) 17 | 18 | 19 | def save_results(results, save_path): 20 | df = pd.DataFrame(results) 21 | df.to_csv(save_path, index=False) 22 | 23 | 24 | def setup_for_distributed(is_master): 25 | """ 26 | This function disables printing when not in master process 27 | """ 28 | import builtins as __builtin__ 29 | 30 | builtin_print = __builtin__.print 31 | 32 | def print(*args, **kwargs): 33 | force = kwargs.pop("force", False) 34 | if is_master or force: 35 | builtin_print(*args, **kwargs) 36 | 37 | __builtin__.print = print 38 | 39 | 40 | def setup_dist(args): 41 | # local_rank = int(os.environ.get("LOCAL_RANK", 0)) 42 | # torch.cuda.set_device(local_rank) # This ensures each process uses the correct GPU 43 | # torch.distributed.init_process_group(backend="nccl" if device == "cuda" else "gloo") 44 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 45 | args.rank = int(os.environ["RANK"]) 46 | args.world_size = int(os.environ["WORLD_SIZE"]) 47 | args.gpu = int(os.environ["LOCAL_RANK"]) 48 | elif "SLURM_PROCID" in os.environ: 49 | args.rank = int(os.environ["SLURM_PROCID"]) 50 | args.gpu = args.rank % torch.cuda.device_count() 51 | elif hasattr(args, "rank"): 52 | pass 53 | else: 54 | print("Not using distributed mode") 55 | args.distributed = False 56 | return 57 | 58 | args.distributed = True 59 | 60 | torch.cuda.set_device(args.gpu) 61 | args.dist_backend = "nccl" 62 | torch.distributed.init_process_group( 63 | backend=args.dist_backend, world_size=args.world_size, rank=args.rank 64 | ) 65 | torch.distributed.barrier() 66 | setup_for_distributed(args.rank == 0) 67 | 68 | 69 | def logging(message): 70 | rank = int(os.environ.get("RANK", 0)) 71 | if rank == 0: 72 | print(message) 73 | 74 | 75 | def broadcast_string(string_data, device, src=0): 76 | rank = torch.distributed.get_rank() 77 | 78 | if rank == src: 79 | encoded_string = string_data.encode('utf-8') 80 | string_len = torch.tensor([len(encoded_string)], dtype=torch.long, device=device) 81 | else: 82 | string_len = torch.tensor([0], dtype=torch.long, device=device) 83 | 84 | torch.distributed.broadcast(string_len, src=src) 85 | 86 | if rank == src: 87 | string_tensor = torch.tensor(list(encoded_string), dtype=torch.uint8, device=device) 88 | else: 89 | string_tensor = torch.zeros(string_len.item(), dtype=torch.uint8, device=device) 90 | 91 | torch.distributed.broadcast(string_tensor, src=src) 92 | received_string = ''.join([chr(byte) for byte in string_tensor]) 93 | return received_string 94 | 95 | def is_dist_avail_and_initialized(): 96 | if not torch.distributed.is_available(): 97 | return False 98 | if not torch.distributed.is_initialized(): 99 | return False 100 | return True 101 | 102 | def reduce_across_processes(val): 103 | if not is_dist_avail_and_initialized(): 104 | # nothing to sync, but we still convert to tensor for consistency with the distributed case. 105 | return torch.tensor(val) 106 | 107 | t = torch.tensor(val, device="cuda") 108 | torch.distributed.barrier() 109 | torch.distributed.all_reduce(t) 110 | return t -------------------------------------------------------------------------------- /ddranking/utils/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import timm 4 | import os 5 | from .networks import ConvNet, MLP, LeNet, AlexNet, VGG, ResNet, BasicBlock, Bottleneck 6 | 7 | 8 | def parse_model_name(model_name): 9 | if "-" not in model_name: 10 | return 0, False 11 | try: 12 | depth = int(model_name.split("-")[1]) 13 | if "BN" in model_name and len(model_name.split("-")) > 2 and model_name.split("-")[2] == "BN": 14 | batchnorm = True 15 | else: 16 | batchnorm = False 17 | except: 18 | raise ValueError("Model name must be in the format of -[]-[]") 19 | return depth, batchnorm 20 | 21 | 22 | def get_convnet(model_name, im_size, channel, num_classes, net_depth, net_norm, pretrained=False, model_path=None): 23 | # print(f"Creating {model_name} with depth={net_depth}, norm={net_norm}") 24 | model = ConvNet(channel=channel, num_classes=num_classes, net_width=128, net_depth=net_depth, 25 | net_act='relu', net_norm=net_norm, net_pooling='avgpooling', im_size=im_size) 26 | if pretrained: 27 | model.load_state_dict(torch.load(model_path, map_location='cpu', weights_only=True)) 28 | return model 29 | 30 | def get_mlp(model_name, im_size, channel, num_classes, pretrained=False, model_path=None): 31 | # print(f"Creating {model_name} with channel={channel}, num_classes={num_classes}") 32 | model = MLP(channel=channel, num_classes=num_classes, res=im_size[0]) 33 | if pretrained: 34 | model.load_state_dict(torch.load(model_path, map_location='cpu', weights_only=True)) 35 | return model 36 | 37 | def get_lenet(model_name, im_size, channel, num_classes, pretrained=False, model_path=None): 38 | # print(f"Creating {model_name} with channel={channel}, num_classes={num_classes}") 39 | model = LeNet(channel=channel, num_classes=num_classes, res=im_size[0]) 40 | if pretrained: 41 | model.load_state_dict(torch.load(model_path, map_location='cpu', weights_only=True)) 42 | return model 43 | 44 | def get_alexnet(model_name, im_size, channel, num_classes, use_torchvision=False, pretrained=False, model_path=None): 45 | # print(f"Creating {model_name} with channel={channel}, num_classes={num_classes}") 46 | if use_torchvision: 47 | model = torchvision.models.alexnet(num_classes=num_classes, pretrained=False) 48 | if im_size == (32, 32) or im_size == (64, 64): 49 | model.features[0] = torch.nn.Conv2d(3, 64, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False) 50 | else: 51 | model = AlexNet(channel=channel, num_classes=num_classes, res=im_size[0]) 52 | if pretrained: 53 | model.load_state_dict(torch.load(model_path, map_location='cpu', weights_only=True)) 54 | return model 55 | 56 | def get_vgg(model_name, im_size, channel, num_classes, depth=11, batchnorm=False, use_torchvision=False, pretrained=False, model_path=None): 57 | # print(f"Creating {model_name} with channel={channel}, num_classes={num_classes}") 58 | if use_torchvision: 59 | if depth == 11: 60 | if batchnorm: 61 | model = torchvision.models.vgg11_bn(num_classes=num_classes, pretrained=False) 62 | else: 63 | model = torchvision.models.vgg11(num_classes=num_classes, pretrained=False) 64 | elif depth == 13: 65 | if batchnorm: 66 | model = torchvision.models.vgg13_bn(num_classes=num_classes, pretrained=False) 67 | else: 68 | model = torchvision.models.vgg13(num_classes=num_classes, pretrained=False) 69 | elif depth == 16: 70 | if batchnorm: 71 | model = torchvision.models.vgg16_bn(num_classes=num_classes, pretrained=False) 72 | else: 73 | model = torchvision.models.vgg16(num_classes=num_classes, pretrained=False) 74 | elif depth == 19: 75 | if batchnorm: 76 | model = torchvision.models.vgg19_bn(num_classes=num_classes, pretrained=False) 77 | else: 78 | model = torchvision.models.vgg19(num_classes=num_classes, pretrained=False) 79 | else: 80 | model = VGG(f'VGG{depth}', channel, num_classes, norm='batchnorm' if batchnorm else 'instancenorm', res=im_size[0]) 81 | 82 | if pretrained: 83 | model.load_state_dict(torch.load(model_path, map_location='cpu', weights_only=True)) 84 | return model 85 | 86 | 87 | def get_resnet(model_name, im_size, channel, num_classes, depth=18, batchnorm=False, use_torchvision=False, pretrained=False, model_path=None): 88 | # print(f"Creating {model_name} with channel={channel}, num_classes={num_classes}") 89 | if use_torchvision: 90 | # print(f"ResNet in torchvision uses batchnorm by default.") 91 | if depth == 18: 92 | model = torchvision.models.resnet18(num_classes=num_classes, pretrained=False) 93 | elif depth == 34: 94 | model = torchvision.models.resnet34(num_classes=num_classes, pretrained=False) 95 | elif depth == 50: 96 | model = torchvision.models.resnet50(num_classes=num_classes, pretrained=False) 97 | if im_size == (64, 64) or im_size == (32, 32): 98 | model.conv1 = torch.nn.Conv2d(3, 64, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False) 99 | model.maxpool = torch.nn.Identity() 100 | else: 101 | if depth == 18: 102 | model = ResNet(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes, norm='batchnorm' if batchnorm else 'instancenorm', res=im_size[0]) 103 | elif depth == 34: 104 | model = ResNet(BasicBlock, [3,4,6,3], channel=channel, num_classes=num_classes, norm='batchnorm' if batchnorm else 'instancenorm', res=im_size[0]) 105 | elif depth == 50: 106 | model = ResNet(Bottleneck, [3,4,6,3], channel=channel, num_classes=num_classes, norm='batchnorm' if batchnorm else 'instancenorm', res=im_size[0]) 107 | 108 | if pretrained: 109 | model.load_state_dict(torch.load(model_path, map_location='cpu', weights_only=True)) 110 | return model 111 | 112 | 113 | def get_other_models(model_name, num_classes, im_size=(32, 32), pretrained=False): 114 | try: 115 | model = torchvision.models.get_model(model_name, num_classes=num_classes, pretrained=pretrained) 116 | except: 117 | try: 118 | model = timm.create_model(model_name, num_classes=num_classes, pretrained=pretrained) 119 | except: 120 | raise ValueError(f"Model {model_name} not found") 121 | return model 122 | 123 | 124 | def build_model(model_name: str, num_classes: int, im_size: tuple, pretrained: bool=False, model_path: str=None, use_torchvision: bool=False, device: str="cuda"): 125 | assert model_name is not None, "model name must be provided" 126 | depth, batchnorm = parse_model_name(model_name) 127 | if model_name.startswith("ConvNet"): 128 | model = get_convnet(model_name, channel=3, num_classes=num_classes, im_size=im_size, net_depth=depth, 129 | net_norm="instancenorm" if not batchnorm else "batchnorm", pretrained=pretrained, model_path=model_path) 130 | elif model_name.startswith("AlexNet"): 131 | model = get_alexnet(model_name, im_size=im_size, channel=3, num_classes=num_classes, pretrained=pretrained, 132 | use_torchvision=use_torchvision, model_path=model_path) 133 | elif model_name.startswith("ResNet"): 134 | model = get_resnet(model_name, im_size=im_size, channel=3, num_classes=num_classes, depth=depth, use_torchvision=use_torchvision, 135 | batchnorm=batchnorm, pretrained=pretrained, model_path=model_path) 136 | elif model_name.startswith("LeNet"): 137 | model = get_lenet(model_name, im_size=im_size, channel=3, num_classes=num_classes, pretrained=pretrained, model_path=model_path) 138 | elif model_name.startswith("MLP"): 139 | model = get_mlp(model_name, im_size=im_size, channel=3, num_classes=num_classes, pretrained=pretrained, model_path=model_path) 140 | elif model_name.startswith("VGG"): 141 | model = get_vgg(model_name, im_size=im_size, channel=3, num_classes=num_classes, depth=depth, batchnorm=batchnorm, 142 | use_torchvision=use_torchvision, pretrained=pretrained, model_path=model_path) 143 | else: 144 | model = get_other_models(model_name, num_classes=num_classes, im_size=im_size, pretrained=pretrained) 145 | 146 | model.to(device) 147 | return model 148 | 149 | 150 | def get_pretrained_model_path(teacher_dir, model_names, dataset): 151 | 152 | return [os.path.join(os.path.join(teacher_dir, f"{dataset}", f"{model_name}", "ckpt_best.pt")) 153 | if os.path.exists(os.path.join(os.path.join(teacher_dir, f"{dataset}", f"{model_name}", "ckpt_best.pt"))) 154 | else None for model_name in model_names] -------------------------------------------------------------------------------- /ddranking/utils/train_and_eval.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import timm 4 | import math 5 | import warnings 6 | import datetime 7 | from torch.optim import SGD, Adam, AdamW 8 | from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR, LambdaLR 9 | from collections import OrderedDict 10 | from .meter import MetricLogger, SmoothedValue, accuracy 11 | from .misc import reduce_across_processes, is_dist_avail_and_initialized 12 | from ..loss import MSEGTLoss 13 | 14 | 15 | REAL_DATA_ACC_CACHE = { 16 | "ImageNet1K-ResNet-18-BN": 56.5, 17 | "TinyImageNet-ResNet-18-BN": 46.4 18 | } 19 | 20 | REAL_DATA_TRAINING_CONFIG = { 21 | "ImageNet1K-ResNet-18-BN": { 22 | "optimizer": "sgd", 23 | "lr_scheduler": "step", 24 | "weight_decay": 0.0001, 25 | "momentum": 0.9, 26 | "num_epochs": 90, 27 | "batch_size": 512, 28 | "lr": 0.1, 29 | "step_size": 30, 30 | "gamma": 0.1 31 | }, 32 | "TinyImageNet-ResNet-18-BN": { 33 | "optimizer": "adamw", 34 | "lr_scheduler": "cosine", 35 | "weight_decay": 0.01, 36 | "lr": 0.01, 37 | "num_epochs": 100, 38 | "batch_size": 512, 39 | "step_size": 0, 40 | "gamma": 0, 41 | "momentum": (0.9, 0.999) 42 | }, 43 | "TinyImageNet-ConvNet-4-BN": { 44 | "optimizer": "sgd", 45 | "lr_scheduler": "step", 46 | "weight_decay": 0.0005, 47 | "momentum": 0.9, 48 | "num_epochs": 100, 49 | "batch_size": 512, 50 | "lr": 0.01, 51 | "step_size": 50, 52 | "gamma": 0.1 53 | }, 54 | "CIFAR10-ConvNet-3": { 55 | "optimizer": "sgd", 56 | "lr_scheduler": "step", 57 | "weight_decay": 0.0005, 58 | "momentum": 0.9, 59 | "num_epochs": 200, 60 | "batch_size": 512, 61 | "lr": 0.01, 62 | "step_size": 100, 63 | "gamma": 0.1 64 | }, 65 | "CIFAR10-ConvNet-3-BN": { 66 | "optimizer": "sgd", 67 | "lr_scheduler": "step", 68 | "weight_decay": 0.0005, 69 | "momentum": 0.9, 70 | "num_epochs": 200, 71 | "batch_size": 512, 72 | "lr": 0.01, 73 | "step_size": 100, 74 | "gamma": 0.1 75 | }, 76 | "CIFAR100-ConvNet-3": { 77 | "optimizer": "sgd", 78 | "lr_scheduler": "step", 79 | "weight_decay": 0.0005, 80 | "momentum": 0.9, 81 | "num_epochs": 200, 82 | "batch_size": 512, 83 | "lr": 0.01, 84 | "step_size": 100, 85 | "gamma": 0.1 86 | }, 87 | "CIFAR100-ConvNet-3-BN": { 88 | "optimizer": "sgd", 89 | "lr_scheduler": "step", 90 | "weight_decay": 0.0005, 91 | "momentum": 0.9, 92 | "num_epochs": 200, 93 | "batch_size": 512, 94 | "lr": 0.01, 95 | "step_size": 100, 96 | "gamma": 0.1 97 | } 98 | } 99 | 100 | 101 | def default_augmentation(images): 102 | return images 103 | 104 | def get_optimizer(optimizer_name, model, lr, weight_decay=0.0005, momentum=0.9): 105 | if optimizer_name == 'sgd': 106 | return SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay) 107 | elif optimizer_name == 'adam': 108 | return Adam(model.parameters(), lr=lr, weight_decay=weight_decay, betas=momentum if isinstance(momentum, tuple) else (0.9, 0.999)) 109 | elif optimizer_name == 'adamw': 110 | return AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, betas=momentum if isinstance(momentum, tuple) else (0.9, 0.999)) 111 | else: 112 | raise NotImplementedError(f"Optimizer {optimizer_name} not implemented") 113 | 114 | def get_lr_scheduler(lr_scheduler_name, optimizer, num_epochs=None, step_size=None, gamma=None): 115 | if lr_scheduler_name == 'step': 116 | assert step_size is not None, "step_size must be provided for step scheduler" 117 | return StepLR(optimizer, step_size=step_size, gamma=gamma if gamma is not None else 0.1) 118 | elif lr_scheduler_name == 'cosineannealing': 119 | assert num_epochs is not None, "num_epochs must be provided for cosine scheduler" 120 | return CosineAnnealingLR(optimizer, T_max=num_epochs) 121 | elif lr_scheduler_name == 'cosine': 122 | assert num_epochs is not None, "num_epochs must be provided for lambda cosine scheduler" 123 | return LambdaLR(optimizer, lambda step: 0.5 * (1.0 + math.cos(math.pi * step / num_epochs / 2)) 124 | if step <= num_epochs 125 | else 0, 126 | last_epoch=-1, 127 | ) 128 | else: 129 | raise NotImplementedError(f"LR Scheduler {lr_scheduler_name} not implemented") 130 | 131 | def train_one_epoch( 132 | epoch, 133 | stu_model, 134 | loader, 135 | loss_fn, 136 | optimizer, 137 | soft_label_mode='S', 138 | aug_func=None, 139 | tea_models=None, 140 | lr_scheduler=None, 141 | class_map=None, 142 | grad_accum_steps=1, 143 | log_interval=500, 144 | device='cuda' 145 | ): 146 | 147 | stu_model.train() 148 | if tea_models is not None: 149 | for tea_model in tea_models: 150 | tea_model.eval() 151 | 152 | if is_dist_avail_and_initialized(): 153 | loader.sampler.set_epoch(epoch) 154 | 155 | if aug_func is None: 156 | aug_func = default_augmentation 157 | 158 | metric_logger = MetricLogger(delimiter=" ") 159 | metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value}")) 160 | metric_logger.add_meter("img/s", SmoothedValue(window_size=10, fmt="{value}")) 161 | 162 | header = f"Epoch: [{epoch}]" 163 | 164 | accumulated_loss = 0.0 165 | accum_step = 0 166 | 167 | for i, (images, targets) in enumerate(metric_logger.log_every(loader, log_interval, header)): 168 | start_time = time.time() 169 | 170 | if class_map is not None: 171 | targets = torch.tensor([class_map[targets[i].item()] for i in range(len(targets))], dtype=targets.dtype, device=targets.device) 172 | 173 | images, targets = images.to(device), targets.to(device) 174 | images = aug_func(images) 175 | 176 | raw_targets = targets.clone() 177 | if soft_label_mode == 'M': 178 | tea_outputs = [tea_model(images) for tea_model in tea_models] 179 | tea_output = torch.stack(tea_outputs, dim=0).mean(dim=0) 180 | targets = tea_output 181 | 182 | output = stu_model(images) 183 | 184 | if isinstance(loss_fn, MSEGTLoss): 185 | loss = loss_fn(output, targets, raw_targets) 186 | else: 187 | loss = loss_fn(output, targets) 188 | 189 | loss = loss / grad_accum_steps 190 | accumulated_loss += loss.item() 191 | 192 | loss.backward() 193 | 194 | accum_step += 1 195 | if accum_step == grad_accum_steps: 196 | optimizer.step() 197 | optimizer.zero_grad() 198 | accum_step = 0 199 | 200 | metric_logger.update(loss=accumulated_loss, lr=round(optimizer.param_groups[0]["lr"], 8)) 201 | accumulated_loss = 0.0 202 | 203 | acc1, acc5 = accuracy(output, targets, topk=(1, 5)) 204 | batch_size = images.shape[0] 205 | 206 | if accum_step == 0: 207 | metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) 208 | metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) 209 | metric_logger.meters["img/s"].update(round(batch_size / (time.time() - start_time), 2)) 210 | 211 | if accum_step > 0: 212 | optimizer.step() 213 | optimizer.zero_grad() 214 | 215 | if lr_scheduler is not None: 216 | lr_scheduler.step() 217 | 218 | 219 | def validate( 220 | model, 221 | loader, 222 | device='cuda', 223 | class_map=None, 224 | log_interval=100, 225 | topk=(1, 5) 226 | ): 227 | model.eval() 228 | metric_logger = MetricLogger(delimiter=" ") 229 | header = f"Test" 230 | 231 | num_processed_samples = 0 232 | with torch.inference_mode(): 233 | for image, target in metric_logger.log_every(loader, log_interval, header): 234 | if class_map is not None: 235 | target = torch.tensor([class_map[target[i].item()] for i in range(len(target))], dtype=target.dtype, device=target.device) 236 | image = image.to(device, non_blocking=True) 237 | target = target.to(device, non_blocking=True) 238 | output = model(image) 239 | 240 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 241 | batch_size = image.shape[0] 242 | metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) 243 | metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) 244 | num_processed_samples += batch_size 245 | 246 | num_processed_samples = reduce_across_processes(num_processed_samples) 247 | if ( 248 | hasattr(loader.dataset, "__len__") 249 | and len(loader.dataset) != num_processed_samples 250 | and torch.distributed.get_rank() == 0 251 | ): 252 | warnings.warn( 253 | f"It looks like the dataset has {len(loader.dataset)} samples, but {num_processed_samples} " 254 | "samples were used for the validation, which might bias the results. " 255 | "Try adjusting the batch size and / or the world size. " 256 | "Setting the world size to 1 is always a safe bet." 257 | ) 258 | 259 | metric_logger.synchronize_between_processes() 260 | 261 | print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}") 262 | return metric_logger.acc1.global_avg -------------------------------------------------------------------------------- /demo_aug.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from ddranking.metrics import AugmentationRobustScore 5 | from ddranking.config import Config 6 | 7 | """ Use config file to specify the arguments (Recommended) """ 8 | config = Config.from_file("./configs/Demo_ARS.yaml") 9 | aug_evaluator = AugmentationRobustScore(config) 10 | 11 | syn_data_dir = "./baselines/SRe2L/ImageNet1K/IPC10/" 12 | print(aug_evaluator.compute_metrics(image_path=syn_data_dir, syn_lr=0.001)) 13 | 14 | 15 | """ Use keyword arguments """ 16 | from torchvision import transforms 17 | device = "cuda" 18 | method_name = "SRe2L" # Specify your method name 19 | ipc = 10 # Specify your IPC 20 | dataset = "ImageNet1K" # Specify your dataset name 21 | syn_data_dir = "./SRe2L/ImageNet1K/IPC10/" # Specify your synthetic data path 22 | data_dir = "./datasets" # Specify your dataset path 23 | model_name = "ResNet-18-BN" # Specify your model name 24 | im_size = (224, 224) # Specify your image size 25 | cutmix_params = { # Specify your data augmentation parameters 26 | "beta": 1.0 27 | } 28 | 29 | syn_images = torch.load(os.path.join(syn_data_dir, f"images.pt"), map_location='cpu') 30 | soft_labels = torch.load(os.path.join(syn_data_dir, f"labels.pt"), map_location='cpu') 31 | syn_lr = torch.load(os.path.join(syn_data_dir, f"lr.pt"), map_location='cpu') 32 | save_path = f"./results/{dataset}/{model_name}/IPC{ipc}/dm_hard_scores.csv" 33 | 34 | custom_train_trans = transforms.Compose([ 35 | transforms.RandomResizedCrop(224, scale=(0.08, 1.0)), 36 | transforms.RandomHorizontalFlip(), 37 | transforms.ToTensor(), 38 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 39 | ]) 40 | custom_val_trans = transforms.Compose([ 41 | transforms.Resize(256), 42 | transforms.CenterCrop(224), 43 | transforms.ToTensor(), 44 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 45 | ]) 46 | 47 | aug_evaluator = AugmentationRobustScore( 48 | dataset=dataset, 49 | real_data_path=data_dir, 50 | ipc=ipc, 51 | model_name=model_name, 52 | label_type='soft', 53 | soft_label_criterion='kl', # Use Soft Cross Entropy Loss 54 | soft_label_mode='M', # Use one-to-one image to soft label mapping 55 | loss_fn_kwargs={'temperature': 1.0, 'scale_loss': False}, 56 | optimizer='adamw', # Use SGD optimizer 57 | lr_scheduler='cosine', # Use StepLR learning rate scheduler 58 | weight_decay=0.01, 59 | momentum=0.9, 60 | num_eval=5, 61 | data_aug_func='cutmix', # Use DSA data augmentation 62 | aug_params=cutmix_params, # Specify dsa parameters 63 | im_size=im_size, 64 | num_epochs=300, 65 | num_workers=4, 66 | stu_use_torchvision=True, 67 | tea_use_torchvision=True, 68 | random_data_format='tensor', 69 | random_data_path='./random_data', 70 | custom_train_trans=custom_train_trans, 71 | custom_val_trans=custom_val_trans, 72 | batch_size=256, 73 | teacher_dir='./teacher_models', 74 | teacher_model_name=['ResNet-18-BN'], 75 | device=device, 76 | dist=True, 77 | save_path=save_path 78 | ) 79 | print(aug_evaluator.compute_metrics(image_path=syn_data_dir, syn_lr=0.001)) 80 | -------------------------------------------------------------------------------- /demo_hard.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import warnings 4 | from ddranking.metrics import LabelRobustScoreHard 5 | from ddranking.config import Config 6 | warnings.filterwarnings("ignore") 7 | 8 | 9 | """ Use config file to specify the arguments (Recommended) """ 10 | config = Config.from_file("./configs/Demo_LRS_Hard_Label.yaml") 11 | hard_label_evaluator = LabelRobustScoreHard(config) 12 | 13 | syn_data_dir = "./baselines/DM/CIFAR10/IPC10/" 14 | syn_images = torch.load(os.path.join(syn_data_dir, f"images.pt"), map_location='cpu') 15 | syn_lr = 0.01 16 | print(hard_label_evaluator.compute_metrics(image_tensor=syn_images, syn_lr=syn_lr)) 17 | 18 | 19 | """ Use keyword arguments """ 20 | device = "cuda" 21 | method_name = "DM" # Specify your method name 22 | ipc = 10 # Specify your IPC 23 | dataset = "CIFAR10" # Specify your dataset name 24 | data_dir = "./datasets" # Specify your dataset path 25 | syn_data_dir = "./DM/CIFAR10/IPC10/" # Specify your synthetic data path 26 | model_name = "ConvNet-3" # Specify your model name 27 | im_size = (32, 32) # Specify your image size 28 | 29 | dsa_params = { 30 | "prob_flip": 0.5, 31 | "ratio_rotate": 15.0, 32 | "saturation": 2.0, 33 | "brightness": 1.0, 34 | "contrast": 0.5, 35 | "ratio_scale": 1.2, 36 | "ratio_crop_pad": 0.125, 37 | "ratio_cutout": 0.5, 38 | } 39 | 40 | syn_images = torch.load(os.path.join(syn_data_dir, f"images.pt"), map_location='cpu') 41 | save_path = f"./results/{dataset}/{model_name}/IPC{ipc}/dm_hard_scores.csv" 42 | hard_label_evaluator = LabelRobustScoreHard( 43 | dataset=dataset, 44 | real_data_path=data_dir, 45 | ipc=ipc, 46 | model_name=model_name, 47 | optimizer='sgd', # Use SGD optimizer 48 | lr_scheduler='step', # Use StepLR learning rate scheduler 49 | weight_decay=0.0005, 50 | momentum=0.9, 51 | use_zca=False, 52 | num_eval=5, 53 | data_aug_func='dsa', # Use DSA data augmentation 54 | aug_params=dsa_params, # Specify DSA parameters 55 | im_size=im_size, 56 | num_epochs=1000, 57 | num_workers=4, 58 | use_torchvision=False, 59 | syn_batch_size=256, 60 | real_batch_size=256, 61 | custom_train_trans=None, 62 | custom_val_trans=None, 63 | device=device, 64 | dist=True, 65 | save_path=save_path, 66 | random_data_format='tensor', 67 | random_data_path='./random_data', 68 | eval_full_data=True, 69 | ) 70 | print(hard_label_evaluator.compute_metrics(image_tensor=syn_images, syn_lr=0.01)) 71 | -------------------------------------------------------------------------------- /demo_soft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import warnings 4 | from ddranking.metrics import LabelRobustScoreSoft 5 | from ddranking.config import Config 6 | warnings.filterwarnings("ignore") 7 | 8 | 9 | """ Use config file to specify the arguments (Recommended) """ 10 | config = Config.from_file("./configs/Demo_LRS_Soft_Label.yaml") 11 | soft_label_evaluator = LabelRobustScoreSoft(config) 12 | 13 | syn_data_dir = "./baselines/DATM/CIFAR10/IPC10/" 14 | syn_images = torch.load(os.path.join(syn_data_dir, f"images.pt"), map_location='cpu') 15 | soft_labels = torch.load(os.path.join(syn_data_dir, f"labels.pt"), map_location='cpu') 16 | syn_lr = torch.load(os.path.join(syn_data_dir, f"lr.pt"), map_location='cpu') 17 | print(soft_label_evaluator.compute_metrics(image_tensor=syn_images, soft_labels=soft_labels, syn_lr=syn_lr)) 18 | 19 | 20 | """ Use keyword arguments """ 21 | device = "cuda" 22 | method_name = "DATM" # Specify your method name 23 | ipc = 10 # Specify your IPC 24 | dataset = "CIFAR10" # Specify your dataset name 25 | syn_data_dir = "./DATM/CIFAR10/IPC10/" # Specify your synthetic data path 26 | data_dir = "./datasets" # Specify your dataset path 27 | model_name = "ConvNet-3" # Specify your model name 28 | im_size = (32, 32) # Specify your image size 29 | dsa_params = { # Specify your data augmentation parameters 30 | "flip": 0.5, 31 | "rotate": 15.0, 32 | "saturation": 2.0, 33 | "brightness": 1.0, 34 | "contrast": 0.5, 35 | "scale": 1.2, 36 | "crop": 0.125, 37 | "cutout": 0.5 38 | } 39 | 40 | syn_images = torch.load(os.path.join(syn_data_dir, f"images.pt"), map_location='cpu') 41 | soft_labels = torch.load(os.path.join(syn_data_dir, f"labels.pt"), map_location='cpu') 42 | syn_lr = torch.load(os.path.join(syn_data_dir, f"lr.pt"), map_location='cpu') 43 | save_path = f"./results/{dataset}/{model_name}/IPC{ipc}/dm_hard_scores.csv" 44 | soft_label_evaluator = LabelRobustScoreSoft( 45 | dataset=dataset, 46 | real_data_path=data_dir, 47 | ipc=ipc, 48 | model_name=model_name, 49 | soft_label_criterion='sce', # Use Soft Cross Entropy Loss 50 | soft_label_mode='S', # Use one-to-one image to soft label mapping 51 | loss_fn_kwargs={'temperature': 1.0, 'scale_loss': False}, 52 | optimizer='sgd', # Use SGD optimizer 53 | lr_scheduler='step', # Use StepLR learning rate scheduler 54 | step_size=500, 55 | weight_decay=0.0005, 56 | momentum=0.9, 57 | use_zca=True, # Use ZCA whitening (please disable it if you didn't use it to distill synthetic data) 58 | num_eval=5, 59 | data_aug_func='dsa', # Use DSA data augmentation 60 | aug_params=dsa_params, # Specify dsa parameters 61 | im_size=im_size, 62 | num_epochs=1000, 63 | num_workers=4, 64 | eval_full_data=True, 65 | stu_use_torchvision=False, 66 | tea_use_torchvision=False, 67 | random_data_format='tensor', 68 | random_data_path='./random_data', 69 | custom_train_trans=None, 70 | custom_val_trans=None, 71 | syn_batch_size=256, 72 | real_batch_size=256, 73 | teacher_dir='./teacher_models', 74 | teacher_model_name=['ConvNet-3'], 75 | device=device, 76 | dist=True, 77 | save_path=save_path 78 | ) 79 | print(soft_label_evaluator.compute_metrics(image_tensor=syn_images, soft_labels=soft_labels, syn_lr=syn_lr)) 80 | -------------------------------------------------------------------------------- /dist/ddranking-0.2.0-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/dist/ddranking-0.2.0-py3-none-any.whl -------------------------------------------------------------------------------- /dist/ddranking-0.2.0.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/dist/ddranking-0.2.0.tar.gz -------------------------------------------------------------------------------- /doc/SUMMARY.md: -------------------------------------------------------------------------------- 1 | # Summary 2 | 3 | [Introduction](introduction.md) 4 | [Contributing](contributing.md) 5 | 6 | # Getting Started 7 | 8 | - [Installation](getting-started/installation.md) 9 | - [Quick Start](getting-started/quick-start.md) 10 | 11 | # Reference Guide 12 | 13 | - [Metrics](metrics/overview.md) 14 | - [LabelRobustScoreHard](metrics/lrs-hard-label.md) 15 | - [LabelRobustScoreSoft](metrics/lrs-soft-label.md) 16 | - [AugmentationRobustScore](metrics/ars.md) 17 | - [GeneralEvaluator](metrics/general.md) 18 | 19 | - [Augmentations](augmentations/overview.md) 20 | - [DSA](augmentations/dsa.md) 21 | - [CutMix](augmentations/cutmix.md) 22 | - [Mixup](augmentations/mixup.md) 23 | 24 | - [Models](models/overview.md) 25 | - [ConvNet](models/convnet.md) 26 | - [AlexNet](models/alexnet.md) 27 | - [ResNet](models/resnet.md) 28 | - [LeNet](models/lenet.md) 29 | - [VGG](models/vgg.md) 30 | - [MLP](models/mlp.md) 31 | 32 | - [Datasets](datasets/overview.md) 33 | 34 | - [Config](config/overview.md) -------------------------------------------------------------------------------- /doc/augmentations/cutmix.md: -------------------------------------------------------------------------------- 1 | ## Cutmix 2 | 3 | Cutmix is a data augmentation technique that creates new samples by combining patches from two images while blending their labels proportionally to the area of the patches.. We follow the implementation of cutmix in [SRe2L](https://github.com/VILA-Lab/SRe2L/tree/main/SRe2L). 4 | 5 |
6 | 7 | CLASS 8 | ddranking.aug.Cutmix(params: dict) 9 | [**[SOURCE]**](https://github.com/NUS-HPC-AI-Lab/DD-Ranking/blob/main/ddranking/aug/cutmix.py) 10 | 11 |
12 | 13 | ### Parameters 14 | 15 | - **params**(dict): Parameters for the cutmix augmentation. We require the parameters to be in the format of `{'param_name': param_value}`. For cutmix, only `beta` (beta distribution parameter) needs to be specified, e.g. `{'beta': 1.0}`. 16 | 17 | ### Example 18 | 19 | ```python 20 | # When intializing an evaluator with cutmix augmentation, and cutmix object will be constructed. 21 | >>> self.aug_func = Cutmix(params={'beta': 1.0}) 22 | 23 | # During training, the cutmix object will be used to augment the data. 24 | >>> images = aug_func(images) 25 | ``` -------------------------------------------------------------------------------- /doc/augmentations/dsa.md: -------------------------------------------------------------------------------- 1 | ## Differentiable Siamese Augmentation (DSA) 2 | 3 | DSA is one of differentiable data augmentations, first used in the dataset distillation task by [DSA](https://github.com/VICO-UoE/DatasetCondensation). 4 | Our implementation of DSA is adopted from [DSA](https://github.com/VICO-UoE/DatasetCondensation). It supports the following differentiable augmentations: 5 | 6 | - Random Flip 7 | - Random Rotation 8 | - Random Saturation 9 | - Random Brightness 10 | - Random Contrast 11 | - Random Scale 12 | - Random Crop 13 | - Random Cutout 14 | 15 |
16 | 17 | CLASS 18 | ddranking.aug.DSA(params: dict, seed: int, aug_mode: str) 19 | [**[SOURCE]**](https://github.com/NUS-HPC-AI-Lab/DD-Ranking/blob/main/ddranking/aug/dsa.py) 20 | 21 |
22 | 23 | ### Parameters 24 | 25 | - **params**(dict): Parameters for the DSA augmentations. We require the parameters to be in the format of `{'param_name': param_value}`. For example, `{'flip': 0.5, 'rotate': 15.0, 'scale': 1.2, 'crop': 0.125, 'cutout': 0.5, 'brightness': 1.0, 'contrast': 0.5, 'saturation': 2.0}`. 26 | - **seed**(int): Random seed. Default is `-1`. 27 | - **aug_mode**(str): `S` for randomly selecting one augmentation for each batch. `M` for applying all augmentations for each batch. 28 | 29 | ### Example 30 | 31 | ```python 32 | # When intializing an evaluator with DSA augmentation, and DSA object will be constructed. 33 | >>> self.aug_func = DSA(params={'flip': 0.5, 'rotate': 15.0, 'scale': 1.2, 'crop': 0.125, 'cutout': 0.5, 'brightness': 1.0, 'contrast': 0.5, 'saturation': 2.0}, seed=-1, aug_mode='S') 34 | 35 | # During training, the DSA object will be used to augment the data. 36 | >>> images = aug_func(images) 37 | ``` 38 | -------------------------------------------------------------------------------- /doc/augmentations/mixup.md: -------------------------------------------------------------------------------- 1 | ## Mixup 2 | 3 | Mixup is a data augmentation technique that generates new training samples by linearly interpolating pairs of images. We follow the implementation of mixup in [SRe2L](https://github.com/VILA-Lab/SRe2L/tree/main/SRe2L). 4 | 5 |
6 | 7 | CLASS 8 | ddranking.aug.Mixup(params: dict) 9 | [**[SOURCE]**](https://github.com/NUS-HPC-AI-Lab/DD-Ranking/blob/main/ddranking/aug/mixup.py) 10 | 11 |
12 | 13 | ### Parameters 14 | 15 | - **params**(dict): Parameters for the mixup augmentation. We require the parameters to be in the format of `{'param_name': param_value}`. For mixup, only `lambda` (mixup strength) needs to be specified, e.g. `{'lambda': 0.8}`. 16 | 17 | ### Example 18 | 19 | ```python 20 | # When intializing an evaluator with mixup augmentation, and mixup object will be constructed. 21 | >>> self.aug_func = Mixup(params={'lambda': 0.8}) 22 | 23 | # During training, the mixup object will be used to augment the data. 24 | >>> images = aug_func(images) 25 | ``` -------------------------------------------------------------------------------- /doc/augmentations/overview.md: -------------------------------------------------------------------------------- 1 | # Augmentations 2 | 3 | DD-Ranking supports commonly used data augmentations in existing methods. A list of augmentations is provided below: 4 | 5 | - [Torchvision transforms](https://pytorch.org/vision/stable/transforms.html) 6 | - [DSA](datm.md) 7 | - [Mixup](mixup.md) 8 | - [Cutmix](cutmix.md) 9 | 10 | In DD-Ranking, data augmentations are specified when initializing an evaluator. 11 | The following arguments are related to data augmentations: 12 | 13 | - **data_aug_func**(str): The name of the data augmentation function used during training. Currently, we support `dsa`, `mixup`, `cutmix`. 14 | - **aug_params**(dict): The parameters for the data augmentation function. 15 | - **custom_train_trans**(torchvision.transforms.Compose): The custom train transform used to load the synthetic data when it's in '.jpg' or '.png' format. 16 | - **custom_val_trans**(torchvision.transforms.Compose): The custom val transform used to load the test dataset. 17 | - **use_zca**(bool): Whether to use ZCA whitening for the data augmentation. This is only applicable to methods that use ZCA whitening during distillation. 18 | 19 | ```python 20 | # When initializing an evaluator, the data augmentation function is specified. 21 | >>> evaluator = SoftLabelEvaluator( 22 | ... 23 | data_aug_func=..., # Specify the data augmentation function 24 | aug_params=..., # Specify the parameters for the data augmentation function 25 | custom_train_trans=..., # Specify the custom train transform 26 | custom_val_trans=..., # Specify the custom val transform 27 | use_zca=..., # Specify whether to use ZCA whitening 28 | ... 29 | ) 30 | ``` -------------------------------------------------------------------------------- /doc/augmentations/torchvision.md: -------------------------------------------------------------------------------- 1 | ## Torchvision Transfoms 2 | 3 | We notice that some methods use jpg or png format images instead of image tensors during evaluation, and apply an additional torchvision-based tranformation to preprocess these images. Also, they may apply different augmentations to the test dataset. Thus, we support the torchvision-based transformations in DD-Ranking for both synthetic and real data. 4 | 5 | We require the torchvision-based transformations to be a `torchvision.transforms.Compose` object. If you have customized transformations, please make sure they have a `__call__` method. For the list of torchvision transformations, please refer to [torchvision-transforms](https://pytorch.org/vision/stable/transforms.html). 6 | 7 | ### Example 8 | 9 | ```python 10 | # Define a custom transformation 11 | class MyTransform: 12 | def __init__(self): 13 | pass 14 | 15 | def __call__(self, x): 16 | return x 17 | 18 | custom_train_trans = torchvision.transforms.Compose([ 19 | MyTransform(), 20 | torchvision.transforms.RandomResizedCrop(32), 21 | torchvision.transforms.ToTensor(), 22 | torchvision.transforms.Normalize(mean, std) 23 | ]) 24 | custom_val_trans = torchvision.transforms.Compose([ 25 | torchvision.transforms.RandomCrop(32, padding=4), 26 | torchvision.transforms.ToTensor(), 27 | torchvision.transforms.Normalize(mean, std) 28 | ]) 29 | 30 | evaluator = DD_Ranking( 31 | ... 32 | custom_train_trans=custom_train_trans, 33 | custom_val_trans=custom_val_trans, 34 | ... 35 | ) 36 | ``` 37 | -------------------------------------------------------------------------------- /doc/config/overview.md: -------------------------------------------------------------------------------- 1 | # Config 2 | 3 | To ease the usage of DD-Ranking, we allow users to specify the parameters of the evaluator in a config file. The config file is a YAML file that contains the parameters of the evaluator. We illustrate the config file with the following examples. 4 | 5 | ## LRS 6 | 7 | ```yaml 8 | dataset: CIFAR100 # dataset name 9 | real_data_path: ./dataset/ # path to the real dataset 10 | ipc: 10 # image per class 11 | im_size: [32, 32] # image size 12 | model_name: ResNet-18-BN # model name 13 | stu_use_torchvision: true # whether to use torchvision to load student model 14 | 15 | tea_use_torchvision: true # whether to use torchvision to load teacher model 16 | 17 | teacher_dir: ./teacher_models # path to the pretrained teacher model 18 | teacher_model_names: [ResNet-18-BN] # the list of teacher models being used for evaluation 19 | 20 | data_aug_func: mixup # data augmentation function 21 | aug_params: 22 | lambda: 0.8 # data augmentation parameter; please follow this format for other parameters 23 | 24 | use_zca: false # whether to use ZCA whitening 25 | use_aug_for_hard: false # whether to use data augmentation for hard label evaluation 26 | 27 | custom_train_trans: # custom torchvision-based transformations to process training data; please follow this format for your own transformations 28 | - name: RandomCrop 29 | args: 30 | size: 32 31 | padding: 4 32 | - name: RandomHorizontalFlip 33 | args: 34 | p: 0.5 35 | - name: ToTensor 36 | - name: Normalize 37 | args: 38 | mean: [0.4914, 0.4822, 0.4465] 39 | std: [0.2023, 0.1994, 0.2010] 40 | 41 | custom_val_trans: null # custom torchvision-based transformations to process validation data; please follow the format above for your own transformations 42 | 43 | soft_label_mode: M # soft label mode 44 | soft_label_criterion: kl # soft label criterion 45 | loss_fn_kwargs: 46 | temperature: 30.0 # temperature for soft label 47 | scale_loss: false # whether to scale the loss 48 | 49 | optimizer: adamw # optimizer 50 | lr_scheduler: cosine # learning rate scheduler 51 | weight_decay: 0.01 # weight decay 52 | momentum: 0.9 # momentum 53 | num_eval: 5 # number of evaluations 54 | eval_full_data: false # whether to compute the test accuracy on the full dataset 55 | num_epochs: 400 # number of training epochs 56 | num_workers: 4 # number of workers 57 | device: cuda # device 58 | dist: true # whether to use distributed training 59 | syn_batch_size: 256 # batch size for synthetic data 60 | real_batch_size: 256 # batch size for real data 61 | save_path: ./results.csv # path to save the results 62 | 63 | random_data_format: tensor # format of the random data, tensor or image 64 | random_data_path: ./random_data # path to the save the random data 65 | 66 | ``` 67 | 68 | To use config file, you can follow the example below. 69 | 70 | ```python 71 | from dd_ranking.metrics import LabelRobustScoreSoft 72 | 73 | config = Config(config_path='./config.yaml') 74 | evaluator = LabelRobustScoreSoft(config) 75 | ``` 76 | 77 | 78 | ## ARS 79 | 80 | ```yaml 81 | 82 | ``` 83 | 84 | -------------------------------------------------------------------------------- /doc/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Welcome! We are glad that you by willing to contribute to the field of dataset distillation. 4 | 5 | - **New Baselines**: If you would like to report new baselines, please submit them by creating a pull request. The exact format is below: name of the baseline, code link, [paper link and score run using this tool]. 6 | 7 | - **New Components**: If you would like to integrate new components, such as new model architectures, new data augmentation methods, and new soft label strategies, please submit them by creating a pull request. 8 | 9 | - **Issues**: If you want to submit issues, you are encouraged to submit yes directly in issues. 10 | 11 | - **Appeal**: If you want to appeal for the score of your method, please submit an issue with your code and a detailed readme file of how to reproduce your results. We tried our best to replicate all methods in the leaderboard based on their papers and open-source code. We are sorry if we miss some details and will be grateful if you can help us improve the leaderboard. 12 | -------------------------------------------------------------------------------- /doc/datasets/overview.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | 3 | DD-Ranking provides a set of commonly used datasets in existing dataset distillation methods. Users can flexibly use these datasets for evaluation. The interface to load datasets is as follows: 4 | 5 |
6 | 7 | ddranking.utils.get_dataset(dataset: str, data_path: str, im_size: tuple, use_zca: bool, custom_val_trans: Optional[Callable], device: str) 8 | [**[SOURCE]**](https://github.com/NUS-HPC-AI-Lab/DD-Ranking/blob/main/ddranking/utils/data.py) 9 |
10 | 11 | ### Parameters 12 | 13 | - **dataset**(str): Name of the dataset. 14 | - **data_path**(str): Path to the dataset. 15 | - **im_size**(tuple): Image size. 16 | - **use_zca**(bool): Whether to use ZCA whitening. When set to True, the dataset will **not be** normalized using the mean and standard deviation of the training set. 17 | - **custom_train_trans**(Optional[Callable]): Custom transformation on the training set. 18 | - **custom_val_trans**(Optional[Callable]): Custom transformation on the validation set. 19 | - **device**(str): Device for performing ZCA whitening. 20 | 21 | Currently, we support the following datasets with default settings. We will keep updating this section with more datasets. 22 | 23 | - **CIFAR10** 24 | - **channels**: `3` 25 | - **im_size**: `(32, 32)` 26 | - **num_classes**: `10` 27 | - **mean**: `[0.4914, 0.4822, 0.4465]` 28 | - **std**: `[0.2023, 0.1994, 0.2010]` 29 | - **CIFAR100** 30 | - **channels**: `3` 31 | - **im_size**: `(32, 32)` 32 | - **num_classes**: `100` 33 | - **mean**: `[0.4914, 0.4822, 0.4465]` 34 | - **std**: `[0.2023, 0.1994, 0.2010]` 35 | - **TinyImageNet** 36 | - **channels**: `3` 37 | - **im_size**: `(64, 64)` 38 | - **num_classes**: `200` 39 | - **mean**: `[0.485, 0.456, 0.406]` 40 | - **std**: `[0.229, 0.224, 0.225]` 41 | - **ImageNet1K** 42 | - **channels**: `3` 43 | - **im_size**: `(224, 224)` 44 | - **num_classes**: `1000` 45 | - **mean**: `[0.485, 0.456, 0.406]` 46 | - **std**: `[0.229, 0.224, 0.225]` 47 | -------------------------------------------------------------------------------- /doc/getting-started/installation.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | From pip 4 | 5 | ```bash 6 | pip install ddranking 7 | ``` 8 | 9 | From source 10 | 11 | ```bash 12 | python setup.py install 13 | ``` -------------------------------------------------------------------------------- /doc/getting-started/quick-start.md: -------------------------------------------------------------------------------- 1 | ## Quick Start 2 | 3 | Below is a step-by-step guide on how to use our `dd_ranking`. This demo is for label-robust score (LRS) on soft labels (source code can be found in `demo_lrs_soft.py`). You can find the demo for LRS on hard label demo in `demo_lrs_hard.py` and the demo for augmentation-robust score (ARS) in `demo_ars.py`. 4 | DD-Ranking supports multi-GPU Distributed evaluation. You can simply use `torchrun` to launch the evaluation. 5 | 6 | **Step1**: Intialize a soft-label metric evaluator object. Config files are recommended for users to specify hyper-parameters. Sample config files are provided [here](https://github.com/NUS-HPC-AI-Lab/DD-Ranking/tree/main/configs). 7 | 8 | ```python 9 | from ddranking.metrics import LabelRobustScoreSoft 10 | from ddranking.config import Config 11 | 12 | >>> config = Config.from_file("./configs/Demo_LRS_Soft_Label.yaml") 13 | >>> lrs_soft_metric = LabelRobustScoreSoft(config) 14 | ``` 15 | 16 |
17 | You can also pass keyword arguments. 18 | 19 | ```python 20 | device = "cuda" 21 | method_name = "DATM" # Specify your method name 22 | ipc = 10 # Specify your IPC 23 | dataset = "CIFAR100" # Specify your dataset name 24 | syn_data_dir = "./data/CIFAR100/IPC10/" # Specify your synthetic data path 25 | real_data_dir = "./datasets" # Specify your dataset path 26 | model_name = "ConvNet-3" # Specify your model name 27 | teacher_dir = "./teacher_models" # Specify your path to teacher model chcekpoints 28 | teacher_model_names = ["ConvNet-3"] # Specify your teacher model names 29 | im_size = (32, 32) # Specify your image size 30 | dsa_params = { # Specify your data augmentation parameters 31 | "prob_flip": 0.5, 32 | "ratio_rotate": 15.0, 33 | "saturation": 2.0, 34 | "brightness": 1.0, 35 | "contrast": 0.5, 36 | "ratio_scale": 1.2, 37 | "ratio_crop_pad": 0.125, 38 | "ratio_cutout": 0.5 39 | } 40 | random_data_format = "tensor" # Specify your random data format (tensor or image) 41 | random_data_path = "./random_data" # Specify your random data path 42 | save_path = f"./results/{dataset}/{model_name}/IPC{ipc}/dm_hard_scores.csv" 43 | 44 | """ We only list arguments that usually need specifying""" 45 | lrs_soft_metric = LabelRobustScoreSoft( 46 | dataset=dataset, 47 | real_data_path=real_data_dir, 48 | ipc=ipc, 49 | model_name=model_name, 50 | soft_label_criterion='sce', # Use Soft Cross Entropy Loss 51 | soft_label_mode='S', # Use one-to-one image to soft label mapping 52 | loss_fn_kwargs={'temperature': 1.0, 'scale_loss': False}, 53 | data_aug_func='dsa', # Use DSA data augmentation 54 | aug_params=dsa_params, # Specify dsa parameters 55 | im_size=im_size, 56 | random_data_format=random_data_format, 57 | random_data_path=random_data_path, 58 | stu_use_torchvision=False, 59 | tea_use_torchvision=False, 60 | teacher_dir=teacher_dir, 61 | teacher_model_names=teacher_model_names, 62 | num_eval=5, 63 | device=device, 64 | dist=True, 65 | save_path=save_path 66 | ) 67 | ``` 68 |
69 | 70 | For detailed explanation for hyper-parameters, please refer to our documentation. 71 | 72 | **Step 2:** Load your synthetic data, labels (if any), and learning rate (if any). 73 | 74 | ```python 75 | >>> syn_images = torch.load('/your/path/to/syn/images.pt') 76 | # You must specify your soft labels if your soft label mode is 'S' 77 | >>> soft_labels = torch.load('/your/path/to/syn/labels.pt') 78 | >>> syn_lr = torch.load('/your/path/to/syn/lr.pt') 79 | ``` 80 | 81 | **Step 3:** Compute the metric. 82 | 83 | ```python 84 | >>> lrs_soft_metric.compute_metrics(image_tensor=syn_images, soft_labels=soft_labels, syn_lr=syn_lr) 85 | # alternatively, you can specify the image folder path to compute the metric 86 | >>> lrs_soft_metric.compute_metrics(image_path='./your/path/to/syn/images', soft_labels=soft_labels, syn_lr=syn_lr) 87 | ``` 88 | 89 | The following results will be printed and saved to `save_path`: 90 | - `HLR mean`: The mean of hard label recovery over `num_eval` runs. 91 | - `HLR std`: The standard deviation of hard label recovery over `num_eval` runs. 92 | - `IOR mean`: The mean of improvement over random over `num_eval` runs. 93 | - `IOR std`: The standard deviation of improvement over random over `num_eval` runs. 94 | - `LRS mean`: The mean of Label-Robust Score over `num_eval` runs. 95 | - `LRS std`: The standard deviation of Label-Robust Score over `num_eval` runs. 96 | 98 | -------------------------------------------------------------------------------- /doc/introduction.md: -------------------------------------------------------------------------------- 1 | 2 | logo 3 | 4 | 5 | 6 | 7 | [![GitHub stars](https://img.shields.io/github/stars/NUS-HPC-AI-Lab/DD-Ranking?style=flat&logo=github)](https://github.com/NUS-HPC-AI-Lab/DD-Ranking) 8 | [![Hugging Face](https://img.shields.io/badge/🤗%20Hugging%20Face-Leaderboard-yellow?style=flat)](https://huggingface.co/spaces/logits/DD-Ranking) 9 | [![Twitter](https://img.shields.io/badge/Twitter-Follow-blue?style=flat&logo=twitter)](https://x.com/Richard91316073/status/1890296645486801230) 10 | 11 | 12 | 13 | Welcome to **DD-Ranking** (DD, *i.e.*, Dataset Distillation), an integrated and easy-to-use evaluation benchmark for dataset distillation! It aims to provide a fair evaluation scheme for DD methods that can decouple the impacts from knowledge distillation and data augmentation to reflect the real informativeness of the distilled data. 14 | 15 | ## Motivation 16 | Dataset Distillation (DD) aims to condense a large dataset into a much smaller one, which allows a model to achieve comparable performance after training on it. DD has gained extensive attention since it was proposed. With some foundational methods such as DC, DM, and MTT, various works have further pushed this area to a new standard with their novel designs. 17 | 18 | ![history](static/history.png) 19 | 20 | Notebaly, more and more methods are transitting from "hard label" to "soft label" in dataset distillation, especially during evaluation. **Hard labels** are categorical, having the same format of the real dataset. **Soft labels** are outputs of a pre-trained teacher model. 21 | Recently, Deng et al., pointed out that "a label is worth a thousand images". They showed analytically that soft labels are exetremely useful for accuracy improvement. 22 | 23 | However, since the essence of soft labels is **knowledge distillation**, we find that when applying the same evaluation method to randomly selected data, the test accuracy also improves significantly (see the figure above). 24 | 25 | This makes us wonder: **Can the test accuracy of the model trained on distilled data reflect the real informativeness of the distilled data?** 26 | 27 | We summaize the evaluation configurations of existing works in the following table, with different colors highlighting different values for each configuration. 28 | ![configurations](./static/configurations.png) 29 | As can be easily seen, the evaluation configurations are diverse, leading to unfairness of using only test accuracy to demonstrate one's performance. 30 | Among these inconsistencies, two critical factors significantly undermine the fairness of current evaluation protocols: label representation (including the corresponding loss function) and data augmentation techniques. 31 | 32 | Motivated by this, we propose DD-Ranking, a new benchmark for DD evaluation. DD-Ranking provides a fair evaluation scheme for DD methods that can decouple the impacts from knowledge distillation and data augmentation to reflect the real informativeness of the distilled data. 33 | 34 | ## Features 35 | 36 | - **Fair Evaluation**: DD-Ranking provides a fair evaluation scheme for DD methods that can decouple the impacts from knowledge distillation and data augmentation to reflect the real informativeness of the distilled data. 37 | - **Easy-to-use**: DD-Ranking provides a unified interface for dataset distillation evaluation. 38 | - **Extensible**: DD-Ranking supports various datasets and models. 39 | - **Customizable**: DD-Ranking supports various data augmentations and soft label strategies. 40 | 41 | ## DD-Ranking Benchmark 42 | 43 | Revisit the original goal of dataset distillation: 44 | > The idea is to synthesize a small number of data points that do not need to come from the correct data distribution, but will, when given to the learning algorithm as training data, approximate the model trained on the original data. (Wang et al., 2020) 45 | > 46 | 47 | ### Label-Robust Score (LRS) 48 | For the label representation, we introduce the Label-Robust Score (LRS) to evaluate the informativeness of the synthesized data using the following two aspects: 49 | 1. The degree to which the real dataset is recovered under hard labels (hard label recovery): \\( \text{HLR}=\text{Acc.}{\text{real-hard}}-\text{Acc.}{\text{syn-hard}} \\). 50 | 51 | 2. The improvement over random selection when using personalized evaluation methods (improvement over random): \\( \text{IOR}=\text{Acc.}{\text{syn-any}}-\text{Acc.}{\text{rdm-any}} \\). 52 | \\(\text{Acc.}\\) is the accuracy of models trained on different samples. Samples' marks are as follows: 53 | - \\(\text{real-hard}\\): Real dataset with hard labels; 54 | - \\(\text{syn-hard}\\): Synthetic dataset with hard labels; 55 | - \\(\text{syn-any}\\): Synthetic dataset with personalized evaluation methods (hard or soft labels); 56 | - \\(\text{rdm-any}\\): Randomly selected dataset (under the same compression ratio) with the same personalized evaluation methods. 57 | 58 | LRS is defined as a weight sum of \\(\text{IOR}\\) and \\(-\text{HLR}\\) to rank different methods: 59 | \\[ 60 | \alpha = w\text{IOR}-(1-w)\text{HLR}, \quad w \in [0, 1] 61 | \\] 62 | Then, the LRS is normalized to \\([0, 1]\\) as follows: 63 | \\[ 64 | \text{LRS} = 100\% \times (e^{\alpha}-e^{-1}) / (e - e^{-1}) 65 | \\] 66 | 67 | By default, we set \\(w = 0.5\\) on the leaderboard, meaning that both \\(\text{IOR}\\) and \\(\text{HLR}\\) are equally important. Users can adjust the weights to emphasize one aspect on the leaderboard. 68 | 69 | ### Augmentation-Robust Score (ARS) 70 | To disentangle data augmentation’s impact, we introduce the augmentation-robust score (ARS) which continues to leverage the relative improvement over randomly selected data. Specifically, we first evaluate synthetic data and a randomly selected subset under the same setting to obtain \\(\text{Acc.}{\text{syn-aug}}\\) and \\(\text{Acc.}{\text{rdm-aug}}\\) (same as IOR). Next, we evaluate both synthetic data and random data again without the data augmentation, and results are denoted as \\(\text{Acc.}{\text{syn-naug}}\\) and \\(\text{Acc.}{\text{rdm-naug}}\\). 71 | Both differences, \\(\text{Acc.syn-aug} - \text{Acc.rdm-aug}\\) and \\(\text{Acc.syn-naug} - \text{Acc.rdm-naug}\\), are positively correlated to the real informativeness of the distilled dataset. 72 | 73 | ARS is a weighted sum of the two differences: 74 | \\[ 75 | \beta = \gamma(\text{Acc.syn-aug} - \text{Acc.rdm-aug}) + (1 - \gamma)(\text{Acc.syn-naug} - \text{Acc.rdm-naug}) 76 | \\] 77 | and normalized similarly. 78 | 79 | -------------------------------------------------------------------------------- /doc/metrics/ars.md: -------------------------------------------------------------------------------- 1 | ## AugmentationRobustScore 2 | 3 |
4 | 5 | CLASS 6 | dd_ranking.metrics.AugmentationRobustScore(config: Optional[Config] = None, 7 | dataset: str = 'ImageNet1K', 8 | real_data_path: str = './dataset/', 9 | ipc: int = 10, 10 | model_name: str = 'ResNet-18-BN', 11 | label_type: str = 'soft', 12 | soft_label_mode: str='S', 13 | soft_label_criterion: str='kl', 14 | loss_fn_kwargs: dict=None, 15 | data_aug_func: str='cutmix', 16 | aug_params: dict={'cutmix_p': 1.0}, 17 | optimizer: str='sgd', 18 | lr_scheduler: str='step', 19 | weight_decay: float=0.0005, 20 | momentum: float=0.9, 21 | step_size: int=None, 22 | num_eval: int=5, 23 | im_size: tuple=(224, 224), 24 | num_epochs: int=300, 25 | use_zca: bool=False, 26 | random_data_format: str='image', 27 | random_data_path: str=None, 28 | batch_size: int=256, 29 | save_path: str=None, 30 | stu_use_torchvision: bool=False, 31 | tea_use_torchvision: bool=False, 32 | num_workers: int=4, 33 | teacher_dir: str='./teacher_models', 34 | teacher_model_names: list=None, 35 | custom_train_trans: Optional[Callable]=None, 36 | custom_val_trans: Optional[Callable]=None, 37 | device: str="cuda", 38 | dist: bool=False 39 | ) 40 | [**[SOURCE]**](https://github.com/NUS-HPC-AI-Lab/DD-Ranking/blob/main/ddranking/metrics/aug_robust.py) 41 |
42 | 43 | A class for evaluating the performance of a dataset distillation method with soft labels. User is able to modify the attributes as needed. 44 | 45 | ### Parameters 46 | 47 | - **config**(Optional[Config]): Config object for specifying all attributes. See [config](../config/overview.md) for more details. 48 | - **dataset**(str): Name of the real dataset. 49 | - **real_data_path**(str): Path to the real dataset. 50 | - **ipc**(int): Images per class. 51 | - **model_name**(str): Name of the surrogate model. See [models](../models/overview.md) for more details. 52 | - **label_type**(str): Type of label representation. `soft` for soft labels, `hard` for hard labels. 53 | - **soft_label_mode**(str): Number of soft labels per image. `S` for single soft label, `M` for multiple soft labels. 54 | - **soft_label_criterion**(str): Loss function for using soft labels. Currently supports `kl` for KL divergence, `sce` for soft cross-entropy, and `mse_gt` for MSEGT loss introduced in EDC. 55 | - **loss_fn_kwargs**(dict): Keyword arguments for the loss function, e.g. `temperature` and `scale_loss` for KL and SCE loss, and `mse_weight` and `ce_weight` for MSEGT loss. 56 | - **data_aug_func**(str): Data augmentation function used during training. Currently supports `dsa`, `cutmix`, `mixup`. See [augmentations](../augmentations/overview.md) for more details. 57 | - **aug_params**(dict): Parameters for the data augmentation function. 58 | - **optimizer**(str): Name of the optimizer. Currently supports torch-based optimizers - `sgd`, `adam`, and `adamw`. 59 | - **lr_scheduler**(str): Name of the learning rate scheduler. Currently supports torch-based schedulers - `step`, `cosine`, `lambda_step`, and `cosineannealing`. 60 | - **weight_decay**(float): Weight decay for the optimizer. 61 | - **momentum**(float): Momentum for the optimizer. 62 | - **step_size**(int): Step size for the learning rate scheduler. 63 | - **use_zca**(bool): Whether to use ZCA whitening. 64 | - **num_eval**(int): Number of evaluations to perform. 65 | - **im_size**(tuple): Size of the images. 66 | - **num_epochs**(int): Number of epochs to train. 67 | - **batch_size**(int): Batch size for the model training. 68 | - **stu_use_torchvision**(bool): Whether to use torchvision to initialize the student model. 69 | - **tea_use_torchvision**(bool): Whether to use torchvision to initialize the teacher model. 70 | - **teacher_dir**(str): Path to the teacher model. 71 | - **teacher_model_names**(list): List of teacher model names. 72 | - **random_data_format**(str): Format of the random data, `tensor` or `image`. 73 | - **random_data_path**(str): Path to save the random data. 74 | - **num_workers**(int): Number of workers for data loading. 75 | - **save_path**(Optional[str]): Path to save the results. 76 | - **custom_train_trans**(Optional[Callable]): Custom transformation function when loading synthetic data. Only support torchvision transformations. See [torchvision-based transformations](../augmentations/torchvision.md) for more details. 77 | - **custom_val_trans**(Optional[Callable]): Custom transformation function when loading test dataset. Only support torchvision transformations. See [torchvision-based transformations](../augmentations/torchvision.md) for more details. 78 | - **device**(str): Device to use for evaluation, `cuda` or `cpu`. 79 | - **dist**(bool): Whether to use distributed training. 80 | 81 | ### Methods 82 |
83 | 84 | compute_metrics(image_tensor: Tensor = None, image_path: str = None, soft_labels: Tensor = None, syn_lr: float = None, ars_lambda: float = 0.5) 85 |
86 | 87 |
88 | This method computes the ARS score for the given image and soft labels (if provided). In each evaluation round, we set a different random seed and perform the following steps: 89 | 90 | 1. Compute the test accuracy of the surrogate model on the synthetic dataset without data augmentation. 91 | 2. Compute the test accuracy of the surrogate model on the synthetic dataset with data augmentation. 92 | 3. Compute the test accuracy of the surrogate model on the randomly selected dataset without data augmentation. We perform learning rate tuning for the best performance. 93 | 4. Compute the test accuracy of the surrogate model on the randomly selected dataset with data augmentation. We perform learning rate tuning for the best performance. 94 | 5. Compute the ARS score. 95 | 96 | The final scores are the average of the scores from `num_eval` rounds. 97 | 98 | #### Parameters 99 | 100 | - **image_tensor**(Tensor): Image tensor. Must specify when `image_path` is not provided. We require the shape to be `(N x IPC, C, H, W)` where `N` is the number of classes. 101 | - **image_path**(str): Path to the image. Must specify when `image_tensor` is not provided. 102 | - **soft_labels**(Tensor): Soft label tensor. Must specify when `soft_label_mode` is `S`. The first dimension must be the same as `image_tensor`. 103 | - **syn_lr**(float): Learning rate for the synthetic dataset. If not specified, the learning rate will be tuned automatically. 104 | - **ars_lambda**(float): Weighting parameter for the ARS. 105 | 106 | #### Returns 107 | 108 | A dictionary with the following keys: 109 | 110 | - **with_aug_mean**: Mean of test accuracy scores with data augmentation from `num_eval` rounds. 111 | - **with_aug_std**: Standard deviation of test accuracy scores with data augmentation from `num_eval` rounds. 112 | - **without_aug_mean**: Mean of test accuracy scores without data augmentation from `num_eval` rounds. 113 | - **without_aug_std**: Standard deviation of test accuracy scores without data augmentation from `num_eval` rounds. 114 | - **augmentation_robust_score_mean**: Mean of ARS scores from `num_eval` rounds. 115 | - **augmentation_robust_score_std**: Standard deviation of ARS scores from `num_eval` rounds. 116 | 117 |
118 | 119 | ### Examples 120 | 121 | with config file: 122 | ```python 123 | >>> config = Config('/path/to/config.yaml') 124 | >>> evaluator = AugmentationRobustScore(config=config) 125 | # load image and soft labels 126 | >>> image_tensor, soft_labels = ... 127 | # compute metrics 128 | >>> evaluator.compute_metrics(image_tensor=image_tensor, soft_labels=soft_labels) 129 | # alternatively, provide image path 130 | >>> evaluator.compute_metrics(image_path='path/to/image/folder/', soft_labels=soft_labels) 131 | ``` 132 | 133 | with keyword arguments: 134 | ```python 135 | >>> evaluator = AugmentationRobustScore( 136 | ... dataset='ImageNet1K', 137 | ... real_data_path='./dataset/', 138 | ... ipc=10, 139 | ... model_name='ResNet-18-BN', 140 | ... label_type='soft', 141 | ... soft_label_mode='M', 142 | ... soft_label_criterion='kl', 143 | ... loss_fn_kwargs={ 144 | ... "temperature": 30.0, 145 | ... "scale_loss": False, 146 | ... }, 147 | ... data_aug_func='mixup', 148 | ... aug_params={ 149 | ... "mixup_p": 0.8, 150 | ... }, 151 | ... optimizer='adamw', 152 | ... lr_scheduler='cosine', 153 | ... num_epochs=300, 154 | ... weight_decay=0.0005, 155 | ... momentum=0.9, 156 | ... use_zca=False, 157 | ... stu_use_torchvision=True, 158 | ... tea_use_torchvision=True, 159 | ... num_workers=4, 160 | ... save_path='./results', 161 | ... random_data_format='image', 162 | ... random_data_path='./random_data', 163 | ... teacher_dir='./teacher_models', 164 | ... teacher_model_names=['ResNet-18-BN'], 165 | ... num_eval=5, 166 | ... device='cuda' 167 | ... ) 168 | # load image and soft labels 169 | >>> image_tensor, soft_labels = ... 170 | # compute metrics 171 | >>> evaluator.compute_metrics(image_tensor=image_tensor, soft_labels=soft_labels) 172 | # alternatively, provide image path 173 | >>> evaluator.compute_metrics(image_path='path/to/image/folder/', soft_labels=soft_labels) 174 | ``` -------------------------------------------------------------------------------- /doc/metrics/general.md: -------------------------------------------------------------------------------- 1 | ## GeneralEvaluator 2 | 3 |
4 | 5 | CLASS 6 | dd_ranking.metrics.GeneralEvaluator(config: Optional[Config] = None, 7 | dataset: str = 'CIFAR10', 8 | real_data_path: str = './dataset/', 9 | ipc: int = 10, 10 | model_name: str = 'ConvNet-3', 11 | soft_label_mode: str='S', 12 | soft_label_criterion: str='kl', 13 | temperature: float=1.0, 14 | data_aug_func: str='cutmix', 15 | aug_params: dict={'cutmix_p': 1.0}, 16 | optimizer: str='sgd', 17 | lr_scheduler: str='step', 18 | weight_decay: float=0.0005, 19 | momentum: float=0.9, 20 | num_eval: int=5, 21 | im_size: tuple=(32, 32), 22 | num_epochs: int=300, 23 | use_zca: bool=False, 24 | real_batch_size: int=256, 25 | syn_batch_size: int=256, 26 | default_lr: float=0.01, 27 | save_path: str=None, 28 | stu_use_torchvision: bool=False, 29 | tea_use_torchvision: bool=False, 30 | num_workers: int=4, 31 | teacher_dir: str='./teacher_models', 32 | custom_train_trans: Optional[Callable]=None, 33 | custom_val_trans: Optional[Callable]=None, 34 | device: str="cuda" 35 | ) 36 | [**[SOURCE]**](https://github.com/NUS-HPC-AI-Lab/DD-Ranking/blob/main/ddranking/metrics/general.py) 37 |
38 | 39 | A class for evaluating the traditional test accuracy of a surrogate model on the synthetic dataset under various settings (label type, data augmentation, etc.). 40 | 41 | ### Parameters 42 | Same as [Soft Label Evaluator](soft-label.md). 43 | 44 | ### Methods 45 |
46 | 47 | compute_metrics(image_tensor: Tensor = None, image_path: str = None, labels: Tensor = None, syn_lr: float = None) 48 |
49 | 50 |
51 | This method computes the test accuracy of the surrogate model on the synthetic dataset under various settings (label type, data augmentation, etc.). 52 | 53 | #### Parameters 54 | 55 | - **image_tensor**(Tensor): Image tensor. Must specify when `image_path` is not provided. We require the shape to be `(N x IPC, C, H, W)` where `N` is the number of classes. 56 | - **image_path**(str): Path to the image. Must specify when `image_tensor` is not provided. 57 | - **labels**(Tensor): Label tensor. It can be either hard labels or soft labels. When `soft_label_mode=S`, the label tensor must be provided. 58 | - **syn_lr**(float): Learning rate for the synthetic dataset. If not specified, the learning rate will be tuned automatically. 59 | 60 | #### Returns 61 | A dictionary with the following keys: 62 | - **acc_mean**: Mean of test accuracy from `num_eval` rounds. 63 | - **acc_std**: Standard deviation of test accuracy from `num_eval` rounds. 64 |
65 | 66 | ### Examples 67 | 68 | with config file: 69 | ```python 70 | >>> config = Config('/path/to/config.yaml') 71 | >>> evaluator = GeneralEvaluator(config=config) 72 | # load image and labels 73 | >>> image_tensor, labels = ... 74 | # compute metrics 75 | >>> evaluator.compute_metrics(image_tensor=image_tensor, labels=labels) 76 | # alternatively, provide image path 77 | >>> evaluator.compute_metrics(image_path='path/to/image.jpg', labels=labels) 78 | ``` 79 | 80 | with keyword arguments: 81 | ```python 82 | >>> evaluator = GeneralEvaluator( 83 | ... dataset='CIFAR10', 84 | ... model_name='ConvNet-3', 85 | ... soft_label_mode='S', 86 | ... soft_label_criterion='sce', 87 | ... temperature=1.0, 88 | ... data_aug_func='cutmix', 89 | ... aug_params={ 90 | ... "cutmix_p": 1.0, 91 | ... }, 92 | ... optimizer='sgd', 93 | ... lr_scheduler='step', 94 | ... weight_decay=0.0005, 95 | ... momentum=0.9, 96 | ... stu_use_torchvision=False, 97 | ... tea_use_torchvision=False, 98 | ... num_eval=5, 99 | ... device='cuda' 100 | ... ) 101 | # load image and labels 102 | >>> image_tensor, labels = ... 103 | # compute metrics 104 | >>> evaluator.compute_metrics(image_tensor=image_tensor, labels=labels) 105 | # alternatively, provide image path 106 | >>> evaluator.compute_metrics(image_path='path/to/image.jpg', labels=labels) 107 | ``` -------------------------------------------------------------------------------- /doc/metrics/lrs-hard-label.md: -------------------------------------------------------------------------------- 1 | ## LabelRobustScoreHard 2 | 3 |
4 | 5 | CLASS 6 | dd_ranking.metrics.LabelRobustScoreHard(config=None, 7 | dataset: str = 'CIFAR10', 8 | real_data_path: str = './dataset/', 9 | ipc: int = 10, 10 | model_name: str = 'ConvNet-3', 11 | data_aug_func: str = 'cutmix', 12 | aug_params: dict = {'cutmix_p': 1.0}, 13 | optimizer: str = 'sgd', 14 | lr_scheduler: str = 'step', 15 | step_size: int = None, 16 | weight_decay: float = 0.0005, 17 | momentum: float = 0.9, 18 | use_zca: bool = False, 19 | num_eval: int = 5, 20 | im_size: tuple = (32, 32), 21 | num_epochs: int = 300, 22 | real_batch_size: int = 256, 23 | syn_batch_size: int = 256, 24 | use_torchvision: bool = False, 25 | eval_full_data: bool = False, 26 | random_data_format: str = 'tensor', 27 | random_data_path: str = './dataset/', 28 | num_workers: int = 4, 29 | save_path: Optional[str] = None, 30 | custom_train_trans: Optional[Callable] = None, 31 | custom_val_trans: Optional[Callable] = None, 32 | device: str = "cuda", 33 | dist: bool = False 34 | ) 35 | [**[SOURCE]**](https://github.com/NUS-HPC-AI-Lab/DD-Ranking/blob/main/ddranking/metrics/hard_label.py) 36 |
37 | 38 | A class for evaluating the performance of a dataset distillation method with hard labels. User is able to modify the attributes as needed. 39 | 40 | ### Parameters 41 | 42 | - **config**(Optional[Config]): Config object for specifying all attributes. See [config](../config/overview.md) for more details. 43 | - **dataset**(str): Name of the real dataset. 44 | - **real_data_path**(str): Path to the real dataset. 45 | - **ipc**(int): Images per class. 46 | - **model_name**(str): Name of the surrogate model. See [models](../models/overview.md) for more details. 47 | - **data_aug_func**(str): Data augmentation function used during training. Currently supports `dsa`, `cutmix`, `mixup`. See [augmentations](../augmentations/overview.md) for more details. 48 | - **aug_params**(dict): Parameters for the data augmentation function. 49 | - **optimizer**(str): Name of the optimizer. Currently supports torch-based optimizers - `sgd`, `adam`, and `adamw`. 50 | - **lr_scheduler**(str): Name of the learning rate scheduler. Currently supports torch-based schedulers - `step`, `cosine`, `lambda_step`, and `cosineannealing`. 51 | - **weight_decay**(float): Weight decay for the optimizer. 52 | - **momentum**(float): Momentum for the optimizer. 53 | - **step_size**(int): Step size for the learning rate scheduler. 54 | - **use_zca**(bool): Whether to use ZCA whitening. 55 | - **num_eval**(int): Number of evaluations to perform. 56 | - **im_size**(tuple): Size of the images. 57 | - **num_epochs**(int): Number of epochs to train. 58 | - **real_batch_size**(int): Batch size for the real dataset. 59 | - **syn_batch_size**(int): Batch size for the synthetic dataset. 60 | - **use_torchvision**(bool): Whether to use torchvision to initialize the model. 61 | - **eval_full_data**(bool): Whether to evaluate on the full dataset. 62 | - **random_data_format**(str): Format of the randomly selected dataset. Currently supports `tensor` and `image`. 63 | - **random_data_path**(str): Path to the randomly selected dataset. 64 | - **num_workers**(int): Number of workers for data loading. 65 | - **save_path**(Optional[str]): Path to save the results. 66 | - **custom_train_trans**(Optional[Callable]): Custom transformation function when loading synthetic data. Only support torchvision transformations. See [torchvision-based transformations](../augmentations/torchvision.md) for more details. 67 | - **custom_val_trans**(Optional[Callable]): Custom transformation function when loading test dataset. Only support torchvision transformations. See [torchvision-based transformations](../augmentations/torchvision.md) for more details. 68 | - **device**(str): Device to use for evaluation, `cuda` or `cpu`. 69 | - **dist**(bool): Whether to use distributed evaluation. 70 | 71 | ### Methods 72 |
73 | 74 | compute_metrics(image_tensor: Tensor = None, image_path: str = None, hard_labels: Tensor = None, syn_lr: float = None, lrs_lambda: float = 0.5) 75 |
76 | 77 | This method computes the HLR, IOR, and LRS for the given image and hard labels (if provided). In each evaluation round, we set a different random seed and perform the following steps: 78 | 79 | 1. Compute the test accuracy of the surrogate model on the synthetic dataset under hard labels. We tune the learning rate for the best performance if `syn_lr` is not provided. 80 | 2. Compute the test accuracy of the surrogate model on the real dataset under the same setting as step 1. 81 | 3. Compute the test accuracy of the surrogate model on the randomly selected dataset under the same setting as step 1. 82 | 4. Compute the HLR and IOR scores. 83 | 5. Compute the LRS. 84 | 85 | The final scores are the average of the scores from `num_eval` rounds. 86 | 87 | #### Parameters 88 | 89 | - **image_tensor**(Tensor): Image tensor. Must specify when `image_path` is not provided. We require the shape to be `(N x IPC, C, H, W)` where `N` is the number of classes. 90 | - **image_path**(str): Path to the image. Must specify when `image_tensor` is not provided. 91 | - **hard_labels**(Tensor): Hard label tensor. The first dimension must be the same as `image_tensor`. 92 | - **syn_lr**(float): Learning rate for the synthetic dataset. If not specified, the learning rate will be tuned automatically. 93 | - **lrs_lambda**(float): Weighting parameter for the LRS. 94 | 95 | #### Returns 96 | 97 | A dictionary with the following keys: 98 | 99 | - **hard_label_recovery_mean**: Mean of HLR scores from `num_eval` rounds. 100 | - **hard_label_recovery_std**: Standard deviation of HLR scores from `num_eval` rounds. 101 | - **improvement_over_random_mean**: Mean of improvement over random scores from `num_eval` rounds. 102 | - **improvement_over_random_std**: Standard deviation of improvement over random scores from `num_eval` rounds. 103 | - **label_robust_score_mean**: Mean of LRS scores from `num_eval` rounds. 104 | - **label_robust_score_std**: Standard deviation of LRS scores from `num_eval` rounds. 105 | 106 | **Examples:** 107 | 108 | with config file: 109 | ```python 110 | >>> config = Config('/path/to/config.yaml') 111 | >>> evaluator = LabelRobustScoreHard(config=config) 112 | # load the image and hard labels 113 | >>> image_tensor, hard_labels = ... 114 | # compute the metrics 115 | >>> evaluator.compute_metrics(image_tensor=image_tensor, hard_labels=hard_labels) 116 | # alternatively, you can provide the image path 117 | >>> evaluator.compute_metrics(image_path='path/to/image/folder/', hard_labels=hard_labels) 118 | ``` 119 | 120 | with keyword arguments: 121 | ```python 122 | >>> evaluator = LabelRobustScoreHard( 123 | ... dataset='CIFAR10', 124 | ... real_data_path='./dataset/', 125 | ... ipc=10, 126 | ... model_name='ConvNet-3', 127 | ... data_aug_func='dsa', 128 | ... aug_params={ 129 | ... "prob_flip": 0.5, 130 | ... "ratio_rotate": 15.0, 131 | ... "saturation": 2.0, 132 | ... "brightness": 1.0, 133 | ... "contrast": 0.5, 134 | ... "ratio_scale": 1.2, 135 | ... "ratio_crop_pad": 0.125, 136 | ... "ratio_cutout": 0.5 137 | ... }, 138 | ... optimizer='sgd', 139 | ... lr_scheduler='step', 140 | ... weight_decay=0.0005, 141 | ... momentum=0.9, 142 | ... step_size=500, 143 | ... num_epochs=1000, 144 | ... real_batch_size=256, 145 | ... syn_batch_size=256, 146 | ... use_torchvision=False, 147 | ... eval_full_data=True, 148 | ... random_data_format='tensor', 149 | ... random_data_path='./random_data/', 150 | ... num_workers=4, 151 | ... save_path='./results/', 152 | ... use_zca=False, 153 | ... num_eval=5, 154 | ... device='cuda', 155 | ... dist=True 156 | ... ) 157 | # load the image and hard labels 158 | >>> image_tensor, hard_labels = ... 159 | # compute the metrics 160 | >>> evaluator.compute_metrics(image_tensor=image_tensor, hard_labels=hard_labels) 161 | # alternatively, you can provide the image path 162 | >>> evaluator.compute_metrics(image_path='path/to/image/folder/', hard_labels=hard_labels) 163 | ``` 164 | -------------------------------------------------------------------------------- /doc/metrics/lrs-soft-label.md: -------------------------------------------------------------------------------- 1 | ## LabelRobustScoreSoft 2 | 3 |
4 | 5 | CLASS 6 | dd_ranking.metrics.LabelRobustScoreSoft(config: Optional[Config] = None, 7 | dataset: str = 'CIFAR10', 8 | real_data_path: str = './dataset/', 9 | ipc: int = 10, 10 | model_name: str = 'ConvNet-3', 11 | soft_label_mode: str='S', 12 | soft_label_criterion: str='kl', 13 | loss_fn_kwargs: dict=None, 14 | data_aug_func: str='cutmix', 15 | aug_params: dict={'cutmix_p': 1.0}, 16 | optimizer: str='sgd', 17 | lr_scheduler: str='step', 18 | weight_decay: float=0.0005, 19 | momentum: float=0.9, 20 | step_size: int=None, 21 | num_eval: int=5, 22 | im_size: tuple=(32, 32), 23 | num_epochs: int=300, 24 | use_zca: bool=False, 25 | use_aug_for_hard: bool=False, 26 | random_data_format: str='tensor', 27 | random_data_path: str=None, 28 | real_batch_size: int=256, 29 | syn_batch_size: int=256, 30 | save_path: str=None, 31 | eval_full_data: bool=False, 32 | stu_use_torchvision: bool=False, 33 | tea_use_torchvision: bool=False, 34 | num_workers: int=4, 35 | teacher_dir: str='./teacher_models', 36 | teacher_model_names: list=None, 37 | custom_train_trans: Optional[Callable]=None, 38 | custom_val_trans: Optional[Callable]=None, 39 | device: str="cuda", 40 | dist: bool=False 41 | ) 42 | [**[SOURCE]**](https://github.com/NUS-HPC-AI-Lab/DD-Ranking/blob/main/ddranking/metrics/soft_label.py) 43 |
44 | 45 | A class for evaluating the performance of a dataset distillation method with soft labels. User is able to modify the attributes as needed. 46 | 47 | ### Parameters 48 | 49 | - **config**(Optional[Config]): Config object for specifying all attributes. See [config](../config/overview.md) for more details. 50 | - **dataset**(str): Name of the real dataset. 51 | - **real_data_path**(str): Path to the real dataset. 52 | - **ipc**(int): Images per class. 53 | - **model_name**(str): Name of the surrogate model. See [models](../models/overview.md) for more details. 54 | - **soft_label_mode**(str): Number of soft labels per image. `S` for single soft label, `M` for multiple soft labels. 55 | - **soft_label_criterion**(str): Loss function for using soft labels. Currently supports `kl` for KL divergence, `sce` for soft cross-entropy, and `mse_gt` for MSEGT loss introduced in EDC. 56 | - **loss_fn_kwargs**(dict): Keyword arguments for the loss function. `temperature` and `scale_loss` for KL and SCE loss, and `mse_weight` and `ce_weight` for MSE and CE loss. 57 | - **data_aug_func**(str): Data augmentation function used during training. Currently supports `dsa`, `cutmix`, `mixup`. See [augmentations](../augmentations/overview.md) for more details. 58 | - **aug_params**(dict): Parameters for the data augmentation function. 59 | - **use_aug_for_hard**(bool): Whether to use the data augmentation specified in `data_aug_func` for hard label evaluation. 60 | - **optimizer**(str): Name of the optimizer. Currently supports torch-based optimizers - `sgd`, `adam`, and `adamw`. 61 | - **lr_scheduler**(str): Name of the learning rate scheduler. Currently supports torch-based schedulers - `step`, `cosine`, `lambda_step`, and `cosineannealing`. 62 | - **weight_decay**(float): Weight decay for the optimizer. 63 | - **momentum**(float): Momentum for the optimizer. 64 | - **step_size**(int): Step size for the learning rate scheduler. 65 | - **use_zca**(bool): Whether to use ZCA whitening. 66 | - **num_eval**(int): Number of evaluations to perform. 67 | - **im_size**(tuple): Size of the images. 68 | - **num_epochs**(int): Number of epochs to train. 69 | - **real_batch_size**(int): Batch size for the real dataset. 70 | - **syn_batch_size**(int): Batch size for the synthetic dataset. 71 | - **stu_use_torchvision**(bool): Whether to use torchvision to initialize the student model. 72 | - **tea_use_torchvision**(bool): Whether to use torchvision to initialize the teacher model. 73 | - **teacher_dir**(str): Path to the teacher model. 74 | - **teacher_model_names**(list): List of teacher model names. 75 | - **random_data_format**(str): Format of the random data, `tensor` or `image`. 76 | - **random_data_path**(str): Path to save the random data. 77 | - **eval_full_data**(bool): Whether to compute the test accuracy on the full dataset (might be time-consuming on large datasets such as ImageNet1K, so we have provided a full dataset performance cache). 78 | - **num_workers**(int): Number of workers for data loading. 79 | - **save_path**(Optional[str]): Path to save the results. 80 | - **custom_train_trans**(Optional[Callable]): Custom transformation function when loading synthetic data. Only support torchvision transformations. See [torchvision-based transformations](../augmentations/torchvision.md) for more details. 81 | - **custom_val_trans**(Optional[Callable]): Custom transformation function when loading test dataset. Only support torchvision transformations. See [torchvision-based transformations](../augmentations/torchvision.md) for more details. 82 | - **device**(str): Device to use for evaluation, `cuda` or `cpu`. 83 | - **dist**(bool): Whether to use distributed training. 84 | 85 | ### Methods 86 |
87 | 88 | compute_metrics(image_tensor: Tensor = None, image_path: str = None, soft_labels: Tensor = None, syn_lr: float = None, lrs_lambda: float = 0.5) 89 |
90 | 91 |
92 | This method computes the HLR, IOR, and LRS for the given image and soft labels (if provided). In each evaluation round, we set a different random seed and perform the following steps: 93 | 94 | 1. Compute the test accuracy of the surrogate model on the synthetic dataset under hard labels. We perform learning rate tuning for the best performance. 95 | 2. Compute the test accuracy of the surrogate model on the real dataset under the same setting as step 1. 96 | 3. Compute the test accuracy of the surrogate model on the synthetic dataset under soft labels. 97 | 4. Compute the test accuracy of the surrogate model on the randomly selected dataset under the same setting as step 3. 98 | 5. Compute the HLR and IOR scores. 99 | 6. Compute the LRS. 100 | 101 | The final scores are the average of the scores from `num_eval` rounds. 102 | 103 | #### Parameters 104 | 105 | - **image_tensor**(Tensor): Image tensor. Must specify when `image_path` is not provided. We require the shape to be `(N x IPC, C, H, W)` where `N` is the number of classes. 106 | - **image_path**(str): Path to the image. Must specify when `image_tensor` is not provided. 107 | - **soft_labels**(Tensor): Soft label tensor. Must specify when `soft_label_mode` is `S`. The first dimension must be the same as `image_tensor`. 108 | - **syn_lr**(float): Learning rate for the synthetic dataset. If not specified, the learning rate will be tuned automatically. 109 | - **lrs_lambda**(float): Weighting parameter for the LRS. 110 | 111 | #### Returns 112 | 113 | A dictionary with the following keys: 114 | 115 | - **hard_label_recovery_mean**: Mean of HLR scores from `num_eval` rounds. 116 | - **hard_label_recovery_std**: Standard deviation of HLR scores from `num_eval` rounds. 117 | - **improvement_over_random_mean**: Mean of improvement over random scores from `num_eval` rounds. 118 | - **improvement_over_random_std**: Standard deviation of improvement over random scores from `num_eval` rounds. 119 | - **label_robust_score_mean**: Mean of LRS from `num_eval` rounds. 120 | - **label_robust_score_std**: Standard deviation of LRS from `num_eval` rounds. 121 | 122 |
123 | 124 | ### Examples 125 | 126 | with config file: 127 | ```python 128 | >>> config = Config('/path/to/config.yaml') 129 | >>> evaluator = LabelRobustScoreSoft(config=config) 130 | # load image and soft labels 131 | >>> image_tensor, soft_labels = ... 132 | # compute metrics 133 | >>> evaluator.compute_metrics(image_tensor=image_tensor, soft_labels=soft_labels) 134 | # alternatively, provide image path 135 | >>> evaluator.compute_metrics(image_path='path/to/image/folder/', soft_labels=soft_labels) 136 | ``` 137 | 138 | with keyword arguments: 139 | ```python 140 | >>> evaluator = LabelRobustScoreSoft( 141 | ... dataset='TinyImageNet', 142 | ... real_data_path='./dataset/', 143 | ... ipc=10, 144 | ... model_name='ResNet-18-BN', 145 | ... soft_label_mode='M', 146 | ... soft_label_criterion='kl', 147 | ... loss_fn_kwargs={ 148 | ... "temperature": 30.0, 149 | ... "scale_loss": False, 150 | ... }, 151 | ... data_aug_func='mixup', 152 | ... aug_params={ 153 | ... "mixup_p": 0.8, 154 | ... }, 155 | ... optimizer='sgd', 156 | ... lr_scheduler='step', 157 | ... num_epochs=300, 158 | ... step_size=100, 159 | ... weight_decay=0.0005, 160 | ... momentum=0.9, 161 | ... use_zca=False, 162 | ... use_aug_for_hard=False, 163 | ... stu_use_torchvision=True, 164 | ... tea_use_torchvision=True, 165 | ... num_workers=4, 166 | ... save_path='./results', 167 | ... eval_full_data=False, 168 | ... random_data_format='image', 169 | ... random_data_path='./random_data', 170 | ... num_eval=5, 171 | ... device='cuda' 172 | ... ) 173 | # load image and soft labels 174 | >>> image_tensor, soft_labels = ... 175 | # compute metrics 176 | >>> evaluator.compute_metrics(image_tensor=image_tensor, soft_labels=soft_labels) 177 | # alternatively, provide image path 178 | >>> evaluator.compute_metrics(image_path='path/to/image/folder/', soft_labels=soft_labels) 179 | ``` -------------------------------------------------------------------------------- /doc/metrics/overview.md: -------------------------------------------------------------------------------- 1 | # DD-Ranking Metrics 2 | 3 | DD-Ranking provides a set of metrics to evaluate the real informativeness of datasets distilled by different methods. The unfairness of existing evaluation is mainly caused by two factors, the label representation and the data augmentation. We design the label-robust score (LRS) and the augmentation robust score (ARS) to disentangle the impact of label representation and data augmentation on the evaluation, respectively. 4 | 5 | ## Evaluation Classes 6 | * [LabelRobustScoreHard](lrs-hard-label.md) computes HLR, IOR, and LRS for methods using hard labels. 7 | * [LabelRobustScoreSoft](lrs-soft-label.md) computes HLR, IOR, and LRS for methods using soft labels. 8 | * [AugmentationRobustScore](ars.md) computes the ARS for methods using soft labels. 9 | * [GeneralEvaluator](general.md) computes the traditional test accuracy for existing methods. 10 | -------------------------------------------------------------------------------- /doc/models/alexnet.md: -------------------------------------------------------------------------------- 1 | ## AlexNet 2 | 3 | Our [implementation](https://github.com/NUS-HPC-AI-Lab/DD-Ranking/blob/main/ddranking/utils/networks.py) of ConvNet is based on [DC](https://github.com/VICO-UoE/DatasetCondensation). 4 | 5 | We provide the following interface to initialize a AlexNet model: 6 | 7 |
8 | 9 | ddranking.utils.get_alexnet(model_name: str, im_size: tuple, channel: int, num_classes: int, pretrained: bool, model_path: str) 10 | [**[SOURCE]**](https://github.com/NUS-HPC-AI-Lab/DD-Ranking/blob/main/ddranking/utils/model.py) 11 |
12 | 13 | ### Parameters 14 | 15 | - **model_name**(str): Name of the model. Please navigate to [models](models/overview.md) for the model naming convention in DD-Ranking. 16 | - **im_size**(tuple): Image size. 17 | - **channel**(int): Number of channels of the input image. 18 | - **num_classes**(int): Number of classes. 19 | - **pretrained**(bool): Whether to load pretrained weights. 20 | - **model_path**(str): Path to the pretrained model weights. 21 | -------------------------------------------------------------------------------- /doc/models/convnet.md: -------------------------------------------------------------------------------- 1 | ## ConvNet 2 | 3 | Our [implementation](https://github.com/NUS-HPC-AI-Lab/DD-Ranking/blob/main/ddranking/utils/networks.py) of ConvNet is based on [DC](https://github.com/VICO-UoE/DatasetCondensation). 4 | 5 | By default, we use width 128, average pooling, and ReLU activation. We provide the following interface to initialize a ConvNet model: 6 | 7 |
8 | 9 | dd_ranking.utils.get_convnet(model_name: str, 10 | im_size: tuple, channel: int, num_classes: int, net_depth: int, net_norm: str, pretrained: bool, model_path: str) 11 | [**[SOURCE]**](https://github.com/NUS-HPC-AI-Lab/DD-Ranking/blob/main/ddranking/utils/model.py) 12 |
13 | 14 | ### Parameters 15 | 16 | - **model_name**(str): Name of the model. Please navigate to [models](models/overview.md) for the model naming convention in DD-Ranking. 17 | - **im_size**(tuple): Image size. 18 | - **channel**(int): Number of channels of the input image. 19 | - **num_classes**(int): Number of classes. 20 | - **net_depth**(int): Depth of the network. 21 | - **net_norm**(str): Normalization method. In ConvNet, we support `instance`, `batch`, and `group` normalization. 22 | - **pretrained**(bool): Whether to load pretrained weights. 23 | - **model_path**(str): Path to the pretrained model weights. 24 | 25 | To load a ConvNet model with different width or activation function or pooling method, you can use the following interface: 26 | 27 |
28 | 29 | dd_ranking.utils.networks.ConvNet(channel, num_classes, net_width, net_depth, net_act, net_norm, net_pooling, im_size) 30 | [**[SOURCE]**](https://github.com/NUS-HPC-AI-Lab/DD-Ranking/blob/main/dd_ranking/utils/networks.py) 31 |
32 | 33 | ### Parameters 34 | We only list the parameters that are not present in `get_convnet`. 35 | - **net_width**(int): Width of the network. 36 | - **net_act**(str): Activation function. We support `relu`, `leakyrelu`, and `sigmoid`. 37 | - **net_pooling**(str): Pooling method. We support `avgpooling`, `maxpooling`, and `none`. 38 | 39 | -------------------------------------------------------------------------------- /doc/models/lenet.md: -------------------------------------------------------------------------------- 1 | ## LeNet 2 | 3 | Our [implementation](https://github.com/NUS-HPC-AI-Lab/DD-Ranking/blob/main/ddranking/utils/networks.py) of LeNet is based on [DC](https://github.com/VICO-UoE/DatasetCondensation). 4 | 5 | We provide the following interface to initialize a LeNet model: 6 | 7 |
8 | 9 | ddranking.utils.get_lenet(model_name: str, im_size: tuple, channel: int, num_classes: int, pretrained: bool, model_path: str) 10 | [**[SOURCE]**](https://github.com/NUS-HPC-AI-Lab/DD-Ranking/blob/main/ddranking/utils/model.py) 11 |
12 | 13 | ### Parameters 14 | 15 | - **model_name**(str): Name of the model. Please navigate to [models](models/overview.md) for the model naming convention in DD-Ranking. 16 | - **im_size**(tuple): Image size. 17 | - **channel**(int): Number of channels of the input image. 18 | - **num_classes**(int): Number of classes. 19 | - **pretrained**(bool): Whether to load pretrained weights. 20 | - **model_path**(str): Path to the pretrained model weights. 21 | -------------------------------------------------------------------------------- /doc/models/mlp.md: -------------------------------------------------------------------------------- 1 | ## MLP 2 | 3 | Our [implementation](https://github.com/NUS-HPC-AI-Lab/DD-Ranking/blob/main/ddranking/utils/networks.py) of MLP is based on [DC](https://github.com/VICO-UoE/DatasetCondensation). 4 | 5 | We provide the following interface to initialize a MLP model: 6 | 7 |
8 | 9 | ddranking.utils.get_mlp(model_name: str, im_size: tuple, channel: int, num_classes: int, pretrained: bool, model_path: str) 10 | [**[SOURCE]**](https://github.com/NUS-HPC-AI-Lab/DD-Ranking/blob/main/ddranking/utils/model.py) 11 |
12 | 13 | ### Parameters 14 | 15 | - **model_name**(str): Name of the model. Please navigate to [models](models/overview.md) for the model naming convention in DD-Ranking. 16 | - **im_size**(tuple): Image size. 17 | - **channel**(int): Number of channels of the input image. 18 | - **num_classes**(int): Number of classes. 19 | - **pretrained**(bool): Whether to load pretrained weights. 20 | - **model_path**(str): Path to the pretrained model weights. 21 | -------------------------------------------------------------------------------- /doc/models/overview.md: -------------------------------------------------------------------------------- 1 | # Models 2 | 3 | DD-Ranking provides the implementation of a set of commonly used model architectures in existing dataset distillation methods. Users can flexibly use these models for main evaluation or cross-architecture evaluation. We will keep updating this section with more models. 4 | 5 | - [ConvNet](convnet.md) 6 | - [ResNet](resnet.md) 7 | - [VGG](vgg.md) 8 | - [LeNet](lenet.md) 9 | - [AlexNet](alexnet.md) 10 | - [MLP](mlp.md) 11 | 12 | Users can also define any model with `torchvision`. 13 | 14 | ## Naming Convention 15 | 16 | We use the following naming conventions for models in DD-Ranking: 17 | 18 | - `model name - model depth - norm type` (for DD-Ranking implemented models) 19 | - torchvision model names, e.g. `vgg11` and `vit_b_16` 20 | 21 | Model name and depth are required when **not using tochvision**. When norm type is not specified, we use default normalization for the model. For example, `ResNet-18-BN` means ResNet18 with batch normalization. `ConvNet-4` means ConvNet with depth 4 and default instance normalization. 22 | 23 | ## Pretrained Model Weights 24 | 25 | For users' convenience, we provide pretrained model weights on CIFAR10, CIFAR100, and TinyImageNet for the following models: 26 | - ConvNet-3 (CIFAR10, CIFAR100) 27 | - ConvNet-3-BN (CIFAR10, CIFAR100) 28 | - ConvNet-4 (TinyImageNet) 29 | - ConvNet-4-BN (TinyImageNet) 30 | - ResNet-18-BN (CIFAR10, CIFAR100, TinyImageNet, ImageNet1K) 31 | 32 | Users can download the weights from the following links: [Pretrained Model Weights](https://drive.google.com/drive/folders/19OnR85PRs3TZk8xS8XNr9hiokfsML4m2?usp=sharing). 33 | 34 | Users can also feel free to use `torchvision` pretrained models. 35 | -------------------------------------------------------------------------------- /doc/models/resnet.md: -------------------------------------------------------------------------------- 1 | ## ResNet 2 | 3 | DD-Ranking supports implementation of ResNet in both [DC](https://github.com/VICO-UoE/DatasetCondensation) and [torchvision](https://pytorch.org/vision/main/models/resnet.html). 4 | 5 | We provide the following interface to initialize a ConvNet model: 6 | 7 |
8 | 9 | ddranking.utils.get_resnet(model_name: str, 10 | im_size: tuple, channel: int, num_classes: int, depth: int, batchnorm: bool, use_torchvision: bool, pretrained: bool, model_path: str) 11 | [**[SOURCE]**](https://github.com/NUS-HPC-AI-Lab/DD-Ranking/blob/main/ddranking/utils/model.py) 12 |
13 | 14 | ### Parameters 15 | 16 | - **model_name**(str): Name of the model. Please navigate to [models](models/overview.md) for the model naming convention in DD-Ranking. 17 | - **im_size**(tuple): Image size. 18 | - **channel**(int): Number of channels of the input image. 19 | - **num_classes**(int): Number of classes. 20 | - **depth**(int): Depth of the network. 21 | - **batchnorm**(bool): Whether to use batch normalization. 22 | - **use_torchvision**(bool): Whether to use torchvision to initialize the model. When using torchvision, the ResNet model uses batch normalization by default. 23 | - **pretrained**(bool): Whether to load pretrained weights. 24 | - **model_path**(str): Path to the pretrained model weights. 25 | 26 |
NOTE
27 | 28 |
29 | When using torchvision ResNet on image size smaller than 224 x 224, we make the following modifications: 30 | 31 | ```python 32 | model.conv1 = torch.nn.Conv2d(3, 64, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False) 33 | model.maxpool = torch.nn.Identity() 34 | ``` 35 |
36 | -------------------------------------------------------------------------------- /doc/models/vgg.md: -------------------------------------------------------------------------------- 1 | ## VGG 2 | 3 | DD-Ranking supports implementation of VGG in both [DC](https://github.com/VICO-UoE/DatasetCondensation) and [torchvision](https://pytorch.org/vision/main/models/vgg.html). 4 | 5 | We provide the following interface to initialize a ConvNet model: 6 | 7 |
8 | 9 | ddranking.utils.get_vgg(model_name: str, 10 | im_size: tuple, channel: int, num_classes: int, depth: int, batchnorm: bool, use_torchvision: bool, pretrained: bool, model_path: str) 11 | [**[SOURCE]**](https://github.com/NUS-HPC-AI-Lab/DD-Ranking/blob/main/ddranking/utils/model.py) 12 |
13 | 14 | ### Parameters 15 | 16 | - **model_name**(str): Name of the model. Please navigate to [models](models/overview.md) for the model naming convention in DD-Ranking. 17 | - **im_size**(tuple): Image size. 18 | - **channel**(int): Number of channels of the input image. 19 | - **num_classes**(int): Number of classes. 20 | - **depth**(int): Depth of the network. 21 | - **batchnorm**(bool): Whether to use batch normalization. 22 | - **use_torchvision**(bool): Whether to use torchvision to initialize the model. 23 | - **pretrained**(bool): Whether to load pretrained weights. 24 | - **model_path**(str): Path to the pretrained model weights. 25 | 26 |
NOTE
27 | 28 |
29 | When using torchvision VGG on image size smaller than 224 x 224, we make the following modifications: 30 | 31 | For 32x32 image size: 32 | ```python 33 | model.classifier = nn.Sequential(OrderedDict([ 34 | ('fc1', nn.Linear(512 * 1 * 1, 4096)), 35 | ('relu1', nn.ReLU(True)), 36 | ('drop1', nn.Dropout()), 37 | ('fc2', nn.Linear(4096, 4096)), 38 | ('relu2', nn.ReLU(True)), 39 | ('drop2', nn.Dropout()), 40 | ('fc3', nn.Linear(4096, num_classes)), 41 | ])) 42 | ``` 43 | 44 | For 64x64 image size: 45 | ```python 46 | model.classifier = nn.Sequential(OrderedDict([ 47 | ('fc1', nn.Linear(512 * 2 * 2, 4096)), 48 | ('relu1', nn.ReLU(True)), 49 | ('drop1', nn.Dropout()), 50 | ('fc2', nn.Linear(4096, 4096)), 51 | ('relu2', nn.ReLU(True)), 52 | ('drop2', nn.Dropout()), 53 | ('fc3', nn.Linear(4096, num_classes)), 54 | ])) 55 | ``` 56 |
-------------------------------------------------------------------------------- /doc/static/configurations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/doc/static/configurations.png -------------------------------------------------------------------------------- /doc/static/history.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/doc/static/history.png -------------------------------------------------------------------------------- /doc/static/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/doc/static/logo.png -------------------------------------------------------------------------------- /doc/static/team/zekai.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/doc/static/team/zekai.jpg -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools", 4 | "wheel", 5 | "torch >= 2.0.0" 6 | ] 7 | build-backend = "setuptools.build_meta" 8 | 9 | [project] 10 | name = "ddranking" 11 | version = "0.2.0" 12 | description = "DD-Ranking: Rethinking the Evaluation of Dataset Distillation" 13 | readme = "README.md" 14 | requires-python = ">=3.8" 15 | dependencies = [ 16 | "torch", 17 | "numpy<2.0.0", 18 | "torchvision", 19 | "tqdm", 20 | "scipy<1.16.0", 21 | "kornia", 22 | "timm", 23 | "pandas" 24 | ] 25 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | packages = find_packages() 4 | 5 | setup( 6 | name='ddranking', # Package name 7 | version='0.2.0', # Version number 8 | description='DD-Ranking: Rethinking the Evaluation of Dataset Distillation', 9 | long_description=open('README.md').read(), # Use your README as the long description 10 | long_description_content_type='text/markdown', 11 | author='DD-Ranking Team', 12 | author_email='lizekai@u.nus.edu', 13 | include_dirs=['ddranking', 'configs', 'static'], 14 | include_package_data=True, 15 | packages=packages, # Automatically discover submodules 16 | install_requires=[ 17 | 'torch', 18 | 'numpy', 19 | 'torchvision', 20 | 'tqdm', 21 | 'scipy<1.16.0', 22 | 'pandas', 23 | 'kornia', 24 | 'timm' 25 | ], 26 | classifiers=[ 27 | "Programming Language :: Python :: 3.8", 28 | "Programming Language :: Python :: 3.9", 29 | "Programming Language :: Python :: 3.10", 30 | "Programming Language :: Python :: 3.11", 31 | "Programming Language :: Python :: 3.12", 32 | "Intended Audience :: Developers", 33 | "Intended Audience :: Information Technology", 34 | "Intended Audience :: Science/Research", 35 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 36 | "Topic :: Scientific/Engineering :: Information Analysis", 37 | 'License :: OSI Approved :: MIT License', 38 | 'Operating System :: OS Independent', 39 | ], 40 | python_requires='>=3.8', 41 | ) 42 | -------------------------------------------------------------------------------- /static/configurations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/static/configurations.png -------------------------------------------------------------------------------- /static/history.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/static/history.png -------------------------------------------------------------------------------- /static/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/DD-Ranking/c1ea0b3f251b7927feb25c4dafc76d5312011caf/static/logo.png --------------------------------------------------------------------------------