├── .gitattributes ├── .github ├── repo.yml └── workflows │ └── security-scan.yaml ├── .gitignore ├── CODEOWNERS ├── LICENSE ├── README.md ├── SmoothCache ├── __init__.py ├── calibration │ ├── calibration_helper.py │ └── diffuser_calibration_helper.py ├── diffuser_cache_helper.py ├── dit_cache_helper.py └── smooth_cache_helper.py ├── assets ├── SmoothCache2.png ├── TeaserFigureFlat.png ├── dit-mosaic.png ├── table1.png ├── table2.png └── table3.png ├── examples └── run_calibration.py ├── setup.py └── smoothcache_schedules ├── 30-N-2-fora.json ├── 30-N-3-threshold-0.35.json ├── 50-N-2-fora.json ├── 50-N-2-l2c.json ├── 50-N-3-threshold-0.08.json ├── 50-N-3-threshold-0.18.json ├── 70-N-2-fora.json ├── 70-N-3-threshold-0.08.json └── diffuser_schedule.json /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto 2 | 3 | # LFS file patterns, separated by spaces. 4 | # Images 5 | # Extensions> png jpg jpeg gif ico bmp 6 | # Videos 7 | # Extensions> mov mp4 mp3 flv fla swf pdf fbx 8 | # Archives 9 | # Extensions> gz zip 7z ttf rar tar.gz 10 | # Key file 11 | # Extensions> snk 12 | # Build artifacts 13 | # Extensions> dll pdb nupkg blob exe exe.old so dylib 14 | # Database files 15 | # Extensions> mdf sdf sqlite db 16 | # Other types 17 | # Extensions> bin 18 | 19 | # The rest of this file is autogenerated, don't modify it directly. Add to the list 20 | # above and run Add-RbxLfsRulesToGitattriutes.ps1 instead. 21 | # https://confluence.rbx.com/display/BIRD/Gitattributes+and+LFS 22 | # Extension Patterns: 23 | 24 | *.[pP][nN][gG] filter=lfs diff=lfs merge=lfs -text 25 | *.[jJ][pP][gG] filter=lfs diff=lfs merge=lfs -text 26 | *.[jJ][pP][eE][gG] filter=lfs diff=lfs merge=lfs -text 27 | *.[gG][iI][fF] filter=lfs diff=lfs merge=lfs -text 28 | *.[iI][cC][oO] filter=lfs diff=lfs merge=lfs -text 29 | *.[bB][mM][pP] filter=lfs diff=lfs merge=lfs -text 30 | 31 | *.[mM][oO][vV] filter=lfs diff=lfs merge=lfs -text 32 | *.[mM][pP]4 filter=lfs diff=lfs merge=lfs -text 33 | *.[mM][pP]3 filter=lfs diff=lfs merge=lfs -text 34 | *.[fF][lL][vV] filter=lfs diff=lfs merge=lfs -text 35 | *.[fF][lL][aA] filter=lfs diff=lfs merge=lfs -text 36 | *.[sS][wW][fF] filter=lfs diff=lfs merge=lfs -text 37 | *.[pP][dD][fF] filter=lfs diff=lfs merge=lfs -text 38 | *.[fF][bB][xX] filter=lfs diff=lfs merge=lfs -text 39 | 40 | *.[gG][zZ] filter=lfs diff=lfs merge=lfs -text 41 | *.[zZ][iI][pP] filter=lfs diff=lfs merge=lfs -text 42 | *.7[zZ] filter=lfs diff=lfs merge=lfs -text 43 | *.[tT][tT][fF] filter=lfs diff=lfs merge=lfs -text 44 | *.[rR][aA][rR] filter=lfs diff=lfs merge=lfs -text 45 | *.[tT][aA][rR].[gG][zZ] filter=lfs diff=lfs merge=lfs -text 46 | 47 | *.[sS][nN][kK] filter=lfs diff=lfs merge=lfs -text 48 | 49 | *.[dD][lL][lL] filter=lfs diff=lfs merge=lfs -text 50 | *.[pP][dD][bB] filter=lfs diff=lfs merge=lfs -text 51 | *.[nN][uU][pP][kK][gG] filter=lfs diff=lfs merge=lfs -text 52 | *.[bB][lL][oO][bB] filter=lfs diff=lfs merge=lfs -text 53 | *.[eE][xX][eE] filter=lfs diff=lfs merge=lfs -text 54 | *.[eE][xX][eE].[oO][lL][dD] filter=lfs diff=lfs merge=lfs -text 55 | *.[sS][oO] filter=lfs diff=lfs merge=lfs -text 56 | *.[dD][yY][lL][iI][bB] filter=lfs diff=lfs merge=lfs -text 57 | 58 | *.[mM][dD][fF] filter=lfs diff=lfs merge=lfs -text 59 | *.[sS][dD][fF] filter=lfs diff=lfs merge=lfs -text 60 | *.[sS][qQ][lL][iI][tT][eE] filter=lfs diff=lfs merge=lfs -text 61 | *.[dD][bB] filter=lfs diff=lfs merge=lfs -text 62 | 63 | *.[bB][iI][nN] filter=lfs diff=lfs merge=lfs -text 64 | -------------------------------------------------------------------------------- /.github/repo.yml: -------------------------------------------------------------------------------- 1 | # Which github team is the maintainer of the repository. Maintainers have special permissions and are responsible for 2 | # the code in the repository in all meaningful ways. They may use CODEOWNERS to delegate ownership of sub-folders 3 | maintainer: roblox/ml-runtime 4 | -------------------------------------------------------------------------------- /.github/workflows/security-scan.yaml: -------------------------------------------------------------------------------- 1 | name: Security Scan 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - main 8 | 9 | jobs: 10 | security: 11 | name: OSS Security SAST 12 | uses: Roblox/security-workflows/.github/workflows/oss-security-sast.yaml@main 13 | with: 14 | skip-ossf: true 15 | secrets: 16 | GITLEAKS_LICENSE: ${{ secrets.GITLEAKS_KEY }} 17 | ROBLOX_SEMGREP_GHC_POC_APP_TOKEN: ${{ secrets.ROBLOX_SEMGREP_GHC_POC_APP_TOKEN }} 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | ## 4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore 5 | 6 | # User-specific files 7 | *.rsuser 8 | *.suo 9 | *.user 10 | *.userosscache 11 | *.sln.docstates 12 | 13 | # User-specific files (MonoDevelop/Xamarin Studio) 14 | *.userprefs 15 | 16 | # Mono auto generated files 17 | mono_crash.* 18 | 19 | # Build results 20 | [Dd]ebug/ 21 | [Dd]ebugPublic/ 22 | [Rr]elease/ 23 | [Rr]eleases/ 24 | x64/ 25 | x86/ 26 | [Aa][Rr][Mm]/ 27 | [Aa][Rr][Mm]64/ 28 | bld/ 29 | [Bb]in/ 30 | [Oo]bj/ 31 | [Ll]og/ 32 | [Ll]ogs/ 33 | 34 | # Visual Studio 2015/2017 cache/options directory 35 | .vs/ 36 | # Uncomment if you have tasks that create the project's static files in wwwroot 37 | #wwwroot/ 38 | 39 | # Visual Studio 2017 auto generated files 40 | Generated\ Files/ 41 | 42 | # MSTest test Results 43 | [Tt]est[Rr]esult*/ 44 | [Bb]uild[Ll]og.* 45 | 46 | # NUnit 47 | *.VisualState.xml 48 | TestResult.xml 49 | nunit-*.xml 50 | 51 | # Build Results of an ATL Project 52 | [Dd]ebugPS/ 53 | [Rr]eleasePS/ 54 | dlldata.c 55 | 56 | # Benchmark Results 57 | BenchmarkDotNet.Artifacts/ 58 | 59 | # .NET Core 60 | project.lock.json 61 | project.fragment.lock.json 62 | artifacts/ 63 | 64 | # StyleCop 65 | StyleCopReport.xml 66 | 67 | # Files built by Visual Studio 68 | *_i.c 69 | *_p.c 70 | *_h.h 71 | *.ilk 72 | *.meta 73 | *.obj 74 | *.iobj 75 | *.pch 76 | *.pdb 77 | *.ipdb 78 | *.pgc 79 | *.pgd 80 | *.rsp 81 | *.sbr 82 | *.tlb 83 | *.tli 84 | *.tlh 85 | *.tmp 86 | *.tmp_proj 87 | *_wpftmp.csproj 88 | *.log 89 | *.vspscc 90 | *.vssscc 91 | .builds 92 | *.pidb 93 | *.svclog 94 | *.scc 95 | 96 | # Chutzpah Test files 97 | _Chutzpah* 98 | 99 | # Visual C++ cache files 100 | ipch/ 101 | *.aps 102 | *.ncb 103 | *.opendb 104 | *.opensdf 105 | *.sdf 106 | *.cachefile 107 | *.VC.db 108 | *.VC.VC.opendb 109 | 110 | # Visual Studio profiler 111 | *.psess 112 | *.vsp 113 | *.vspx 114 | *.sap 115 | 116 | # Visual Studio Trace Files 117 | *.e2e 118 | 119 | # TFS 2012 Local Workspace 120 | $tf/ 121 | 122 | # Guidance Automation Toolkit 123 | *.gpState 124 | 125 | # ReSharper is a .NET coding add-in 126 | _ReSharper*/ 127 | *.[Rr]e[Ss]harper 128 | *.DotSettings.user 129 | 130 | # TeamCity is a build add-in 131 | _TeamCity* 132 | 133 | # DotCover is a Code Coverage Tool 134 | *.dotCover 135 | 136 | # AxoCover is a Code Coverage Tool 137 | .axoCover/* 138 | !.axoCover/settings.json 139 | 140 | # Visual Studio code coverage results 141 | *.coverage 142 | *.coveragexml 143 | 144 | # NCrunch 145 | _NCrunch_* 146 | .*crunch*.local.xml 147 | nCrunchTemp_* 148 | 149 | # MightyMoose 150 | *.mm.* 151 | AutoTest.Net/ 152 | 153 | # Web workbench (sass) 154 | .sass-cache/ 155 | 156 | # Installshield output folder 157 | [Ee]xpress/ 158 | 159 | # DocProject is a documentation generator add-in 160 | DocProject/buildhelp/ 161 | DocProject/Help/*.HxT 162 | DocProject/Help/*.HxC 163 | DocProject/Help/*.hhc 164 | DocProject/Help/*.hhk 165 | DocProject/Help/*.hhp 166 | DocProject/Help/Html2 167 | DocProject/Help/html 168 | 169 | # Click-Once directory 170 | publish/ 171 | 172 | # Publish Web Output 173 | *.[Pp]ublish.xml 174 | *.azurePubxml 175 | # Note: Comment the next line if you want to checkin your web deploy settings, 176 | # but database connection strings (with potential passwords) will be unencrypted 177 | *.pubxml 178 | *.publishproj 179 | 180 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 181 | # checkin your Azure Web App publish settings, but sensitive information contained 182 | # in these scripts will be unencrypted 183 | PublishScripts/ 184 | 185 | # NuGet Packages 186 | *.nupkg 187 | # NuGet Symbol Packages 188 | *.snupkg 189 | # The packages folder can be ignored because of Package Restore 190 | **/[Pp]ackages/* 191 | # except build/, which is used as an MSBuild target. 192 | !**/[Pp]ackages/build/ 193 | # Uncomment if necessary however generally it will be regenerated when needed 194 | #!**/[Pp]ackages/repositories.config 195 | # NuGet v3's project.json files produces more ignorable files 196 | *.nuget.props 197 | *.nuget.targets 198 | 199 | # Microsoft Azure Build Output 200 | csx/ 201 | *.build.csdef 202 | 203 | # Microsoft Azure Emulator 204 | ecf/ 205 | rcf/ 206 | 207 | # Windows Store app package directories and files 208 | AppPackages/ 209 | BundleArtifacts/ 210 | Package.StoreAssociation.xml 211 | _pkginfo.txt 212 | *.appx 213 | *.appxbundle 214 | *.appxupload 215 | 216 | # Visual Studio cache files 217 | # files ending in .cache can be ignored 218 | *.[Cc]ache 219 | # but keep track of directories ending in .cache 220 | !?*.[Cc]ache/ 221 | 222 | # Others 223 | ClientBin/ 224 | ~$* 225 | *~ 226 | *.dbmdl 227 | *.dbproj.schemaview 228 | *.jfm 229 | *.pfx 230 | *.publishsettings 231 | orleans.codegen.cs 232 | 233 | # Including strong name files can present a security risk 234 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 235 | #*.snk 236 | 237 | # Since there are multiple workflows, uncomment next line to ignore bower_components 238 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 239 | #bower_components/ 240 | 241 | # RIA/Silverlight projects 242 | Generated_Code/ 243 | 244 | # Backup & report files from converting an old project file 245 | # to a newer Visual Studio version. Backup files are not needed, 246 | # because we have git ;-) 247 | _UpgradeReport_Files/ 248 | Backup*/ 249 | UpgradeLog*.XML 250 | UpgradeLog*.htm 251 | ServiceFabricBackup/ 252 | *.rptproj.bak 253 | 254 | # SQL Server files 255 | *.mdf 256 | *.ldf 257 | *.ndf 258 | 259 | # Business Intelligence projects 260 | *.rdl.data 261 | *.bim.layout 262 | *.bim_*.settings 263 | *.rptproj.rsuser 264 | *- [Bb]ackup.rdl 265 | *- [Bb]ackup ([0-9]).rdl 266 | *- [Bb]ackup ([0-9][0-9]).rdl 267 | 268 | # Microsoft Fakes 269 | FakesAssemblies/ 270 | 271 | # GhostDoc plugin setting file 272 | *.GhostDoc.xml 273 | 274 | # Node.js Tools for Visual Studio 275 | .ntvs_analysis.dat 276 | node_modules/ 277 | 278 | # Visual Studio 6 build log 279 | *.plg 280 | 281 | # Visual Studio 6 workspace options file 282 | *.opt 283 | 284 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 285 | *.vbw 286 | 287 | # Visual Studio LightSwitch build output 288 | **/*.HTMLClient/GeneratedArtifacts 289 | **/*.DesktopClient/GeneratedArtifacts 290 | **/*.DesktopClient/ModelManifest.xml 291 | **/*.Server/GeneratedArtifacts 292 | **/*.Server/ModelManifest.xml 293 | _Pvt_Extensions 294 | 295 | # Paket dependency manager 296 | .paket/paket.exe 297 | paket-files/ 298 | 299 | # FAKE - F# Make 300 | .fake/ 301 | 302 | # CodeRush personal settings 303 | .cr/personal 304 | 305 | # Python Tools for Visual Studio (PTVS) 306 | __pycache__/ 307 | *.pyc 308 | 309 | # Cake - Uncomment if you are using it 310 | # tools/** 311 | # !tools/packages.config 312 | 313 | # Tabs Studio 314 | *.tss 315 | 316 | # Telerik's JustMock configuration file 317 | *.jmconfig 318 | 319 | # BizTalk build output 320 | *.btp.cs 321 | *.btm.cs 322 | *.odx.cs 323 | *.xsd.cs 324 | 325 | # OpenCover UI analysis results 326 | OpenCover/ 327 | 328 | # Azure Stream Analytics local run output 329 | ASALocalRun/ 330 | 331 | # MSBuild Binary and Structured Log 332 | *.binlog 333 | 334 | # NVidia Nsight GPU debugger configuration file 335 | *.nvuser 336 | 337 | # MFractors (Xamarin productivity tool) working folder 338 | .mfractor/ 339 | 340 | # Local History for Visual Studio 341 | .localhistory/ 342 | 343 | # BeatPulse healthcheck temp database 344 | healthchecksdb 345 | 346 | # Backup folder for Package Reference Convert tool in Visual Studio 2017 347 | MigrationBackup/ 348 | 349 | # Ionide (cross platform F# VS Code tools) working folder 350 | .ionide/ 351 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # This is the codeowners file. It describes the people who must approve a code review 2 | # The first entry is a path filter for which things code owners will be required on. 3 | # To have a team, prefix it with Roblox/. Note that the team must have been added to the repo. 4 | # Multiple owners on one line will be an OR operation, as long as one approves the whole line is satisfied. 5 | # For more details: https://roblox.atlassian.net/wiki/spaces/ENGEFF/pages/1549927831/Using+CODEOWNERS+files 6 | 7 | # Default catch-all for anything that isn't more specifically owned. 8 | 9 | * @roblox/ml-runtime 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 9 | 10 | ![Accelerating Diffusion Transformer inference across multiple modalities with 50 DDIM Steps on DiT-XL-256x256, 100 DPM-Solver++(3M) SDE steps for a 10s audio sample (spectrogram shown) on Stable Audio Open, 30 Rectified Flow steps on Open-Sora 480p 2s videos](assets/TeaserFigureFlat.png) 11 | 12 | **Figure 1. Accelerating Diffusion Transformer inference across multiple modalities with 50 DDIM Steps on DiT-XL-256x256, 100 DPM-Solver++(3M) SDE steps for a 10s audio sample (spectrogram shown) on Stable Audio Open, 30 Rectified Flow steps on Open-Sora 480p 2s videos** 13 | 14 | 15 | # Updates 16 | 17 | ## Release v0.1 18 | 19 | [View release notes for v0.1](https://github.com/Roblox/SmoothCache/releases/tag/v0.1) 20 | 21 | SmoothCache now supports generating cache schedues using a zero-intrusion external helper. See [run_calibration.py](./examples/run_calibration.py) to find out how it generates a schedule compatible with [HuggingFace Diffusers DiTPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/dit/pipeline_dit.py), without requiring any changes to Diffusers implementation! 22 | 23 | 24 | # Introduction 25 | We introduce **SmoothCache**, a straightforward acceleration technique for DiT architecture models, that's both **training-free, flexible and performant**. By leveraging layer-wise representation error, our method identifies redundancies in the diffusion process, generates a static caching scheme to reuse output featuremaps and therefore reduces the need for computationally expensive operations. This solution works across different models and modalities, can be easily dropped into existing Diffusion Transformer pipelines, can be stacked on different solvers, and requires no additional training or datasets. **SmoothCache** consistently outperforms various solvers designed to accelerate the diffusion process, while matching or surpassing the performance of existing modality-specific caching techniques. 26 | 27 | > 🥯[[Arxiv]](https://arxiv.org/abs/2411.10510) 28 | 29 | ![Illustration of SmoothCache. When the layer representation loss obtained from the calibration pass is below some threshold α, the corresponding layer is cached and used in place of the same computation on a future timestep. The figure on the left shows how the layer representation error impacts whether certain layers are eligible for caching. The error of the attention (attn) layer is higher in earlier timesteps, so our schedule caches the later timesteps accordingly. The figure on the right shows the application of the caching schedule to the DiT-XL architecture. The output of the attn layer at time t − 1 is cached and re-used in place of computing FFN t − 2, since the corresponding error is below α. This cached output is introduced in the model using the properties of the residual connection.](assets/SmoothCache2.png) 30 | 31 | ## Quick Start 32 | 33 | ### Install 34 | ```bash 35 | pip install dit-smoothcache 36 | ``` 37 | 38 | ### Usage - Inference 39 | 40 | Inspired by [DeepCache](https://raw.githubusercontent.com/horseee/DeepCache), we have implemented drop-in SmoothCache helper classes that easily applies to [Huggingface Diffuser DiTPipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/dit), and [original DiT implementations](https://github.com/facebookresearch/DiT). 41 | 42 | Generally, only 3 additional lines needs to be added to the original sampler scripts: 43 | ```python 44 | from SmoothCache import 45 | cache_helper = DiffuserCacheHelper(, schedule=schedule) 46 | cache_helper.enable() 47 | # Original sampler code. 48 | cache_helper.disable() 49 | ``` 50 | 51 | #### Usage example with Huggingface Diffuser DiTPipeline: 52 | ```python 53 | import json 54 | import torch 55 | from diffusers import DiTPipeline, DPMSolverMultistepScheduler 56 | 57 | # Import SmoothCacheHelper 58 | from SmoothCache import DiffuserCacheHelper 59 | 60 | # Load the DiT pipeline and scheduler 61 | pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", torch_dtype=torch.float16) 62 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) 63 | pipe = pipe.to("cuda") 64 | 65 | # Initialize the DiffuserCacheHelper with the model 66 | with open("smoothcache_schedules/50-N-3-threshold-0.35.json", "r") as f: 67 | schedule = json.load(f) 68 | cache_helper = DiffuserCacheHelper(pipe.transformer, schedule=schedule) 69 | 70 | # Enable the caching helper 71 | cache_helper.enable() 72 | # Prepare the input 73 | words = ["Labrador retriever"] 74 | class_ids = pipe.get_label_ids(words) 75 | 76 | # Generate images with the pipeline 77 | generator = torch.manual_seed(33) 78 | image = pipe(class_labels=class_ids, num_inference_steps=50, generator=generator).images[0] 79 | 80 | # Restore the original forward method and disable the helper 81 | # disable() should be paired up with enable() 82 | cache_helper.disable() 83 | ``` 84 | 85 | #### Usage example with original DiT implementation 86 | ```python 87 | import torch 88 | 89 | torch.backends.cuda.matmul.allow_tf32 = True 90 | torch.backends.cudnn.allow_tf32 = True 91 | from torchvision.utils import save_image 92 | from diffusion import create_diffusion 93 | from diffusers.models import AutoencoderKL 94 | from download import find_model 95 | from models import DiT_models 96 | import argparse 97 | from SmoothCache import DiTCacheHelper # Import DiTCacheHelper 98 | import json 99 | 100 | # Setup PyTorch: 101 | torch.manual_seed(args.seed) 102 | torch.set_grad_enabled(False) 103 | device = "cuda" if torch.cuda.is_available() else "cpu" 104 | 105 | if args.ckpt is None: 106 | assert ( 107 | args.model == "DiT-XL/2" 108 | ), "Only DiT-XL/2 models are available for auto-download." 109 | assert args.image_size in [256, 512] 110 | assert args.num_classes == 1000 111 | 112 | # Load model: 113 | latent_size = args.image_size // 8 114 | model = DiT_models[args.model]( 115 | input_size=latent_size, num_classes=args.num_classes 116 | ).to(device) 117 | ckpt_path = args.ckpt or f"DiT-XL-2-{args.image_size}x{args.image_size}.pt" 118 | state_dict = find_model(ckpt_path) 119 | model.load_state_dict(state_dict) 120 | model.eval() # important! 121 | with open("smoothcache_schedules/50-N-3-threshold-0.35.json", "r") as f: 122 | schedule = json.load(f) 123 | cache_helper = DiTCacheHelper(model, schedule=schedule) 124 | 125 | # number of timesteps should be consistent with provided schedules 126 | diffusion = create_diffusion(str(len(schedule[cache_helper.components_to_wrap[0]]))) 127 | 128 | # Enable the caching helper 129 | cache_helper.enable() 130 | 131 | # Sample images: 132 | samples = diffusion.p_sample_loop( 133 | model.forward_with_cfg, 134 | z.shape, 135 | z, 136 | clip_denoised=False, 137 | model_kwargs=model_kwargs, 138 | progress=True, 139 | device=device, 140 | ) 141 | samples, _ = samples.chunk(2, dim=0) # Remove null class samples 142 | samples = vae.decode(samples / 0.18215).sample 143 | 144 | # Disable the caching helper after sampling 145 | cache_helper.disable() 146 | # Save and display images: 147 | save_image(samples, "sample.png", nrow=4, normalize=True, value_range=(-1, 1)) 148 | ``` 149 | 150 | ### Usage - Cache Schedule Generation 151 | See [run_calibration.py](./examples/run_calibration.py), which generates schedule for the self-attention module ([attn1](https://github.com/huggingface/diffusers/blob/37a5f1b3b69ed284086fb31fb1b49668cba6c365/src/diffusers/models/attention.py#L380)) 152 | from Diffusers [BasicTransformerBlock](https://github.com/huggingface/diffusers/blob/37a5f1b3b69ed284086fb31fb1b49668cba6c365/src/diffusers/models/attention.py#L261C7-L261C28) block. 153 | 154 | Note that only self-attention, and not cross-attention, is enabled in the stock config of Diffusers [DiT module](https://github.com/huggingface/diffusers/blob/37a5f1b3b69ed284086fb31fb1b49668cba6c365/src/diffusers/models/transformers/dit_transformer_2d.py#L72-L73). We leave this behavior 155 | as-is for the purpose of minimal intrusion. 156 | 157 | We welcome all contributions aimed at expending SmoothCache's model coverage and module coverage. 158 | 159 | ## Visualization 160 | 161 | ### 256x256 Image Generation Task 162 | 163 | ![Mosaic Image](assets/dit-mosaic.png) 164 | 165 | 166 | 167 | ## Evaluation 168 | 169 | ### Image Generation with DiT-XL/2-256x256 170 | ![Table 1. Results For DiT-XL-256x256 on using DDIM Sampling. 171 | Note that L2C is not training free](assets/table1.png) 172 | 173 | ### Video Generation with OpenSora 174 | ![Table 2. Results For OpenSora on Rectified Flow](assets/table2.png) 175 | 176 | ### Audio Generation with Stable Audio Open 177 | ![Table 3. Results For Stable Audio Open on DPMSolver++(3M) SDE on 3 datasets](assets/table3.png) 178 | 179 | 180 | # License 181 | SmoothCache is licensed under the [Apache-2.0](LICENSE) license. 182 | 183 | ## Bibtex 184 | ``` 185 | @misc{liu2024smoothcacheuniversalinferenceacceleration, 186 | title={SmoothCache: A Universal Inference Acceleration Technique for Diffusion Transformers}, 187 | author={Joseph Liu and Joshua Geddes and Ziyu Guo and Haomiao Jiang and Mahesh Kumar Nandwana}, 188 | year={2024}, 189 | eprint={2411.10510}, 190 | archivePrefix={arXiv}, 191 | primaryClass={cs.LG}, 192 | url={https://arxiv.org/abs/2411.10510}, 193 | } 194 | ``` -------------------------------------------------------------------------------- /SmoothCache/__init__.py: -------------------------------------------------------------------------------- 1 | # SmoothCache/__init__.py 2 | 3 | from .smooth_cache_helper import SmoothCacheHelper 4 | 5 | __all__ = ['SmoothCacheHelper'] 6 | 7 | # Try to import DiffuserCacheHelper 8 | try: 9 | from .diffuser_cache_helper import DiffuserCacheHelper 10 | __all__.append('DiffuserCacheHelper') 11 | except ImportError: 12 | print("Warning: DiffuserCacheHelper not imported. Ensure Diffusers is installed.") 13 | 14 | # Try to import DiTCacheHelper 15 | try: 16 | from .dit_cache_helper import DiTCacheHelper 17 | __all__.append('DiTCacheHelper') 18 | except ImportError: 19 | print("Warning: DiTCacheHelper not imported. Ensure necessary dependencies are installed.") 20 | 21 | # Try to import calibration helpers 22 | try: 23 | from .calibration.calibration_helper import CalibrationHelper 24 | __all__.append('CalibrationHelper') 25 | except ImportError: 26 | print("Warning: CalibrationHelper not imported.") 27 | 28 | try: 29 | from .calibration.diffuser_calibration_helper import DiffuserCalibrationHelper 30 | __all__.append('DiffuserCalibrationHelper') 31 | except ImportError: 32 | print("Warning: DiffuserCalibrationHelper not imported. Ensure Diffusers is installed.") -------------------------------------------------------------------------------- /SmoothCache/calibration/calibration_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Roblox Corporation 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | import re 17 | import statistics 18 | from typing import Dict, List, Optional, Union, Type 19 | import torch 20 | import torch.nn as nn 21 | from pathlib import Path 22 | 23 | 24 | def rel_l1_loss(prev_output, cur_output): 25 | """ 26 | Compute the relative L1 loss between prev_output and cur_output as a single float. 27 | 28 | Args: 29 | prev_output (torch.Tensor): Previous layer output. Shape: [batch_size, channels, ...] 30 | cur_output (torch.Tensor): Current layer output. Shape: [batch_size, channels, ...] 31 | 32 | Returns: 33 | float: Relative L1 loss across the entire batch, on flattened inputs, 34 | Since DiTPipeline will duplicate the batch anyway. 35 | """ 36 | output_diff = prev_output.float() - cur_output.float() 37 | numerator = torch.norm(output_diff, p=1) 38 | denominator = torch.norm(cur_output.float(), p=1) 39 | relative_l1 = numerator / denominator 40 | return relative_l1.cpu().item() 41 | 42 | class CalibrationHelper: 43 | def __init__( 44 | self, 45 | model: nn.Module, 46 | block_classes: Union[Type[nn.Module], List[Type[nn.Module]]], 47 | components_to_wrap: List[str], 48 | calibration_lookahead: int = 3, 49 | calibration_threshold: float = 0.0, 50 | schedule_length: int = 50, 51 | log_file: str = "calibration_schedule.json" 52 | ): 53 | """ 54 | Base CalibrationHelper that dynamically wraps specified components for calibration. 55 | 56 | Args: 57 | model (nn.Module): The model whose components we want to calibrate. 58 | block_classes (Union[Type[nn.Module], List[Type[nn.Module]]]): The block class(es) identifying which blocks to wrap. 59 | components_to_wrap (List[str]): Component names within each block to wrap (e.g. ['attn1', 'mlp']). 60 | calibration_lookahead (int): Number of steps to look back when computing errors. 61 | log_file (str): Path to save the generated schedule. 62 | """ 63 | self.model = model 64 | self.block_classes = block_classes if isinstance(block_classes, list) else [block_classes] 65 | self.components_to_wrap = components_to_wrap 66 | 67 | self.calibration_lookahead = calibration_lookahead 68 | # Validate calibration_lookahead 69 | if self.calibration_lookahead <= 0: 70 | raise ValueError("calibration_lookahead must be greater than 0.") 71 | 72 | self.calibration_threshold = calibration_threshold 73 | self.schedule_length = schedule_length 74 | self.log_file = log_file 75 | 76 | # Tracking original forward methods 77 | self.original_forwards = {} 78 | 79 | # Tracking steps and outputs 80 | self.current_steps = {} 81 | self.previous_layer_outputs = {} 82 | self.calibration_results = {} 83 | 84 | # State 85 | self.enabled = False 86 | 87 | def enable(self): 88 | """ 89 | Enable calibration mode by wrapping components at runtime. 90 | After enabling, simply run your pipeline once to collect calibration data. 91 | """ 92 | if self.enabled: 93 | return 94 | self.enabled = True 95 | self.reset_state() 96 | self.wrap_components() 97 | 98 | def disable(self): 99 | """ 100 | Disable calibration mode, unwrap the components, generate the schedule, and save it. 101 | Ensures that the destination directory exists before writing the schedule JSON. 102 | """ 103 | if not self.enabled: 104 | return 105 | self.enabled = False 106 | self.unwrap_components() 107 | generated_schedule = self.generate_schedule() 108 | 109 | log_path = Path(self.log_file) 110 | if log_path.parent: 111 | log_path.parent.mkdir(parents=True, exist_ok=True) 112 | 113 | with log_path.open("w") as f: 114 | f.write("{\n") 115 | for i, (key, value) in enumerate(generated_schedule.items()): 116 | # Serialize the list as a compact JSON list 117 | value_str = json.dumps(value, separators=(',', ':')) 118 | # Write the key-value pair 119 | f.write(f' "{key}": {value_str}') 120 | if i < len(generated_schedule) - 1: 121 | f.write(",\n") 122 | else: 123 | f.write("\n") 124 | f.write("}\n") 125 | 126 | self.reset_state() 127 | 128 | def reset_state(self): 129 | """ 130 | Reset internal state. 131 | """ 132 | self.current_steps.clear() 133 | self.previous_layer_outputs.clear() 134 | self.calibration_results.clear() 135 | 136 | def wrap_components(self): 137 | """ 138 | Wrap the specified components in the given block classes. 139 | """ 140 | for block_name, block in self.model.named_modules(): 141 | if any(isinstance(block, cls) for cls in self.block_classes): 142 | self.wrap_block_components(block, block_name) 143 | 144 | def wrap_block_components(self, block, block_name: str): 145 | """ 146 | Wrap the target components (e.g., 'attn1') in each block. 147 | """ 148 | for comp_name in self.components_to_wrap: 149 | if hasattr(block, comp_name): 150 | component = getattr(block, comp_name) 151 | full_name = f"{block_name}.{comp_name}" 152 | self.original_forwards[full_name] = component.forward 153 | wrapped_forward = self.create_wrapped_forward(full_name, component.forward) 154 | component.forward = wrapped_forward 155 | 156 | def unwrap_components(self): 157 | """ 158 | Restore original forward methods for all wrapped components. 159 | """ 160 | for full_name, original_forward in self.original_forwards.items(): 161 | module = self.get_module_by_name(self.model, full_name) 162 | if module is not None: 163 | module.forward = original_forward 164 | self.original_forwards.clear() 165 | 166 | def create_wrapped_forward(self, full_name: str, original_forward): 167 | """ 168 | Create a wrapped forward method that intercepts outputs, computes errors, and stores them. 169 | """ 170 | def wrapped_forward(*args, **kwargs): 171 | # Increment step counter 172 | step = self.current_steps.get(full_name, 0) + 1 173 | self.current_steps[full_name] = step 174 | 175 | # Call original forward 176 | output = original_forward(*args, **kwargs) 177 | 178 | # 'output' is the layer output for this component. We treat it as a torch.Tensor 179 | # Store and compute error vs previous steps 180 | # Initialize storage if not present 181 | if full_name not in self.previous_layer_outputs: 182 | self.previous_layer_outputs[full_name] = [None] * self.calibration_lookahead 183 | if full_name not in self.calibration_results: 184 | self.calibration_results[full_name] = [[] for _ in range(self.calibration_lookahead)] 185 | 186 | current_output = output 187 | # Compare with previous outputs 188 | for j in range(self.calibration_lookahead): 189 | prev_output = self.previous_layer_outputs[full_name][j] 190 | if prev_output is not None and current_output is not None: 191 | # Compute error 192 | error = rel_l1_loss(prev_output, current_output) 193 | self.calibration_results[full_name][j].append(error) 194 | 195 | # Update previous outputs 196 | self.previous_layer_outputs[full_name].insert(0, current_output.detach().clone()) 197 | if len(self.previous_layer_outputs[full_name]) > self.calibration_lookahead: 198 | self.previous_layer_outputs[full_name].pop() 199 | 200 | return output 201 | return wrapped_forward 202 | 203 | def generate_schedule(self): 204 | """ 205 | Generate schedules for each exact component name (e.g., 'attn1', 'mlp1', etc.) 206 | using n-row scanning logic, where n is arbitrary based on calibration_lookahead. 207 | 208 | For example, if self.calibration_results has keys: 209 | 'transformer_blocks.0.attn1', 'transformer_blocks.1.attn1', 'transformer_blocks.0.mlp1', ... 210 | we parse out the last part (e.g., 'attn1', 'mlp1') as `component_full`, 211 | and group all blocks that share that same component_full. 212 | 213 | Each group yields n arrays: row0, row1, ..., row(n-1)_list, averaged across all blocks, 214 | then scanned to produce the schedule. 215 | 216 | Returns: 217 | A dictionary like: 218 | { 219 | 'attn1': [schedule_length schedule], 220 | 'mlp1': [schedule_length schedule], 221 | ... 222 | } 223 | """ 224 | import numpy as np 225 | from collections import defaultdict 226 | 227 | # Dictionary: component_name -> list of lists for each row 228 | component_to_rows = defaultdict(list) 229 | 230 | # Step A: Collect row arrays by exact component name 231 | for full_name, sublists in self.calibration_results.items(): 232 | if len(sublists) < self.calibration_lookahead: 233 | # skip if incomplete 234 | continue 235 | 236 | # e.g., 'transformer_blocks.0.attn1' => component_full='attn1' 237 | component_full = full_name.split('.')[-1] # e.g., 'attn1' 238 | 239 | # sublists is a list of row arrays for this component 240 | component_to_rows[component_full].append(sublists) 241 | 242 | final_schedules = {} 243 | 244 | # Step B: For each component_full, average rows and produce schedule 245 | for component_full, sublist_groups in component_to_rows.items(): 246 | # Assuming each sublist_group has the same number of rows (calibration_lookahead) 247 | num_rows = len(sublist_groups[0]) if sublist_groups else 0 248 | 249 | # Average each row across all blocks 250 | averaged_rows = [] 251 | for row_idx in range(num_rows): 252 | row_arrays = [sublist[row_idx] for sublist in sublist_groups] 253 | avg_row_list = self._average_arrays(row_arrays) 254 | averaged_rows.append(avg_row_list) 255 | 256 | schedule = self._scan_nrows_sublists(averaged_rows, self.calibration_threshold) 257 | final_schedules[component_full] = schedule 258 | 259 | print(final_schedules) 260 | return final_schedules 261 | 262 | def _average_arrays(self, array_list): 263 | """ 264 | Given a list of 1D numpy arrays of potentially different lengths, 265 | compute the average across them at each index. 266 | Returns a Python list of floats for the average. 267 | e.g. if array_list = [arrA(len=49), arrB(len=49), arrC(len=48), ...] 268 | we find max_len, sum, count -> average. 269 | """ 270 | import numpy as np 271 | if not array_list: 272 | return [] 273 | 274 | max_len = max(len(arr) for arr in array_list) 275 | sum_vals = np.zeros(max_len, dtype=float) 276 | count_vals = np.zeros(max_len, dtype=int) 277 | 278 | for arr in array_list: 279 | for i, val in enumerate(arr): 280 | sum_vals[i] += val 281 | count_vals[i] += 1 282 | 283 | avg_arr = np.zeros(max_len, dtype=float) 284 | for i in range(max_len): 285 | if count_vals[i] > 0: 286 | avg_arr[i] = sum_vals[i] / count_vals[i] 287 | return avg_arr.tolist() 288 | 289 | def _scan_nrows_sublists(self, row_lists, threshold): 290 | """ 291 | Scan through multiple rows (arbitrary number) in reverse order to produce a schedule. 292 | 293 | Parameters: 294 | row_lists (list of lists): A list where each element is a row's list of values 295 | ordered from highest priority to lowest. 296 | threshold (float): The threshold value to check against. 297 | 298 | Returns: 299 | schedule (list): The generated schedule based on the scanning logic. 300 | """ 301 | schedule = [None] * self.schedule_length 302 | i = 0 303 | 304 | while i < self.schedule_length: 305 | idx = i 306 | used = False 307 | 308 | # Iterate through each row in reverse order (highest priority first) 309 | for row_idx in range(len(row_lists)-1, -1, -1): 310 | current_row_list = row_lists[row_idx] 311 | if idx >= len(current_row_list): 312 | continue # Skip if index is out of bounds for this row 313 | 314 | if current_row_list[idx] <= threshold: 315 | # Activate the current step 316 | schedule[i] = 1 317 | 318 | # Determine how many steps to skip based on the row priority 319 | num_skips = row_idx + 1 # More skips for higher priority rows 320 | skip_steps = [] 321 | for s in range(1, num_skips + 1): 322 | skip_step = i + s 323 | if skip_step < self.schedule_length: 324 | schedule[skip_step] = 0 325 | skip_steps.append(skip_step) 326 | 327 | # Move the index past the skipped steps 328 | i += (num_skips + 1) # Move to the step after the last skip 329 | used = True 330 | break 331 | 332 | if not used: 333 | # Fallback: Activate current step without skipping 334 | schedule[i] = 1 335 | i += 1 336 | 337 | # Override the first and last steps to be active 338 | if self.schedule_length > 0: 339 | schedule[0] = 1 340 | schedule[-1] = 1 341 | 342 | # Fill any remaining None values with 1 343 | for x in range(self.schedule_length): 344 | if schedule[x] is None: 345 | schedule[x] = 1 346 | 347 | return schedule 348 | 349 | def get_module_by_name(self, model, full_name): 350 | """ 351 | Utility to retrieve a module by full name. 352 | """ 353 | names = full_name.split('.') 354 | module = model 355 | for name in names: 356 | if hasattr(module, name): 357 | module = getattr(module, name) 358 | else: 359 | return None 360 | return module 361 | -------------------------------------------------------------------------------- /SmoothCache/calibration/diffuser_calibration_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Roblox Corporation 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import List, Optional 16 | import torch.nn as nn 17 | from .calibration_helper import CalibrationHelper 18 | 19 | try: 20 | from diffusers.models.attention import BasicTransformerBlock 21 | except ImportError: 22 | BasicTransformerBlock = None 23 | 24 | class DiffuserCalibrationHelper(CalibrationHelper): 25 | def __init__( 26 | self, 27 | model: nn.Module, 28 | calibration_lookahead: int = 3, 29 | calibration_threshold: float = 0.0, 30 | schedule_length: int = 50, 31 | log_file: str = "calibration_schedule.json", 32 | components_to_wrap: Optional[List[str]] = None 33 | ): 34 | """ 35 | Diffuser-specific CalibrationHelper derived from CalibrationHelper. 36 | 37 | Args: 38 | model (nn.Module): The model to wrap (e.g., pipe.transformer). 39 | calibration_lookahead (int): Steps to look back for error calculation. 40 | calibration_threshold (float): Cutoff L1 error value to enable caching. 41 | schedule_length (int): Length of the generated schedule, 1:1 mapped to pipeline timesteps. 42 | log_file (str): Path to save the generated schedule JSON. 43 | components_to_wrap (List[str], optional): List of component names to wrap. 44 | Defaults to ['attn1']. 45 | 46 | Raises: 47 | ImportError: If diffusers' BasicTransformerBlock is unavailable. 48 | """ 49 | if BasicTransformerBlock is None: 50 | raise ImportError("Diffusers library not installed or BasicTransformerBlock not found.") 51 | 52 | block_classes = [BasicTransformerBlock] 53 | if components_to_wrap is None: 54 | components_to_wrap = ['attn1'] 55 | super().__init__( 56 | model=model, 57 | block_classes=block_classes, 58 | components_to_wrap=components_to_wrap, 59 | calibration_lookahead=calibration_lookahead, 60 | calibration_threshold=calibration_threshold, 61 | schedule_length=schedule_length, 62 | log_file=log_file 63 | ) 64 | -------------------------------------------------------------------------------- /SmoothCache/diffuser_cache_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Roblox Corporation 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Helper Class for Diffusion Transformer Implemented at 16 | https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/dit""" 17 | 18 | from typing import List, Optional 19 | from .smooth_cache_helper import SmoothCacheHelper 20 | 21 | try: 22 | from diffusers.models.attention import BasicTransformerBlock 23 | except ImportError: 24 | print("Warning: Diffusers library is not installed. DiffuserCacheHelper cannot be used.") 25 | BasicTransformerBlock = None 26 | 27 | class DiffuserCacheHelper(SmoothCacheHelper): 28 | def __init__(self, model, schedule): 29 | if BasicTransformerBlock is None: 30 | raise ImportError("Diffusers library is not installed. DiffuserCacheHelper cannot be used.") 31 | block_classes = BasicTransformerBlock 32 | components_to_wrap = ['attn1'] 33 | super().__init__( 34 | model=model, 35 | block_classes=block_classes, 36 | components_to_wrap=components_to_wrap, 37 | schedule=schedule 38 | ) 39 | -------------------------------------------------------------------------------- /SmoothCache/dit_cache_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Roblox Corporation 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Helper Class for Diffusion Transformer Implemented at 16 | https://github.com/facebookresearch/DiT""" 17 | 18 | from .smooth_cache_helper import SmoothCacheHelper 19 | 20 | try: 21 | # Assuming DiTBlock is defined in 'models/dit.py' in the DiT repository 22 | from models import DiTBlock 23 | except ImportError: 24 | print("Warning: DiT library is not accessible. DiTCacheHelper cannot be used.") 25 | DiTBlock = None 26 | 27 | class DiTCacheHelper(SmoothCacheHelper): 28 | def __init__(self, model, schedule): 29 | if DiTBlock is None: 30 | raise ImportError("DiT library is not accessible. DiTCacheHelper cannot be used.") 31 | block_classes = DiTBlock 32 | components_to_wrap = ['attn', 'mlp'] 33 | super().__init__( 34 | model=model, 35 | block_classes=block_classes, 36 | components_to_wrap=components_to_wrap, 37 | schedule=schedule 38 | ) 39 | -------------------------------------------------------------------------------- /SmoothCache/smooth_cache_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Roblox Corporation 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Core SmoothCache Helper Implementation""" 16 | 17 | from typing import Dict, Any, Optional, List, Union, Type 18 | import torch 19 | import torch.nn as nn 20 | 21 | class SmoothCacheHelper: 22 | def __init__( 23 | self, 24 | model: nn.Module, 25 | block_classes: Union[Type[nn.Module], List[Type[nn.Module]]], 26 | components_to_wrap: List[str], 27 | schedule: Dict[str, List[int]], 28 | ): 29 | """ 30 | Generalized SmoothCacheHelper to wrap specified components in specified block classes. 31 | 32 | Args: 33 | model (nn.Module): The model to wrap. 34 | block_classes (Type[nn.Module] or List[Type[nn.Module]]): The block class(es) to search for. 35 | components_to_wrap (List[str]): The names of the components within the blocks to wrap. 36 | schedule (Dict[str, List[int]]): A dictionary mapping component names to lists of 0/1 values for each timestep. 37 | 1 means run normally, 0 means use cached result. 38 | """ 39 | self.model = model 40 | self.block_classes = block_classes if isinstance(block_classes, list) else [block_classes] 41 | self.components_to_wrap = components_to_wrap 42 | self.schedule = schedule 43 | 44 | self.original_forwards = {} 45 | self.cache = {} 46 | # Use per-module step counters 47 | self.current_steps = {} 48 | self.start_steps = {} 49 | 50 | def enable(self): 51 | self.reset_state() 52 | self.wrap_components() 53 | 54 | def disable(self): 55 | self.unwrap_components() 56 | self.reset_state() 57 | 58 | def reset_state(self): 59 | self.current_steps = {} 60 | self.start_steps = {} 61 | self.cache.clear() 62 | 63 | def is_skip_step(self, full_name): 64 | # Extract component name and block index from full_name 65 | names = full_name.split('.') 66 | component_name = names[-1] # e.g., 'attn' or 'mlp', etc. 67 | block_index = names[-2] # e.g., '0', '1', '2', etc. 68 | schedule_key_with_index = f"{component_name}-{block_index}" 69 | schedule_key_without_index = component_name 70 | 71 | # Determine which schedule key to use 72 | if schedule_key_with_index in self.schedule: 73 | # Use the schedule specific to the block 74 | schedule_key = schedule_key_with_index 75 | elif schedule_key_without_index in self.schedule: 76 | # Use the general schedule for the component 77 | schedule_key = schedule_key_without_index 78 | else: 79 | return False 80 | 81 | # Get the current timestep for this module by # Adjust index to start from 0 82 | current_step = self.current_steps.get(full_name, 0) - 1 83 | schedule_list = self.schedule[schedule_key] 84 | 85 | if current_step < 0 or current_step >= len(schedule_list): 86 | return False 87 | 88 | # 1 means run normally, 0 means use cached result (skip computation) 89 | skip = schedule_list[current_step] == 0 90 | 91 | return skip 92 | 93 | def wrap_components(self): 94 | # Wrap specified components within each block class 95 | for block_name, block in self.model.named_modules(): 96 | if any(isinstance(block, cls) for cls in self.block_classes): 97 | self.wrap_block_components(block, block_name) 98 | 99 | def wrap_block_components(self, block, block_name): 100 | if len(self.components_to_wrap) > 0: 101 | for comp_name in self.components_to_wrap: 102 | if hasattr(block, comp_name): 103 | component = getattr(block, comp_name) 104 | full_name = f"{block_name}.{comp_name}" 105 | # Store original forward method 106 | self.original_forwards[full_name] = component.forward 107 | # Create wrapped forward method 108 | wrapped_forward = self.create_wrapped_forward(full_name, component.forward) 109 | # Replace the component's forward method 110 | component.forward = wrapped_forward 111 | 112 | def unwrap_components(self): 113 | # Restore original forward methods 114 | for full_name, original_forward in self.original_forwards.items(): 115 | module = self.get_module_by_name(self.model, full_name) 116 | if module is not None: 117 | module.forward = original_forward 118 | # Clear original_forwards to avoid accumulating stale states 119 | self.original_forwards.clear() 120 | 121 | def create_wrapped_forward(self, full_name, original_forward): 122 | def wrapped_forward(*args, **kwargs): 123 | # Initialize step counters for this module if not already done 124 | if full_name not in self.current_steps: 125 | self.current_steps[full_name] = 0 126 | self.start_steps[full_name] = None 127 | 128 | # Increment current_step for this module 129 | self.current_steps[full_name] += 1 130 | 131 | if self.is_skip_step(full_name) and full_name in self.cache: 132 | # Use cached output during skipped steps 133 | print("Returning cached result for", full_name, "at step", self.current_steps[full_name]) 134 | return self.cache[full_name] 135 | else: 136 | # Compute output and cache it 137 | output = original_forward(*args, **kwargs) 138 | self.cache[full_name] = output 139 | print("returning normal result for ", full_name, " at step ", self.current_steps[full_name]) 140 | return output 141 | return wrapped_forward 142 | 143 | def get_module_by_name(self, model, full_name): 144 | # Utility function to retrieve a module by its full name 145 | names = full_name.split('.') 146 | module = model 147 | for name in names: 148 | if hasattr(module, name): 149 | module = getattr(module, name) 150 | else: 151 | return None 152 | return module 153 | -------------------------------------------------------------------------------- /assets/SmoothCache2.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:9dc21f167f80091de68ac38d37d0b425e814f1874497c623b57a6adcb98b02dd 3 | size 96595 4 | -------------------------------------------------------------------------------- /assets/TeaserFigureFlat.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7d223edb05346086b865d0609644d0dd3a789c9769c36da7d10d01d6d55edb7b 3 | size 6459503 4 | -------------------------------------------------------------------------------- /assets/dit-mosaic.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a4db68e9ec248f8dbeafe10d3023dbd0da2ca08579844324d62d6afd31068378 3 | size 9330260 4 | -------------------------------------------------------------------------------- /assets/table1.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:911d7a4ba0ec96ffec53a6bd9c02e29a3efb6c893a6d7b087e7275253f812f2d 3 | size 99052 4 | -------------------------------------------------------------------------------- /assets/table2.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:78719d226b0f8a521ea041f6ed034e45b9cf7436773fa608e4ab67958f29490f 3 | size 30527 4 | -------------------------------------------------------------------------------- /assets/table3.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:85ed39d3fd4ef933f34c27298f0fcde0be5b944bdf572e188f45e0afcd43cab8 3 | size 68178 4 | -------------------------------------------------------------------------------- /examples/run_calibration.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Roblox Corporation 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from diffusers import DiTPipeline, DPMSolverMultistepScheduler 17 | from SmoothCache import DiffuserCalibrationHelper 18 | 19 | def main(): 20 | # Load pipeline 21 | pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", torch_dtype=torch.float16) 22 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) 23 | pipe = pipe.to("cuda") 24 | 25 | 26 | num_inference_steps = 50 27 | # Initialize calibration helper 28 | calibration_helper = DiffuserCalibrationHelper( 29 | model=pipe.transformer, 30 | calibration_lookahead=3, 31 | calibration_threshold=0.15, 32 | schedule_length=num_inference_steps, # should be consistent with num_inference_steps below 33 | log_file="smoothcache_schedules/diffuser_schedule.json" 34 | ) 35 | 36 | # Enable calibration 37 | calibration_helper.enable() 38 | 39 | # Run pipeline normally 40 | words = ["Labrador retriever", "combination lock", "cassette player"] 41 | 42 | 43 | class_ids = pipe.get_label_ids(words) 44 | 45 | generator = torch.manual_seed(33) 46 | images = pipe( 47 | class_labels=class_ids, 48 | num_inference_steps=num_inference_steps, 49 | generator=generator 50 | ).images # Normal pipeline call 51 | 52 | # Disable calibration and generate schedule 53 | calibration_helper.disable() 54 | 55 | print("Calibration complete. Schedule saved to smoothcache_schedules/diffuser_schedule.json") 56 | 57 | for prompt, image in zip(words, images): 58 | image.save(prompt + '.png') 59 | 60 | if __name__ == "__main__": 61 | main() 62 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # SmoothCache/setup.py 2 | 3 | from setuptools import setup, find_packages 4 | 5 | packages = find_packages() 6 | print("Packages found:", packages) 7 | 8 | with open("README.md", "r") as f: 9 | long_description = f.read() 10 | 11 | setup( 12 | name='dit-smoothcache', 13 | version='v0.1.1', 14 | description='Training-free acceleration toolkit for Diffusion Transformer pipelines', 15 | packages=packages, 16 | author='Roblox Core AI', 17 | long_description=long_description, 18 | long_description_content_type="text/markdown", 19 | url="https://github.com/Roblox/SmoothCache", 20 | install_requires=[ 21 | "torch>=2.0.0", 22 | ], 23 | ) 24 | -------------------------------------------------------------------------------- /smoothcache_schedules/30-N-2-fora.json: -------------------------------------------------------------------------------- 1 | {"attn": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1], 2 | "mlp": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1]} -------------------------------------------------------------------------------- /smoothcache_schedules/30-N-3-threshold-0.35.json: -------------------------------------------------------------------------------- 1 | {"attn": [1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1], 2 | "mlp": [1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1]} -------------------------------------------------------------------------------- /smoothcache_schedules/50-N-2-fora.json: -------------------------------------------------------------------------------- 1 | {"mlp": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1], 2 | "attn": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1]} -------------------------------------------------------------------------------- /smoothcache_schedules/50-N-2-l2c.json: -------------------------------------------------------------------------------- 1 | {"attn-0": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-0": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-1": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1], "mlp-1": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-2": [1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-2": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-3": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-3": [1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-4": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-4": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-5": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-5": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-6": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-6": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-7": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-7": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-8": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-8": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-9": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-9": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-10": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-10": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-11": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-11": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-12": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-12": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-13": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-13": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-14": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-14": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-15": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-15": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-16": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-16": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-17": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-17": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-18": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-18": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-19": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-19": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-20": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-20": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-21": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-21": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-22": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-22": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-23": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-23": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-24": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-24": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-25": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-25": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-26": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-26": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "attn-27": [1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "mlp-27": [1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]} -------------------------------------------------------------------------------- /smoothcache_schedules/50-N-3-threshold-0.08.json: -------------------------------------------------------------------------------- 1 | {"attn": [1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 2 | "mlp": [1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]} -------------------------------------------------------------------------------- /smoothcache_schedules/50-N-3-threshold-0.18.json: -------------------------------------------------------------------------------- 1 | {"attn": [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1], 2 | "mlp": [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1]} -------------------------------------------------------------------------------- /smoothcache_schedules/70-N-2-fora.json: -------------------------------------------------------------------------------- 1 | {"attn": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1], 2 | "mlp": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1]} -------------------------------------------------------------------------------- /smoothcache_schedules/70-N-3-threshold-0.08.json: -------------------------------------------------------------------------------- 1 | {"attn": [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1], 2 | "mlp": [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1]} -------------------------------------------------------------------------------- /smoothcache_schedules/diffuser_schedule.json: -------------------------------------------------------------------------------- 1 | { 2 | "attn1": [1,0,1,0,0,1,0,1,0,1,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,1,1] 3 | } 4 | --------------------------------------------------------------------------------