├── .editorconfig ├── .github └── workflows │ ├── linter.yml │ ├── performanceTests.yml │ └── tests.yml ├── .gitignore ├── AdOptimize.png ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── adOptimizeClient ├── build.gradle.kts ├── gradle.properties ├── gradle │ └── wrapper │ │ ├── gradle-wrapper.jar │ │ └── gradle-wrapper.properties ├── gradlew ├── gradlew.bat ├── settings.gradle.kts └── src │ └── main │ └── kotlin │ ├── AdConfig.kt │ └── client │ └── main.kt ├── adoptimize-cli-compiler-plugin ├── build.gradle.kts └── src │ └── main │ ├── kotlin │ └── adoptimize │ │ ├── ADOptimizeCommandLineProcessor.kt │ │ ├── ADOptimizeCompilerConfigurationExtension.kt │ │ ├── ADOptimizeComponentRegistrar.kt │ │ ├── ADOptimizeIrGenerationExtension.kt │ │ ├── DependencyContainer.kt │ │ ├── ErrorHandling.kt │ │ └── autodiff │ │ ├── AutoDiffCodeWriter.kt │ │ ├── AutoDiffCodeWriterVendor.kt │ │ ├── BoxedReverseNodeCustomizer.kt │ │ ├── ForwardModeDifferentiator.kt │ │ ├── GuardedScope.kt │ │ ├── Metadata │ │ ├── BoxedPrimitiveInfo.kt │ │ ├── DifferentiableApi.kt │ │ ├── ReverseScalarClass.kt │ │ └── StackClass.kt │ │ ├── NodeCodeCopy │ │ ├── AutoDiffCodeWriterImpl.kt │ │ ├── AutoDiffCustomTypeInliner.kt │ │ ├── AutoDiffInliner.kt │ │ ├── AutoDiffOperationOverloadWriter.kt │ │ └── NodeAnalyzer.kt │ │ ├── NodePopulation │ │ └── CustomReverseNodePopulator.kt │ │ ├── PrimalFunctionTransformer.kt │ │ ├── ReverseForwardNodeCustomizer.kt │ │ ├── ReverseScalarClassCreator.kt │ │ ├── UnwrapppedNode │ │ ├── BackpropReplacer.kt │ │ ├── CallLowerer.kt │ │ ├── PrimalReplacer.kt │ │ ├── PropertyCopier.kt │ │ ├── Replacer.kt │ │ └── UnboxedReverseNodeCustomizer.kt │ │ ├── Util.kt │ │ ├── diffIR │ │ ├── DiffIR.kt │ │ ├── DiffIRCreator.kt │ │ ├── DiffIRTransformer.kt │ │ └── DiffIRVisitor.kt │ │ ├── forwards │ │ └── TangentRecorder.kt │ │ └── reverse │ │ └── PullbackGenerator.kt │ └── resources │ └── META-INF │ └── services │ ├── org.jetbrains.kotlin.compiler.plugin.CommandLineProcessor │ └── org.jetbrains.kotlin.compiler.plugin.ComponentRegistrar ├── adoptimize-common ├── build.gradle.kts └── src │ └── main │ └── kotlin │ └── adOptimizeCommon │ └── DifferentiableApi.kt ├── adoptimize-gradle-plugin ├── build.gradle.kts └── src │ └── main │ └── kotlin │ └── adoptimize │ └── gradle │ ├── ADOptimizeExtension.kt │ └── ADOptimizeGradleExtension.kt ├── adoptimize-integration-tests ├── build.gradle.kts └── src │ └── test │ ├── kotlin │ └── adoptimize │ │ ├── ADOptimizeBlackBoxTest.kt │ │ ├── ADOptimizeConfigurator.kt │ │ ├── ADOptimizeIRTest.kt │ │ ├── AbstractADOptimizeBlackBoxTest.kt │ │ └── AbstractADOptimizeIrTest.kt │ └── testData │ ├── codegen │ ├── activeArgument.kt │ ├── assignOperations.kt │ ├── constArg.kt │ ├── control_flow.kt │ ├── control_flow_derivative.kt │ ├── control_flow_nested_if.kt │ ├── diffkt.kt │ ├── elseLower.kt │ ├── exp.kt │ ├── floatFunction.kt │ ├── forwardsUnbox.kt │ ├── getValInitializer.kt │ ├── getterNoExplicitUnbox.kt │ ├── if_statement.kt │ ├── implicitParameter.kt │ ├── initialization.kt │ ├── logProb.kt │ ├── multipleOutputs.kt │ ├── multipleOutputsControlFlow.kt │ ├── nestedWhenVariable.kt │ ├── nonActiveArgument.kt │ ├── nullArgument.kt │ ├── parameterWithNonParameterizedTypeArg.kt │ ├── reverseForwardControlFlow.kt │ ├── reverseForwardControlFlowNested.kt │ ├── reverseForwardControlFlowPrimal.kt │ ├── reverseForwardControlFlowWhen.kt │ ├── reverseForwardControlFlowWhenPrimal.kt │ ├── reverseForwardDerivative.kt │ ├── reverseForwardLn.kt │ ├── reverseForwardLogProb.kt │ ├── reverseForwardNonActiveIntermediateValues.kt │ ├── reverseForwardPrimal.kt │ ├── scalarNoop.kt │ ├── simpleWhileLoop.kt │ ├── switchAssign.kt │ ├── unwrapONE.kt │ ├── unwrapZERO.kt │ └── while_statement.kt │ └── ir │ ├── control_flow.ir.txt │ ├── control_flow.kt │ ├── derivative.ir.txt │ ├── derivative.kt │ ├── firstAndSecondOrderDerivative.ir.txt │ ├── firstAndSecondOrderDerivative.kt │ ├── floatFunction.ir.txt │ ├── floatFunction.kt │ ├── secondOrderDerivative.ir.txt │ └── secondOrderDerivative.kt ├── adoptimize-publish └── build.gradle.kts ├── build.gradle.kts ├── buildSrc ├── build.gradle.kts ├── gradle.properties ├── prepare-deps │ └── build.gradle.kts ├── settings.gradle.kts └── src │ └── main │ └── kotlin │ └── tasks.kt ├── config └── build.gradle.kts ├── differentiable-api-preprocessor-compiler-plugin ├── build.gradle.kts └── src │ └── main │ ├── kotlin │ └── diffPrep │ │ ├── DiffPrepClassLifterDelegate.kt │ │ ├── DiffPrepErrorMessagesExtension.kt │ │ ├── DiffPrepErrors.kt │ │ ├── DifferentiableApiPreprocessorCommandLineProcessor.kt │ │ ├── DifferentiableApiPreprocessorCompilerConfigurationExtension.kt │ │ ├── DifferentiableApiPreprocessorComponentRegistrar.kt │ │ ├── DifferentiableApiPreprocessorIrGenerationExtension.kt │ │ ├── ErrorHandling.kt │ │ ├── analysisHandler │ │ └── DifferentiableApiPreprocessorAnalysisHandlerExtension.kt │ │ └── metadata │ │ ├── BoxedPrimitiveInfo.kt │ │ ├── DifferentiableApi.kt │ │ ├── DifferentiableApiBuilder.kt │ │ └── StackClass.kt │ └── resources │ └── META-INF │ └── services │ ├── org.jetbrains.kotlin.compiler.plugin.CommandLineProcessor │ └── org.jetbrains.kotlin.compiler.plugin.ComponentRegistrar ├── differentiable-api-preprocessor-gradle-plugin ├── build.gradle.kts └── src │ └── main │ └── kotlin │ └── diffPrep │ └── gradle │ ├── DifferentiableApiPreprocessorExtension.kt │ └── DifferentiableApiPreprocessorGradleExtension.kt ├── differentiable-api-preprocessor-integration-tests ├── build.gradle.kts └── src │ └── test │ ├── kotlin │ └── diffPrep │ │ ├── AbstractDifferentiablePreprocessorBlackBoxTest.kt │ │ ├── AbstractDifferentiablePreprocessorIrTest.kt │ │ ├── AbstractDifferentiablePrepropessorDiagnosticTests.kt │ │ ├── DifferentiablePreprocessorBaseTest.kt │ │ ├── DifferentiablePreprocessorConfigurator.kt │ │ ├── DifferentiablePreprocessorDiagnosticTest.kt │ │ ├── DifferentiablePreprocessorIrTest.kt │ │ └── DifferentiablePreprocessorWithTempDirectoryTest.kt │ └── testData │ ├── diagnostics │ ├── toUnBoxFunctionInvalidSignature.kt │ ├── toUnboxClassMethods.kt │ └── validApi.kt │ ├── ir │ ├── api.ir.txt │ ├── api.kt │ ├── resources │ │ └── adoptimize.properties │ ├── typeOperator.ir.txt │ └── typeOperator.kt │ └── withTmpDir │ ├── multiModule │ └── validApi.kt │ └── singleModule │ └── validApi.kt ├── differentiable-api-preprocessor-publish └── build.gradle.kts ├── gradle.properties ├── gradle └── wrapper │ ├── gradle-wrapper.jar │ └── gradle-wrapper.properties ├── gradlew ├── gradlew.bat ├── optimizer-plugins ├── adoptimize-integration-tests │ └── testDependencies │ │ └── org │ │ └── diffApi │ │ └── 0.0.1 │ │ └── diffApi-0.0.1.jar └── bmgoptimize-integration-tests │ └── testDependencies │ └── org │ └── bmgApi │ └── 0.0.1 │ └── bmgApi-0.0.1.jar ├── plugin-generators-common ├── build.gradle.kts └── src │ ├── main │ └── kotlin │ │ └── pluginCommon │ │ ├── CopyAndReplacer.kt │ │ ├── DependencyContainer.kt │ │ ├── ErrorHandling.kt │ │ ├── ScopeSubstitutionMap.kt │ │ ├── Substitutor.kt │ │ ├── generators │ │ ├── DescriptorWrappers.kt │ │ ├── GeneratedAuthenticClass.kt │ │ ├── GeneratedAuthenticClassDescriptor.kt │ │ ├── IrBodyGenerator.kt │ │ ├── IrClassGenerator.kt │ │ ├── IrFunctionGenerator.kt │ │ ├── IrPropertyGenerator.kt │ │ ├── IrUtil.kt │ │ ├── ParameterInfo.kt │ │ └── WatchableMutableList.kt │ │ └── lowerings │ │ ├── ClassLifter.kt │ │ ├── ElseBranchLowering.kt │ │ ├── FunctionLowering.kt │ │ ├── RedundantVariableRemover.kt │ │ ├── ShallowTransformer.kt │ │ ├── UnitCastTransformer.kt │ │ ├── UnnestLowering.kt │ │ └── VariableWhenLowering.kt │ └── test │ └── kotlin │ └── pluginCommon │ └── DependencyContainerTests.kt ├── producer-consumer ├── README.md ├── build.gradle.kts ├── consumer │ ├── build.gradle.kts │ └── src │ │ └── main │ │ └── kotlin │ │ └── main.kt ├── gradle.properties ├── gradle │ └── wrapper │ │ ├── gradle-wrapper.jar │ │ └── gradle-wrapper.properties ├── gradlew ├── gradlew.bat ├── producer │ ├── build.gradle.kts │ └── src │ │ └── main │ │ ├── kotlin │ │ └── fauxDiffKt.kt │ │ └── resources │ │ └── adoptimize.properties ├── settings.gradle.kts └── src │ └── main │ └── resources │ └── adoptimize.properties └── settings.gradle.kts /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | [*.{kt,kts}] 3 | disabled_rules=no-wildcard-imports,filename,indent -------------------------------------------------------------------------------- /.github/workflows/linter.yml: -------------------------------------------------------------------------------- 1 | ################################# 2 | ################################# 3 | ## Super Linter GitHub Actions ## 4 | ################################# 5 | ################################# 6 | name: Lint Code Base 7 | 8 | # 9 | # Documentation: 10 | # https://docs.github.com/en/actions/learn-github-actions/workflow-syntax-for-github-actions 11 | # 12 | 13 | ############################# 14 | # Start the job on all push # 15 | ############################# 16 | on: 17 | pull_request: 18 | branches: [main] 19 | 20 | ############### 21 | # Set the Job # 22 | ############### 23 | jobs: 24 | build: 25 | # Name the Job 26 | name: Lint Code Base 27 | # Set the agent to run on 28 | runs-on: ubuntu-latest 29 | 30 | ################## 31 | # Load all steps # 32 | ################## 33 | steps: 34 | ########################## 35 | # Checkout the code base # 36 | ########################## 37 | - name: Checkout Code 38 | uses: actions/checkout@v2 39 | with: 40 | # Full git history is needed to get a proper list of changed files within `super-linter` 41 | fetch-depth: 0 42 | 43 | ################################ 44 | # Run Linter against code base # 45 | ################################ 46 | - name: Lint Code Base 47 | uses: github/super-linter@v4 48 | env: 49 | VALIDATE_ALL_CODEBASE: false 50 | VALIDATE_JSCPD: false 51 | VALIDATE_GITHUB_ACTIONS: false 52 | VALIDATE_KOTLIN_ANDROID: false 53 | VALIDATE_BASH: false 54 | VALIDATE_SHELL_SHFMT: false 55 | VALIDATE_MARKDOWN: false 56 | VALIDATE_NATURAL_LANGUAGE: false 57 | # Change to 'master' if your main branch differs 58 | DEFAULT_BRANCH: main 59 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} -------------------------------------------------------------------------------- /.github/workflows/performanceTests.yml: -------------------------------------------------------------------------------- 1 | name: Performance Tests 2 | on: 3 | schedule: 4 | - cron: '30 23 * * *' 5 | jobs: 6 | testPublish: 7 | runs-on: macos-latest 8 | permissions: 9 | contents: read 10 | packages: write 11 | steps: 12 | - name: set up jdk 11 13 | uses: actions/setup-java@v2 14 | with: 15 | java-version: '8' 16 | distribution: 'adopt' 17 | - name: install diffkt dependencies 18 | run: | 19 | /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)" 20 | brew install onednn 21 | brew install libomp 22 | - uses: actions/checkout@v2 23 | - name: client performance tests 24 | run: | 25 | ./gradlew publishToMavenLocal 26 | pushd adOptimizeClient 27 | ./gradlew test 28 | popd 29 | env: 30 | GITHUB_TOKEN: ${{ secrets.FBTOKEN }} -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Test and Publish 2 | on: 3 | push: 4 | branches: 5 | main 6 | pull_request: 7 | branches: 8 | main 9 | jobs: 10 | testPublish: 11 | runs-on: macos-11 12 | permissions: 13 | contents: read 14 | packages: write 15 | steps: 16 | - name: set up jdk 8 17 | uses: actions/setup-java@v2 18 | with: 19 | java-version: '8' 20 | distribution: 'adopt' 21 | - name: install diffkt dependencies 22 | run: | 23 | /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)" 24 | brew install onednn 25 | brew install libomp 26 | - uses: actions/checkout@v2 27 | - name: test 28 | run: ./gradlew integrationTests 29 | env: 30 | GITHUB_TOKEN: ${{ secrets.FBTOKEN }} 31 | - name: upload ad test results 32 | uses: actions/upload-artifact@v2 33 | if: always() 34 | with: 35 | name: adOptimizeTestHtml 36 | path: ./**/adoptimize-integration-tests/build/reports/tests/test/ 37 | - name: upload diffPrep test results 38 | uses: actions/upload-artifact@v2 39 | if: always() 40 | with: 41 | name: diffPrepTestHtml 42 | path: ./**/differentiable-api-preprocessor-integration-tests/build/reports/tests/test/ 43 | - name: client tests 44 | run: | 45 | ./gradlew publishToMavenLocal 46 | pushd adOptimizeClient 47 | ./gradlew run 48 | popd 49 | env: 50 | GITHUB_TOKEN: ${{ secrets.FBTOKEN }} 51 | - name: consumer producer tests 52 | run: | 53 | ./gradlew publishToMavenLocal 54 | pushd producer-consumer 55 | ./gradlew :producer:publishToMavenLocal 56 | ./gradlew :consumer:run 57 | popd 58 | env: 59 | GITHUB_TOKEN: ${{ secrets.FBTOKEN }} 60 | - name: publish 61 | if: success() 62 | run: ./gradlew publish -Pgroup=org.diffkt.adoptimize -Pversion=0.1.1-$(git rev-parse --short HEAD) 63 | env: 64 | GITHUB_ACTOR: ${{ secrets.DIFFKT_ACTOR }} 65 | GITHUB_TOKEN: ${{ secrets.DIFFKT_TOKEN }} 66 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | github.env 3 | *.jar 4 | hs_err_pid* 5 | gradle-app.setting 6 | !gradle-wrapper.jar 7 | 8 | 9 | # ignore Gradle project-specific cache directory 10 | .gradle 11 | 12 | # Ignore Gradle build output directory 13 | build 14 | 15 | .idea 16 | 17 | out 18 | 19 | # Ignore kotlin test 20 | .kotlintest 21 | 22 | # ignore the dist that is populated by build 23 | dist 24 | 25 | # ignore test build directories 26 | adoptimize-integration-tests/testDependencies/* 27 | bmgoptimize-integration-tests/testDependencies/* 28 | -------------------------------------------------------------------------------- /AdOptimize.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/optimizer-plugins/0695dc024d4a2f6eaf0559026efcda2a66b5e810/AdOptimize.png -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to optimizer-plugins 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to optimizer-plugins, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | MIT License 3 | 4 | Copyright (c) Meta Platforms, Inc. and affiliates. 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. -------------------------------------------------------------------------------- /adOptimizeClient/build.gradle.kts: -------------------------------------------------------------------------------- 1 | /* 2 | * 3 | * Copyright (c) Meta Platforms, Inc. and affiliates. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | 10 | description = "client" 11 | 12 | plugins { 13 | val ktVersion: String by System.getProperties() 14 | kotlin("jvm") version ktVersion 15 | id("meta-diffkt-adoptimize") version "0.0.1-SNAPSHOT" 16 | application 17 | } 18 | 19 | val diffKtVersion = "0.1.0-2d523b5" 20 | 21 | java { 22 | sourceCompatibility = JavaVersion.VERSION_1_8 23 | targetCompatibility = JavaVersion.VERSION_1_8 24 | toolchain { 25 | targetCompatibility = JavaVersion.VERSION_1_8 26 | } 27 | } 28 | 29 | repositories { 30 | mavenLocal() 31 | mavenCentral() 32 | maven { 33 | url = uri("https://maven.pkg.jetbrains.space/kotlin/p/kotlin/bootstrap") 34 | } 35 | maven { 36 | url = uri("https://maven.pkg.github.com/facebookresearch/diffkt") 37 | credentials { 38 | username = System.getenv("GITHUB_ACTOR") 39 | password = System.getenv("GITHUB_TOKEN") 40 | } 41 | } 42 | } 43 | 44 | adOptimize { 45 | this.diffApi("org.diffkt.adopt", "api", diffKtVersion) 46 | this.optimizeAnnotation("config.Optimize") 47 | this.secondOrderAnnotation("config.SecondOrderOptimize") 48 | this.failOnADFail(true) 49 | this.reverseADFunction("config.ReverseAD") 50 | } 51 | 52 | dependencies { 53 | implementation(kotlin("stdlib-jdk8")) 54 | implementation(kotlin("reflect")) 55 | implementation(group = "org.diffkt.adopt", name = "api", version = diffKtVersion) 56 | 57 | testRuntimeOnly("org.junit.platform:junit-platform-launcher:1.8.0-M1") 58 | testRuntimeOnly("org.junit.jupiter:junit-jupiter-engine:5.7.1") 59 | testImplementation("junit", "junit", "4.12") 60 | testImplementation("org.junit.jupiter:junit-jupiter-api:5.7.1") 61 | testImplementation(group = "org.diffkt.adopt", name = "api", version = diffKtVersion) 62 | } 63 | 64 | application { 65 | mainClass.set("client.MainKt") 66 | } 67 | 68 | tasks.withType { 69 | useJUnitPlatform() 70 | } 71 | -------------------------------------------------------------------------------- /adOptimizeClient/gradle.properties: -------------------------------------------------------------------------------- 1 | # 2 | # 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | # 9 | 10 | systemProp.ktVersion=1.7.0-dev-444 -------------------------------------------------------------------------------- /adOptimizeClient/gradle/wrapper/gradle-wrapper.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/optimizer-plugins/0695dc024d4a2f6eaf0559026efcda2a66b5e810/adOptimizeClient/gradle/wrapper/gradle-wrapper.jar -------------------------------------------------------------------------------- /adOptimizeClient/gradle/wrapper/gradle-wrapper.properties: -------------------------------------------------------------------------------- 1 | # 2 | # 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | # 9 | 10 | distributionBase=GRADLE_USER_HOME 11 | distributionPath=wrapper/dists 12 | distributionUrl=https\://services.gradle.org/distributions/gradle-7.1-bin.zip 13 | zipStoreBase=GRADLE_USER_HOME 14 | zipStorePath=wrapper/dists 15 | -------------------------------------------------------------------------------- /adOptimizeClient/gradlew.bat: -------------------------------------------------------------------------------- 1 | @rem 2 | @rem Copyright 2015 the original author or authors. 3 | @rem 4 | @rem Licensed under the Apache License, Version 2.0 (the "License"); 5 | @rem you may not use this file except in compliance with the License. 6 | @rem You may obtain a copy of the License at 7 | @rem 8 | @rem https://www.apache.org/licenses/LICENSE-2.0 9 | @rem 10 | @rem Unless required by applicable law or agreed to in writing, software 11 | @rem distributed under the License is distributed on an "AS IS" BASIS, 12 | @rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | @rem See the License for the specific language governing permissions and 14 | @rem limitations under the License. 15 | @rem 16 | 17 | @if "%DEBUG%" == "" @echo off 18 | @rem ########################################################################## 19 | @rem 20 | @rem Gradle startup script for Windows 21 | @rem 22 | @rem ########################################################################## 23 | 24 | @rem Set local scope for the variables with windows NT shell 25 | if "%OS%"=="Windows_NT" setlocal 26 | 27 | set DIRNAME=%~dp0 28 | if "%DIRNAME%" == "" set DIRNAME=. 29 | set APP_BASE_NAME=%~n0 30 | set APP_HOME=%DIRNAME% 31 | 32 | @rem Resolve any "." and ".." in APP_HOME to make it shorter. 33 | for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi 34 | 35 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 36 | set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" 37 | 38 | @rem Find java.exe 39 | if defined JAVA_HOME goto findJavaFromJavaHome 40 | 41 | set JAVA_EXE=java.exe 42 | %JAVA_EXE% -version >NUL 2>&1 43 | if "%ERRORLEVEL%" == "0" goto init 44 | 45 | echo. 46 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 47 | echo. 48 | echo Please set the JAVA_HOME variable in your environment to match the 49 | echo location of your Java installation. 50 | 51 | goto fail 52 | 53 | :findJavaFromJavaHome 54 | set JAVA_HOME=%JAVA_HOME:"=% 55 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe 56 | 57 | if exist "%JAVA_EXE%" goto init 58 | 59 | echo. 60 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 61 | echo. 62 | echo Please set the JAVA_HOME variable in your environment to match the 63 | echo location of your Java installation. 64 | 65 | goto fail 66 | 67 | :init 68 | @rem Get command-line arguments, handling Windows variants 69 | 70 | if not "%OS%" == "Windows_NT" goto win9xME_args 71 | 72 | :win9xME_args 73 | @rem Slurp the command line arguments. 74 | set CMD_LINE_ARGS= 75 | set _SKIP=2 76 | 77 | :win9xME_args_slurp 78 | if "x%~1" == "x" goto execute 79 | 80 | set CMD_LINE_ARGS=%* 81 | 82 | :execute 83 | @rem Setup the command line 84 | 85 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar 86 | 87 | 88 | @rem Execute Gradle 89 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% 90 | 91 | :end 92 | @rem End local scope for the variables with windows NT shell 93 | if "%ERRORLEVEL%"=="0" goto mainEnd 94 | 95 | :fail 96 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of 97 | rem the _cmd.exe /c_ return code! 98 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 99 | exit /b 1 100 | 101 | :mainEnd 102 | if "%OS%"=="Windows_NT" endlocal 103 | 104 | :omega 105 | -------------------------------------------------------------------------------- /adOptimizeClient/settings.gradle.kts: -------------------------------------------------------------------------------- 1 | /* 2 | * 3 | * Copyright (c) Meta Platforms, Inc. and affiliates. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | 10 | rootProject.name = "adOptimizeClient" 11 | 12 | pluginManagement { 13 | repositories { 14 | mavenLocal() 15 | gradlePluginPortal() 16 | maven { 17 | url = uri("https://maven.pkg.jetbrains.space/kotlin/p/kotlin/bootstrap") 18 | } 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /adOptimizeClient/src/main/kotlin/AdConfig.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * 3 | * Copyright (c) Meta Platforms, Inc. and affiliates. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | 10 | package config 11 | 12 | annotation class Optimize 13 | annotation class SecondOrderOptimize 14 | annotation class ReverseAD 15 | -------------------------------------------------------------------------------- /adOptimizeClient/src/main/kotlin/client/main.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * 3 | * Copyright (c) Meta Platforms, Inc. and affiliates. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | 10 | package client 11 | 12 | import config.Optimize 13 | import config.ReverseAD 14 | import config.SecondOrderOptimize 15 | import org.diffkt.* 16 | import java.lang.IllegalStateException 17 | import kotlin.math.cos 18 | import kotlin.math.pow 19 | 20 | @ReverseAD 21 | fun jacobian_transposed_vector_product(x: Float, f: (Float) -> Float): Float { 22 | TODO() 23 | } 24 | 25 | @Optimize 26 | fun target(a: Float): Float { 27 | val i0 = a.pow(2f) 28 | val i1 = i0 + i0 29 | val i2 = cos(i1) 30 | return i2 31 | } 32 | 33 | @SecondOrderOptimize 34 | @Optimize 35 | fun target(a: DScalar): DScalar { 36 | val i0 = a.pow(2f) 37 | val i1 = i0 + i0 38 | val i2 = cos(i1) 39 | return i2 40 | } 41 | 42 | fun nonOptimal_target(a: DScalar): DScalar { 43 | val i0 = a.pow(2f) 44 | val i1 = i0 + i0 45 | val i2 = cos(i1) 46 | return i2 47 | } 48 | 49 | fun box(): String { 50 | val x = FloatScalar(2.15f) 51 | val floatDerivative = jacobian_transposed_vector_product(x.value, ::target) 52 | val derivative = primalAndReverseDerivative(x, { t: DScalar -> target(t) }) 53 | val secondOrderDerivative = primalAndForwardDerivative( 54 | x = x, 55 | f = { z: DScalar -> primalAndReverseDerivative(z, ::target).second } 56 | ) 57 | val secondOrderDerivativeExpectation: Pair = primalAndForwardDerivative( 58 | x = x, 59 | f = { z: DScalar -> primalAndReverseDerivative(z, ::nonOptimal_target).second } 60 | ) 61 | val expected_derivative = primalAndReverseDerivative(x, { t: DScalar -> nonOptimal_target(t) }) 62 | val tol = 0.000001 63 | 64 | if (Math.abs(derivative.first.basePrimal().value - expected_derivative.first.basePrimal().value) > tol) { 65 | return "PRIMAL FAIL: expected ${expected_derivative.first.basePrimal().value} but got ${derivative.first.basePrimal().value}" 66 | } 67 | if (Math.abs(secondOrderDerivative.first.basePrimal().value - secondOrderDerivativeExpectation.first.basePrimal().value) > tol) { 68 | return "FIRST Derivative FAIL: expected ${secondOrderDerivativeExpectation.first.basePrimal().value} but got ${secondOrderDerivative.first.basePrimal().value}" 69 | } 70 | if (Math.abs(secondOrderDerivative.second.basePrimal().value - secondOrderDerivativeExpectation.second.basePrimal().value) > tol) { 71 | return "Second Derivative FAIL: expected ${secondOrderDerivativeExpectation.second.basePrimal().value} but got ${secondOrderDerivative.second.basePrimal().value}" 72 | } 73 | if (Math.abs(derivative.second.basePrimal().value - expected_derivative.second.basePrimal().value) > tol) { 74 | return "DERIVATIVE FAIL: expected ${expected_derivative.second.basePrimal().value} but got ${derivative.second.basePrimal().value}" 75 | } 76 | if (Math.abs(derivative.second.basePrimal().value - floatDerivative) > tol) { 77 | return "Float derivative FAIL: expected ${derivative.first.basePrimal().value} but got $floatDerivative" 78 | } 79 | 80 | return "OK" 81 | } 82 | 83 | fun main() { 84 | val outcome = box() 85 | if (outcome != "OK") { 86 | throw IllegalStateException("Box test failed: $outcome") 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /adoptimize-cli-compiler-plugin/build.gradle.kts: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | import com.github.jengelman.gradle.plugins.shadow.tasks.ShadowJar 9 | description = "AD optimize cli compiler plugin" 10 | 11 | plugins { 12 | id("com.github.johnrengelman.shadow") version "6.1.0" 13 | } 14 | 15 | dependencies { 16 | val ktVersion: String by System.getProperties() 17 | compileOnly("org.jetbrains.kotlin:kotlin-stdlib:$ktVersion") 18 | api("org.jetbrains.kotlin:kotlin-compiler-embeddable:$ktVersion") 19 | compileOnly(project(":adoptimize-common")) 20 | compileOnly(project(":plugin-generators-common")) 21 | } 22 | 23 | sourceSets { 24 | main {} 25 | test { java.srcDirs("test", "tests") } 26 | } 27 | 28 | val shadowArtifact by configurations.creating 29 | val shadowJar: ShadowJar = tasks.getByName("shadowJar") { 30 | val convention = project.convention.getPlugin() 31 | archiveClassifier.set("sources") 32 | from(convention.sourceSets.main.get().output) 33 | configurations = mutableListOf(project.configurations.compileOnly.get()) 34 | relocate("org.jetbrains.org.objectweb.asm.tree.analysis", "org.objectweb.asm.tree.analysis") 35 | relocate("org.jetbrains.kotlin.com.intellij", "com.intellij") 36 | dependencies { 37 | exclude(dependency("org.jetbrains.kotlin:kotlin-stdlib")) 38 | // and its transitive dependencies: 39 | exclude(dependency("org.jetbrains.kotlin:kotlin-stdlib-common")) 40 | exclude(dependency("org.jetbrains:annotations")) 41 | 42 | exclude(dependency("com.intellij:openapi")) 43 | // and its transitive dependencies: 44 | exclude(dependency("com.intellij:extensions")) 45 | exclude(dependency("com.intellij:annotations")) 46 | } 47 | } 48 | 49 | artifacts { 50 | add(shadowArtifact.name, shadowJar) 51 | } 52 | -------------------------------------------------------------------------------- /adoptimize-cli-compiler-plugin/src/main/kotlin/adoptimize/ADOptimizeCompilerConfigurationExtension.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package adoptimize 9 | import org.jetbrains.kotlin.config.CompilerConfiguration 10 | import org.jetbrains.kotlin.config.JVMConfigurationKeys 11 | import org.jetbrains.kotlin.config.JvmSerializeIrMode 12 | import org.jetbrains.kotlin.extensions.CompilerConfigurationExtension 13 | 14 | class ADOptimizeCompilerConfigurationExtension : CompilerConfigurationExtension { 15 | override fun updateConfiguration(configuration: CompilerConfiguration) { 16 | configuration.put(JVMConfigurationKeys.IR, true) 17 | configuration.put(JVMConfigurationKeys.SERIALIZE_IR, JvmSerializeIrMode.INLINE) 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /adoptimize-cli-compiler-plugin/src/main/kotlin/adoptimize/DependencyContainer.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package adoptimize 9 | 10 | import adoptimize.autodiff.AutoDiffCodeWriterVendor 11 | import adoptimize.autodiff.BackPropFunction.DiffIRCreator 12 | import adoptimize.autodiff.BoxedReverseNodeCustomizer 13 | import adoptimize.autodiff.Metadata.DifferentiableApi 14 | import adoptimize.autodiff.Metadata.StackClass 15 | import adoptimize.autodiff.NodePopulation.CustomReverseNodePopulator 16 | import adoptimize.autodiff.ReverseScalarClassCreator 17 | import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext 18 | import org.jetbrains.kotlin.ir.IrBuiltIns 19 | import org.jetbrains.kotlin.ir.types.IrType 20 | import org.jetbrains.kotlin.ir.util.defaultType 21 | import pluginCommon.DependencyContainer 22 | import pluginCommon.generators.IrBodyGenerator 23 | import pluginCommon.generators.IrClassGenerator 24 | import pluginCommon.generators.IrFunctionGenerator 25 | import pluginCommon.generators.IrPropertyGenerator 26 | import pluginCommon.lowerings.* 27 | 28 | fun createAdOptimizeDependencyContainer(differentiableApi: DifferentiableApi, stackClass: StackClass, pluginContext: IrPluginContext): DependencyContainer { 29 | val container = DependencyContainer() 30 | with(container) { 31 | put(pluginContext) 32 | put(pluginContext.irBuiltIns) 33 | val redundantVariableRemover = RedundantVariableRemover( 34 | setOf( 35 | differentiableApi.reverseDiffScalarClass.clazz.defaultType, 36 | differentiableApi.rootDifferentiableType, 37 | differentiableApi.forwardDiffScalarClass.clazz.defaultType, 38 | differentiableApi.boxedPrimitiveInfo.boxedPrimitiveClass.defaultType 39 | ) 40 | ) 41 | put(redundantVariableRemover) 42 | put(differentiableApi) 43 | put(stackClass) 44 | put() 45 | put() 46 | put() 47 | put() 48 | put() 49 | put() 50 | put() 51 | put() 52 | put() 53 | put() 54 | put() 55 | put() 56 | put() 57 | } 58 | return container 59 | } 60 | -------------------------------------------------------------------------------- /adoptimize-cli-compiler-plugin/src/main/kotlin/adoptimize/ErrorHandling.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package adoptimize 9 | 10 | class AutoDiffException(message: String) : Exception(message) 11 | -------------------------------------------------------------------------------- /adoptimize-cli-compiler-plugin/src/main/kotlin/adoptimize/autodiff/AutoDiffCodeWriter.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package adoptimize.autodiff 9 | 10 | import adoptimize.autodiff.diffIR.CallVariable 11 | import org.jetbrains.kotlin.ir.declarations.IrFunction 12 | import org.jetbrains.kotlin.ir.declarations.IrValueDeclaration 13 | import org.jetbrains.kotlin.ir.declarations.IrVariable 14 | import pluginCommon.ScopeSubstitutionMap 15 | 16 | typealias DerivativeContributions = List> 17 | interface WrittenDeclarations 18 | class PrimalAndPullback(val primal: IrValueDeclaration, val pullback: IrValueDeclaration) : WrittenDeclarations 19 | class Primal(val primal: IrValueDeclaration) : WrittenDeclarations 20 | 21 | interface AutoDiffCodeWriter { 22 | fun writeBackpropCodeForLeaf( 23 | leaf: CallVariable, 24 | primalToLocalMap: ScopeSubstitutionMap, 25 | currentUpstream: IrVariable, 26 | backPropMethod: IrFunction, 27 | guardedScope: GuardedScope, 28 | pullback: IrValueDeclaration? 29 | ): DerivativeContributions 30 | 31 | fun writeInitCodeForLeaf( 32 | leaf: CallVariable, 33 | primalToLocalMap: ScopeSubstitutionMap, 34 | guardedScope: GuardedScope, 35 | declarationParent: org.jetbrains.kotlin.ir.declarations.IrDeclarationParent 36 | ): WrittenDeclarations? 37 | } 38 | -------------------------------------------------------------------------------- /adoptimize-cli-compiler-plugin/src/main/kotlin/adoptimize/autodiff/AutoDiffCodeWriterVendor.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package adoptimize.autodiff 9 | 10 | import adoptimize.autodiff.Metadata.DifferentiableApi 11 | import adoptimize.autodiff.NodeCodeCopy.AutoDiffCodeWriterImpl 12 | import adoptimize.autodiff.diffIR.DiffIRFunction 13 | import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext 14 | import pluginCommon.generators.IrBodyGenerator 15 | import pluginCommon.generators.IrFunctionGenerator 16 | 17 | class AutoDiffCodeWriterVendor( 18 | val callGenerator: IrBodyGenerator, 19 | val differentiableApi: DifferentiableApi, 20 | val functionGenerator: IrFunctionGenerator, 21 | val context: IrPluginContext 22 | ) { 23 | fun codeWriter(primalFunction: DiffIRFunction): AutoDiffCodeWriter = AutoDiffCodeWriterImpl( 24 | callGenerator, differentiableApi, primalFunction, 25 | functionGenerator, context 26 | ) 27 | } 28 | -------------------------------------------------------------------------------- /adoptimize-cli-compiler-plugin/src/main/kotlin/adoptimize/autodiff/BoxedReverseNodeCustomizer.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package adoptimize.autodiff 9 | 10 | import adoptimize.autodiff.Metadata.ActiveParameterRequirement 11 | import adoptimize.autodiff.Metadata.DifferentiableApi 12 | import adoptimize.autodiff.Metadata.ParamMapType 13 | import adoptimize.autodiff.Metadata.ParameterMap 14 | import adoptimize.autodiff.Metadata.ReverseScalarClass 15 | import adoptimize.autodiff.NodePopulation.CustomReverseNodePopulator 16 | import adoptimize.autodiff.diffIR.DiffIRFunction 17 | import org.jetbrains.kotlin.ir.symbols.IrClassSymbol 18 | import org.jetbrains.kotlin.ir.types.classifierOrFail 19 | import org.jetbrains.kotlin.ir.types.isSubtypeOfClass 20 | import org.jetbrains.kotlin.ir.util.defaultType 21 | import org.jetbrains.kotlin.name.Name 22 | import pluginCommon.generators.ParameterInfo 23 | 24 | class BoxedReverseNodeCustomizer(val differentiableApi: DifferentiableApi, val populator: CustomReverseNodePopulator) : ReverseNodeCustomizer { 25 | override fun buildParameterInfos(originValueParameter: ParameterWithIndex): List> { 26 | val type = originValueParameter.valueDescriptor.type.classifierOrFail as IrClassSymbol 27 | val baseName = correctSpecializedNames(originValueParameter.valueDescriptor.name.toString()) 28 | return if (differentiableApi.reverseDiffScalarClass.clazz.defaultType.isSubtypeOfClass(type)) { 29 | val parameterInfo = ParameterInfo(Name.identifier("${baseName}Node"), differentiableApi.reverseDiffScalarClass.clazz.defaultType) 30 | val parameterMap = ParameterMap(originValueParameter.index, ParamMapType.CastToReverse, parameterInfo.name, true) 31 | listOf(Pair(parameterInfo, parameterMap)) 32 | } else { 33 | val parameterInfo = ParameterInfo(Name.identifier(baseName), originValueParameter.valueDescriptor.type) 34 | val parameterMap = ParameterMap(originValueParameter.index, ParamMapType.NoOp, parameterInfo.name, false) 35 | listOf(Pair(parameterInfo, parameterMap)) 36 | } 37 | } 38 | 39 | override fun typeRequirements(): List = listOf(ActiveParameterRequirement.Reverse) 40 | 41 | override fun name(primalName: String): String = "${primalName}Reverse" 42 | 43 | override fun populate(primalFunction: DiffIRFunction, shellClass: ReverseScalarClass) { 44 | populator.populate(shellClass, primalFunction) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /adoptimize-cli-compiler-plugin/src/main/kotlin/adoptimize/autodiff/Metadata/BoxedPrimitiveInfo.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package adoptimize.autodiff.Metadata 9 | 10 | import org.jetbrains.kotlin.ir.declarations.IrClass 11 | import org.jetbrains.kotlin.ir.declarations.IrProperty 12 | import org.jetbrains.kotlin.ir.types.IrType 13 | import org.jetbrains.kotlin.ir.util.properties 14 | 15 | class BoxedPrimitiveInfo( 16 | val boxedPrimitiveClass: IrClass, 17 | val valueProperty: IrProperty, 18 | val primitiveType: IrType, 19 | val scalarZeroObjectProperty: IrProperty, 20 | val scalarOneObjectProperty: IrProperty 21 | ) { 22 | init { 23 | if (!boxedPrimitiveClass.properties.contains(valueProperty)) { 24 | throw IllegalStateException("The value property must be a property of the boxedPrimitive") 25 | } 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /adoptimize-cli-compiler-plugin/src/main/kotlin/adoptimize/autodiff/Metadata/StackClass.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package adoptimize.autodiff.Metadata 9 | 10 | import org.jetbrains.kotlin.ir.declarations.IrClass 11 | import org.jetbrains.kotlin.ir.declarations.IrFunction 12 | 13 | class StackClass(val clazz: IrClass, val popMethod: IrFunction, val pushMethod: IrFunction, val notEmptyMethod: IrFunction, val topMethod: IrFunction) 14 | -------------------------------------------------------------------------------- /adoptimize-cli-compiler-plugin/src/main/kotlin/adoptimize/autodiff/NodeCodeCopy/AutoDiffCodeWriterImpl.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package adoptimize.autodiff.NodeCodeCopy 9 | 10 | import adoptimize.autodiff.* 11 | import adoptimize.autodiff.Metadata.DifferentiableApi 12 | import adoptimize.autodiff.diffIR.CallVariable 13 | import adoptimize.autodiff.diffIR.DiffIRFunction 14 | import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext 15 | import org.jetbrains.kotlin.ir.declarations.IrClass 16 | import org.jetbrains.kotlin.ir.declarations.IrDeclarationParent 17 | import org.jetbrains.kotlin.ir.declarations.IrFunction 18 | import org.jetbrains.kotlin.ir.declarations.IrSimpleFunction 19 | import org.jetbrains.kotlin.ir.declarations.IrValueDeclaration 20 | import org.jetbrains.kotlin.ir.declarations.IrVariable 21 | import org.jetbrains.kotlin.ir.util.isSubclassOf 22 | import pluginCommon.ScopeSubstitutionMap 23 | import pluginCommon.generators.IrBodyGenerator 24 | import pluginCommon.generators.IrFunctionGenerator 25 | import pluginCommon.generators.overrideRoot 26 | 27 | class AutoDiffCodeWriterImpl( 28 | val callGenerator: IrBodyGenerator, 29 | val differentiableApi: DifferentiableApi, 30 | primalFunction: DiffIRFunction, 31 | functionGenerator: IrFunctionGenerator, 32 | context: IrPluginContext 33 | ) : AutoDiffCodeWriter { 34 | val inliner = AutoDiffInliner(callGenerator, differentiableApi, primalFunction) 35 | val inlinerCustomTypes = AutoDiffCustomTypeInliner(callGenerator, differentiableApi, primalFunction) 36 | val runtime = AutoDiffOperationOverloadWriter( 37 | callGenerator, 38 | differentiableApi, 39 | functionGenerator, 40 | context.irBuiltIns 41 | ) 42 | override fun writeBackpropCodeForLeaf( 43 | leaf: CallVariable, 44 | primalToLocalMap: ScopeSubstitutionMap, 45 | currentUpstream: IrVariable, 46 | backPropMethod: IrFunction, 47 | guardedScope: GuardedScope, 48 | pullback: IrValueDeclaration? 49 | ): DerivativeContributions { 50 | return when { 51 | leaf.callInfo.dependencyNode == null -> runtime.writeBackpropCodeForLeaf(leaf, primalToLocalMap, currentUpstream, backPropMethod, guardedScope, pullback) 52 | (backPropMethod as IrSimpleFunction).overrideRoot() == differentiableApi.reverseDiffScalarClass.backpropMethod.overrideRoot() -> inliner.writeBackpropCodeForLeaf(leaf, primalToLocalMap, currentUpstream, backPropMethod, guardedScope, pullback) 53 | else -> inlinerCustomTypes.writeBackpropCodeForLeaf(leaf, primalToLocalMap, currentUpstream, backPropMethod, guardedScope, pullback) 54 | } 55 | } 56 | 57 | override fun writeInitCodeForLeaf( 58 | leaf: CallVariable, 59 | primalToLocalMap: ScopeSubstitutionMap, 60 | guardedScope: GuardedScope, 61 | declarationParent: IrDeclarationParent 62 | ): WrittenDeclarations? { 63 | return when { 64 | leaf.callInfo.dependencyNode == null -> runtime.writeInitCodeForLeaf(leaf, primalToLocalMap, guardedScope, declarationParent) 65 | declarationParent is IrClass && (declarationParent as IrClass).isSubclassOf(differentiableApi.reverseDiffScalarClass.clazz) -> inliner.writeInitCodeForLeaf( 66 | leaf, primalToLocalMap, guardedScope, 67 | declarationParent 68 | ) 69 | else -> inlinerCustomTypes.writeInitCodeForLeaf(leaf, primalToLocalMap, guardedScope, declarationParent) 70 | } 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /adoptimize-cli-compiler-plugin/src/main/kotlin/adoptimize/autodiff/diffIR/DiffIRVisitor.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package adoptimize.autodiff.diffIR 9 | 10 | interface DiffIRVisitor { 11 | fun visitVariable(declaration: ActiveVariable) {} 12 | fun visitSetVariable(expression: SetValue) {} 13 | fun visitBlockStatement(expression: BlockStatement) {} 14 | fun visitBlockBodyStatement(expression: BlockBodyStatement) {} 15 | fun visitWhenStatement(expression: WhenStatement) {} 16 | fun visitLoopStatement(expression: LoopStatement) {} 17 | fun visitConditionStatement(expression: ConditionBlock) {} 18 | fun visitReturn(returnStatement: ReturnStatement) {} 19 | fun visitConstant(constantStatement: ConstantStatement) {} 20 | fun visitSetField(setField: SetField) {} 21 | fun visitCall(call: Call) {} 22 | fun visitConstructorCallVariable(constructorCallVariable: ConstructorCallVariable) {} 23 | fun visitTypeOperatorVariable(typeOperatorVariable: TypeOperatorVariable) {} 24 | fun visitPushIntermediateState(pushIntermediateStateVariable: PushIntermediateStateVariable) {} 25 | } 26 | -------------------------------------------------------------------------------- /adoptimize-cli-compiler-plugin/src/main/kotlin/adoptimize/autodiff/forwards/TangentRecorder.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package adoptimize.autodiff.forwards 9 | 10 | import org.jetbrains.kotlin.ir.declarations.IrProperty 11 | import org.jetbrains.kotlin.ir.declarations.IrValueDeclaration 12 | 13 | class TangentRecorder { 14 | private val targetValueToTargetTangentProperty = mutableMapOf() 15 | private val targetPropertyToTangentProperty = mutableMapOf() 16 | operator fun set(targetValue: IrValueDeclaration, targetProperty: IrProperty) { targetValueToTargetTangentProperty[targetValue] = targetProperty } 17 | operator fun set(targetProperty: IrProperty, tangentProperty: IrProperty) { targetPropertyToTangentProperty[targetProperty] = tangentProperty } 18 | 19 | operator fun get(srcValue: IrValueDeclaration) = targetValueToTargetTangentProperty[srcValue] 20 | operator fun get(targetProperty: IrProperty) = targetPropertyToTangentProperty[targetProperty] 21 | } 22 | -------------------------------------------------------------------------------- /adoptimize-cli-compiler-plugin/src/main/resources/META-INF/services/org.jetbrains.kotlin.compiler.plugin.CommandLineProcessor: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | adoptimize.ADOptimizeCommandLineProcessor -------------------------------------------------------------------------------- /adoptimize-cli-compiler-plugin/src/main/resources/META-INF/services/org.jetbrains.kotlin.compiler.plugin.ComponentRegistrar: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | adoptimize.ADOptimizeComponentRegistrar -------------------------------------------------------------------------------- /adoptimize-common/build.gradle.kts: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | description = "store shared constants between the preprocessor plugin that writes the properties file and the adoptimizer plugin that reads the properties file" 9 | -------------------------------------------------------------------------------- /adoptimize-common/src/main/kotlin/adOptimizeCommon/DifferentiableApi.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package adOptimizeCommon 9 | 10 | const val propertiesFileName = "adoptimize.properties" 11 | const val reverseClass = "reverseClass" 12 | const val forwardClass = "forwardScalarClass" 13 | const val tangentProperty = "tangentProperty" 14 | const val boxedPrimitive = "boxedPrimitive" 15 | const val primitiveType = "primitiveType" 16 | const val primalProperty = "primalProperty" 17 | const val backpropMethod = "backpropMethod" 18 | const val upstreamProperty = "upstreamProperty" 19 | const val pushbackMethod = "pushbackMethod" 20 | const val derivativeId = "derivativeId" 21 | const val scalarRoot = "scalarRoot" 22 | const val primalAndPullbackFunction = "primalAndPullbackFunction" 23 | const val valueProperty = "valueProperty" 24 | const val scalarPlusFunction = "scalarPlusFunction" 25 | const val tensorPlusFunction = "tensorPlusFunction" 26 | const val scalarZero = "scalarZero" 27 | const val scalarOne = "scalarOne" 28 | const val stackImpl = "stackImpl" 29 | const val toUnboxFunction = "toUnboxFunction" 30 | const val toReverse = "toReverse" 31 | const val dTensor = "dTensor" 32 | const val reverseOperations = "reverseOperations" 33 | const val scalarNoop = "scalarNoop" 34 | 35 | fun reverseNodeNameFromOperationsName(operationsFunctionFqName: String) = "${operationsFunctionFqName.replace('.','_')}LiftedReverseNode" 36 | -------------------------------------------------------------------------------- /adoptimize-gradle-plugin/build.gradle.kts: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | import org.jetbrains.kotlin.gradle.tasks.KotlinCompile 9 | description = "AD optimize gradle plugin" 10 | 11 | plugins { 12 | id("java-gradle-plugin") 13 | } 14 | 15 | dependencies { 16 | val ktVersion: String by System.getProperties() 17 | compileOnly("org.jetbrains.kotlin:kotlin-stdlib:$ktVersion") 18 | compileOnly(kotlin("gradle-plugin-api")) 19 | implementation(project(":config")) 20 | } 21 | 22 | tasks.withType { 23 | kotlinOptions.jvmTarget = "1.8" 24 | } 25 | 26 | group = "org.meta.diffkt.adoptimize" 27 | 28 | // generate plugin descriptors in the resulting JAR's META-INF directory 29 | gradlePlugin { 30 | plugins { 31 | create("meta-diffkt-adoptimize-gradle-plugin") { 32 | id = "meta-diffkt-adoptimize" 33 | implementationClass = "adoptimize.gradle.ADOptimizeGradleSubPlugin" 34 | } 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /adoptimize-gradle-plugin/src/main/kotlin/adoptimize/gradle/ADOptimizeExtension.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | package adoptimize.gradle 8 | 9 | open class ADOptimizeExtension { 10 | var optimize = "" 11 | var diffApi: String = "" 12 | var secondOrderOptimization = "" 13 | private set 14 | var failOnADFailFlag = false 15 | private set 16 | var reverseAD = "" 17 | private set 18 | 19 | open fun optimizeAnnotation(customOptimizeAnnotation: String) { 20 | this.optimize = customOptimizeAnnotation 21 | } 22 | 23 | open fun diffApi(group: String, artifactId: String, version: String) { 24 | diffApi = "$group:$artifactId:$version" 25 | } 26 | 27 | open fun secondOrderAnnotation(customSecondOrderAnnotation: String) { 28 | this.secondOrderOptimization = customSecondOrderAnnotation 29 | } 30 | 31 | open fun failOnADFail(failOnAdFail: Boolean) { 32 | this.failOnADFailFlag = failOnAdFail 33 | } 34 | 35 | open fun reverseADFunction(fqn: String) { 36 | this.reverseAD = fqn 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /adoptimize-gradle-plugin/src/main/kotlin/adoptimize/gradle/ADOptimizeGradleExtension.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package adoptimize.gradle 9 | 10 | import config.BuildConfig 11 | import org.gradle.api.Project 12 | import org.gradle.api.provider.Provider 13 | import org.gradle.tooling.provider.model.ToolingModelBuilderRegistry 14 | import org.jetbrains.kotlin.gradle.plugin.KotlinCompilation 15 | import org.jetbrains.kotlin.gradle.plugin.KotlinCompilerPluginSupportPlugin 16 | import org.jetbrains.kotlin.gradle.plugin.SubpluginArtifact 17 | import org.jetbrains.kotlin.gradle.plugin.SubpluginOption 18 | import javax.inject.Inject 19 | 20 | class ADOptimizeGradleSubPlugin @Inject internal constructor(private val registry: ToolingModelBuilderRegistry) : KotlinCompilerPluginSupportPlugin { 21 | override fun apply(target: Project) { 22 | target.extensions.create("adOptimize", ADOptimizeExtension::class.java) 23 | } 24 | 25 | override fun isApplicable(kotlinCompilation: KotlinCompilation<*>): Boolean = true 26 | 27 | override fun applyToCompilation(kotlinCompilation: KotlinCompilation<*>): Provider> { 28 | val project = kotlinCompilation.target.project 29 | 30 | val adoptimize = project.extensions.getByType(ADOptimizeExtension::class.java) 31 | 32 | return project.provider { 33 | val options = mutableListOf() 34 | options += SubpluginOption("optimize", adoptimize.optimize) 35 | options += SubpluginOption("diffApi", adoptimize.diffApi) 36 | options += SubpluginOption("secondOrderOptimize", adoptimize.secondOrderOptimization) 37 | options += SubpluginOption("failOnADFail", adoptimize.failOnADFailFlag.toString()) 38 | options += SubpluginOption("reverseADFunction", adoptimize.reverseAD.toString()) 39 | options 40 | } 41 | } 42 | 43 | override fun getCompilerPluginId() = BuildConfig.ADOPTIMIZE_ID 44 | 45 | override fun getPluginArtifact(): SubpluginArtifact = 46 | SubpluginArtifact(groupId = BuildConfig.PLUGIN_GROUP, artifactId = BuildConfig.ADOPTIMIZE_ARTIFACT_ID, version = BuildConfig.PLUGIN_VERSION) 47 | } 48 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/build.gradle.kts: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | repositories { 9 | kotlinBuildLocalRepo(project) 10 | } 11 | 12 | val platformDependencies: Array by rootProject.extra 13 | val testDependencies: Array by rootProject.extra 14 | val testRuntimeDependencies: Array by rootProject.extra 15 | val kotlinStd by configurations.creating 16 | val coreDependencies: Array by rootProject.extra 17 | 18 | dependencies { 19 | val ktVersion: String by System.getProperties() 20 | val ADDiffKtVersion: String by System.getProperties() 21 | val intellijVersion: String by System.getProperties() 22 | 23 | kotlinStd("org.jetbrains.kotlin:kotlin-stdlib:$ktVersion") 24 | testImplementation(project(":adoptimize-cli-compiler-plugin", "shadowArtifact")) 25 | testImplementation("org.jetbrains.kotlin:kotlin-compiler-internal-test-framework:$ktVersion") 26 | testImplementation("org.jetbrains.kotlin:kotlin-scripting-compiler:$ktVersion") 27 | testImplementation(project(":config")) 28 | 29 | // necessary for populating the gradle cache the the configurator can access that jar at runtime 30 | testImplementation("org.diffkt.adopt", "api", ADDiffKtVersion) 31 | 32 | testDependencies.forEach { 33 | testImplementation(it) 34 | } 35 | 36 | testRuntimeDependencies.forEach { 37 | testRuntimeOnly(it) 38 | } 39 | 40 | platformDependencies.forEach { 41 | testImplementation("com.jetbrains.intellij.platform:$it:$intellijVersion") 42 | } 43 | 44 | coreDependencies.forEach { artifactName -> 45 | testImplementation("kotlin.build:intellij-core:$intellijVersion") { 46 | artifact { 47 | name = artifactName 48 | type = "jar" 49 | extension = "jar" 50 | } 51 | } 52 | } 53 | } 54 | 55 | val testArtifact by configurations.creating 56 | val testJar = tasks.register("testJar") { 57 | val convention = project.convention.getPlugin() 58 | archiveClassifier.set("sources") 59 | from(convention.sourceSets.test.get().output) 60 | } 61 | 62 | artifacts { 63 | add(testArtifact.name, testJar) 64 | } 65 | 66 | projectTest { 67 | workingDir = rootDir 68 | useJUnitPlatform() 69 | } 70 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/kotlin/adoptimize/ADOptimizeIRTest.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package adoptimize 9 | 10 | import org.junit.jupiter.api.Test 11 | 12 | class ADOptimizeIRTest : AbstractADOptimizeIRTest() { 13 | val homeDir = "adoptimize-integration-tests/src/test/testData/ir" 14 | 15 | @Test 16 | fun testControlFlow() { 17 | runTest("$homeDir/control_flow.kt") 18 | } 19 | 20 | @Test 21 | fun testDerivative() { 22 | runTest("$homeDir/derivative.kt") 23 | } 24 | 25 | @Test 26 | fun secondOrderDerivative() { 27 | runTest("$homeDir/secondOrderDerivative.kt") 28 | } 29 | 30 | @Test 31 | fun firstAndSecondOrderDerivative() { 32 | runTest("$homeDir/firstAndSecondOrderDerivative.kt") 33 | } 34 | 35 | @Test 36 | fun testFloatFunction() { 37 | runTest("$homeDir/floatFunction.kt") 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/kotlin/adoptimize/AbstractADOptimizeBlackBoxTest.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package adoptimize 9 | 10 | import org.jetbrains.kotlin.test.builders.TestConfigurationBuilder 11 | import org.jetbrains.kotlin.test.model.TestModule 12 | import org.jetbrains.kotlin.test.runners.codegen.AbstractIrBlackBoxCodegenTest 13 | import org.jetbrains.kotlin.test.services.RuntimeClasspathProvider 14 | import org.jetbrains.kotlin.test.services.TestServices 15 | import java.io.File 16 | 17 | abstract class AbstractADOptimizeBlackBoxTest : AbstractIrBlackBoxCodegenTest() { 18 | override fun configure(builder: TestConfigurationBuilder) { 19 | super.configure(builder) 20 | with(builder) { 21 | useCustomRuntimeClasspathProviders({ testServices: TestServices -> 22 | object : RuntimeClasspathProvider(testServices) { 23 | override fun runtimeClassPaths(module: TestModule): List { 24 | return listOf(ADOptimizeConfigurator.diffApiJar, ADOptimizeConfigurator.kotlinReflect) 25 | } 26 | } 27 | }) 28 | 29 | useConfigurators({ testServices: TestServices -> 30 | ADOptimizeConfigurator(testServices) 31 | } 32 | ) 33 | } 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/kotlin/adoptimize/AbstractADOptimizeIrTest.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package adoptimize 9 | 10 | import org.jetbrains.kotlin.test.builders.TestConfigurationBuilder 11 | import org.jetbrains.kotlin.test.model.TestModule 12 | import org.jetbrains.kotlin.test.runners.ir.AbstractIrTextTest 13 | import org.jetbrains.kotlin.test.services.RuntimeClasspathProvider 14 | import org.jetbrains.kotlin.test.services.TestServices 15 | import java.io.File 16 | 17 | abstract class AbstractADOptimizeIRTest : AbstractIrTextTest() { 18 | override fun configure(builder: TestConfigurationBuilder) { 19 | super.configure(builder) 20 | with(builder) { 21 | useCustomRuntimeClasspathProviders({ testServices: TestServices -> 22 | object : RuntimeClasspathProvider(testServices) { 23 | override fun runtimeClassPaths(testModule: TestModule): List { 24 | return listOf(ADOptimizeConfigurator.diffApiJar) 25 | } 26 | } 27 | }) 28 | 29 | useConfigurators({ testServices: TestServices -> 30 | ADOptimizeConfigurator(testServices) 31 | } 32 | ) 33 | } 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/activeArgument.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | 4 | annotation class Optimize 5 | 6 | @Optimize 7 | fun target(dead: FloatScalar, a: DScalar): DScalar { 8 | var s: DScalar = dead 9 | val i0 = a.pow(2f) 10 | s = i0 11 | val i1 = s + s 12 | val i2 = i1 * i1 13 | return i2 14 | } 15 | 16 | fun nonOptimal_target(dead: FloatScalar, a: DScalar): DScalar { 17 | var s: DScalar = dead 18 | val i0 = a.pow(2f) 19 | s = i0 20 | val i1 = s + s 21 | val i2 = i1 * i1 22 | return i2 23 | } 24 | 25 | fun box(): String { 26 | val x = FloatScalar(2.15f) 27 | val dead = FloatScalar(1.11f) 28 | val primal_derivative = primalAndReverseDerivative(x, { y: DScalar -> target(dead, y) }) 29 | val primal = primal_derivative.first 30 | val derivative = primal_derivative.second 31 | 32 | val expected_primal_derivative = primalAndReverseDerivative(x, { y: DScalar -> nonOptimal_target(dead, y) }) 33 | val expected_derivative = expected_primal_derivative.second 34 | val expected_primal = expected_primal_derivative.first 35 | val tol = 0.000001f 36 | if (Math.abs(derivative.basePrimal().value - expected_derivative.basePrimal().value) > tol) { 37 | return "FAIL: expected ${expected_derivative.basePrimal().value} but got ${derivative.basePrimal().value}" 38 | } 39 | if (Math.abs(primal.basePrimal().value - expected_primal.basePrimal().value) > tol) { 40 | return "FAIL: expected ${expected_primal.basePrimal().value} but got ${primal.basePrimal().value}" 41 | } 42 | return "OK" 43 | } 44 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/assignOperations.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | 4 | annotation class Optimize 5 | 6 | @Optimize 7 | fun target(a: DScalar): DScalar { 8 | var i = 0 9 | i += 1 10 | i *= 2 11 | i /= 3 12 | i -= 4 13 | val i0 = a.pow(2f) 14 | val i1 = 2f * i0 15 | val i2 = i1 * i1 16 | return i2 17 | } 18 | 19 | fun nonOptimal_target(a: DScalar): DScalar { 20 | val i0 = a.pow(2f) 21 | val i1 = 2f * i0 22 | val i2 = i1 * i1 23 | return i2 24 | } 25 | 26 | fun box(): String { 27 | val x = FloatScalar(2.15f) 28 | val primal_derivative = primalAndReverseDerivative(x, { y: DScalar -> target(y) }) 29 | val primal = primal_derivative.first 30 | val derivative = primal_derivative.second 31 | val expected_primal_derivative = primalAndReverseDerivative(x, { y: DScalar -> nonOptimal_target(y) }) 32 | val expected_derivative = expected_primal_derivative.second 33 | val expected_primal = expected_primal_derivative.first 34 | val tol = 0.000001f 35 | if (Math.abs(derivative.basePrimal().value - expected_derivative.basePrimal().value) > tol) { 36 | return "FAIL: expected ${expected_derivative.basePrimal().value} but got ${derivative.basePrimal().value}" 37 | } 38 | if (Math.abs(primal.basePrimal().value - expected_primal.basePrimal().value) > tol) { 39 | return "FAIL: expected ${expected_primal.basePrimal().value} but got ${primal.basePrimal().value}" 40 | } 41 | return "OK" 42 | } 43 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/constArg.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | 4 | annotation class Optimize 5 | 6 | @Optimize 7 | fun target(a: DScalar): DScalar { 8 | var i = 0 9 | i++ 10 | val i0 = a.pow(2f) 11 | val i1 = 2f * i0 12 | val i2 = i1 * i1 13 | return i2 14 | } 15 | 16 | fun nonOptimal_target(a: DScalar): DScalar { 17 | val i0 = a.pow(2f) 18 | val i1 = 2f * i0 19 | val i2 = i1 * i1 20 | return i2 21 | } 22 | 23 | fun box(): String { 24 | val x = FloatScalar(2.15f) 25 | val primal_derivative = primalAndReverseDerivative(x, { y: DScalar -> target(y) }) 26 | val primal = primal_derivative.first 27 | val derivative = primal_derivative.second 28 | val expected_primal_derivative = primalAndReverseDerivative(x, { y: DScalar -> nonOptimal_target(y) }) 29 | val expected_derivative = expected_primal_derivative.second 30 | val expected_primal = expected_primal_derivative.first 31 | val tol = 0.000001f 32 | if (Math.abs(derivative.basePrimal().value - expected_derivative.basePrimal().value) > tol) { 33 | return "FAIL: expected ${expected_derivative.basePrimal().value} but got ${derivative.basePrimal().value}" 34 | } 35 | if (Math.abs(primal.basePrimal().value - expected_primal.basePrimal().value) > tol) { 36 | return "FAIL: expected ${expected_primal.basePrimal().value} but got ${primal.basePrimal().value}" 37 | } 38 | return "OK" 39 | } 40 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/control_flow.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | 4 | annotation class Optimize 5 | 6 | @Optimize 7 | fun nested_if(a: DScalar, c: Float): DScalar { 8 | val i0 = a * a 9 | var b = i0 * i0 10 | when { 11 | c > 0.0 -> { 12 | b = b * i0 13 | } 14 | else -> { 15 | b = i0 * a 16 | } 17 | } 18 | val y = b * i0 19 | return y 20 | } 21 | 22 | fun manual_if(a: DScalar, c: Float): DScalar { 23 | val i0 = a * a 24 | var b = i0 * i0 25 | when { 26 | c > 0.0 -> { 27 | b = b * i0 28 | } 29 | else -> { 30 | b = i0 * a 31 | } 32 | } 33 | val y = b * i0 34 | return y 35 | } 36 | 37 | fun box(): String { 38 | val x = FloatScalar(1.15f) 39 | val constant = 0.0f 40 | val derivativePair = primalAndReverseDerivative(x, { c: DScalar -> manual_if(c, constant) }) 41 | val primal = derivativePair.first 42 | val derivative = derivativePair.second 43 | val expected_derivativePair = primalAndReverseDerivative(x, { c: DScalar -> nested_if(c, constant) }) 44 | val expected_derivative = expected_derivativePair.second 45 | val expectedPrimal = expected_derivativePair.first 46 | val tol = 0.000001f 47 | if (Math.abs(primal.basePrimal().value - expectedPrimal.basePrimal().value) > tol) { 48 | return "Primal FAIL: expected ${expectedPrimal.basePrimal().value} but got ${primal.basePrimal().value}" 49 | } 50 | if (Math.abs(derivative.basePrimal().value - expected_derivative.basePrimal().value) > tol) { 51 | return "Derivative FAIL: expected ${expected_derivative.basePrimal().value} but got ${derivative.basePrimal().value}" 52 | } 53 | return "OK" 54 | } 55 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/control_flow_derivative.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | 4 | annotation class Optimize 5 | 6 | @Optimize 7 | fun target(a: DScalar, c: Double): DScalar { 8 | val i0 = a * a 9 | var b = i0 + i0 10 | when { 11 | c > 0.0 -> { 12 | val localVar = b + b 13 | b = localVar * b 14 | } 15 | else -> { 16 | b = b * b 17 | } 18 | } 19 | val y = b * i0 20 | return y 21 | } 22 | 23 | fun nonOptimal_target(a: DScalar, c: Double): DScalar { 24 | val i0 = a * a 25 | var b = i0 + i0 26 | when { 27 | c > 0.0 -> { 28 | val localVar = b + b 29 | b = localVar * b 30 | } 31 | else -> { 32 | b = b * b 33 | } 34 | } 35 | val y = b * i0 36 | return y 37 | } 38 | 39 | fun box(): String { 40 | val x = FloatScalar(1.15f) 41 | val constant1 = 0.5 42 | val constant2 = 0.0 43 | for (element in listOf(constant1, constant2).withIndex()) { 44 | val c = element.value 45 | val derivative = primalAndReverseDerivative(x, { y: DScalar -> target(y, c) }).second 46 | val expected_derivative = primalAndReverseDerivative(x, { y: DScalar -> nonOptimal_target(y, c) }).second 47 | val tol = 0.000001f 48 | if (Math.abs(derivative.basePrimal().value - expected_derivative.basePrimal().value) > tol) { 49 | return "FAIL: (index ${element.index}) expected ${expected_derivative.basePrimal().value} but got ${derivative.basePrimal().value}" 50 | } 51 | } 52 | 53 | return "OK" 54 | } 55 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/control_flow_nested_if.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | annotation class Optimize 4 | 5 | @Optimize 6 | fun target(a: DScalar, c: Double): DScalar { 7 | val i0 = a * a 8 | var b = i0 * i0 9 | var g = i0 + i0 10 | when { 11 | c > 2.25 -> { 12 | b = b * i0 13 | } 14 | else -> { 15 | when { 16 | c > 0.5 -> { 17 | val localVar = g + g 18 | g = localVar + g 19 | } 20 | else -> { 21 | g = g * g 22 | } 23 | } 24 | b = g * b 25 | } 26 | } 27 | val y = b * g 28 | return y 29 | } 30 | 31 | fun nonOptimal_target(a: DScalar, c: Double): DScalar { 32 | val i0 = a * a 33 | var b = i0 * i0 34 | var g = i0 + i0 35 | when { 36 | c > 2.25 -> { 37 | b = b * i0 38 | } 39 | else -> { 40 | when { 41 | c > 0.5 -> { 42 | val localVar = g + g 43 | g = localVar + g 44 | } 45 | else -> { 46 | g = g * g 47 | } 48 | } 49 | b = g * b 50 | } 51 | } 52 | val y = b * g 53 | return y 54 | } 55 | 56 | fun box(): String { 57 | val x = FloatScalar(1.15f) 58 | val constant1 = 2.5 59 | val constant2 = 0.25 60 | val constant3 = 1.5 61 | for (element in listOf(constant1, constant2, constant3).withIndex()) { 62 | val c = element.value 63 | val derivativePair = primalAndReverseDerivative(x, { y: DScalar -> target(y, c) }) 64 | val expected_derivativePair = primalAndReverseDerivative(x, { y: DScalar -> nonOptimal_target(y, c) }) 65 | val tol = 0.000001 66 | if (Math.abs(derivativePair.second.basePrimal().value - expected_derivativePair.second.basePrimal().value) > tol) { 67 | return "DERIVATIVE FAIL: (index ${element.index}) expected ${expected_derivativePair.second.basePrimal().value} but got ${derivativePair.second.basePrimal().value}" 68 | } 69 | if (Math.abs(derivativePair.first.basePrimal().value - expected_derivativePair.first.basePrimal().value) > tol) { 70 | return "PRIMAL FAIL: (index ${element.index}) expected ${expected_derivativePair.first.basePrimal().value} but got ${derivativePair.first.basePrimal().value}" 71 | } 72 | } 73 | 74 | return "OK" 75 | } 76 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/diffkt.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | 4 | annotation class Optimize 5 | 6 | @Optimize 7 | fun target(a: DScalar): DScalar = (a.pow(2f) + a.pow(2f)) * (a.pow(2f) + a.pow(2f)) 8 | 9 | fun nonOptimal_target(a: DScalar): DScalar = (a.pow(2f) + a.pow(2f)) * (a.pow(2f) + a.pow(2f)) 10 | 11 | fun box(): String { 12 | val x = FloatScalar(2.15f) 13 | val primal_derivative = primalAndReverseDerivative(x, { y: DScalar -> target(y) }) 14 | val primal = primal_derivative.first 15 | val derivative = primal_derivative.second 16 | val expected_primal_derivative = primalAndReverseDerivative(x, { y: DScalar -> nonOptimal_target(y) }) 17 | val expected_derivative = expected_primal_derivative.second 18 | val expected_primal = expected_primal_derivative.first 19 | val tol = 0.000001f 20 | if (Math.abs(derivative.basePrimal().value - expected_derivative.basePrimal().value) > tol) { 21 | return "FAIL: expected ${expected_derivative.basePrimal().value} but got ${derivative.basePrimal().value}" 22 | } 23 | if (Math.abs(primal.basePrimal().value - expected_primal.basePrimal().value) > tol) { 24 | return "FAIL: expected ${expected_primal.basePrimal().value} but got ${primal.basePrimal().value}" 25 | } 26 | return "OK" 27 | } 28 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/elseLower.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | 4 | annotation class Optimize 5 | 6 | @Optimize 7 | fun target(a: DScalar): DScalar { 8 | val b = cos(a) 9 | var x = a 10 | if (b >= 0f) { 11 | x = a * b 12 | } 13 | return x 14 | } 15 | 16 | fun nonOptimal_target(a: DScalar): DScalar { 17 | val b = cos(a) 18 | var x = a 19 | if (b >= 0f) { 20 | x = a * b 21 | } 22 | return x 23 | } 24 | 25 | fun box(): String { 26 | val x1 = FloatScalar(2.15f) 27 | val x2 = FloatScalar(0.5f) 28 | for (input in listOf(x1, x2)) { 29 | val primal_derivative = primalAndReverseDerivative(input, { y: DScalar -> target(y) }) 30 | val primal = primal_derivative.first 31 | val derivative = primal_derivative.second 32 | val expected_primal_derivative = primalAndReverseDerivative(input, { y: DScalar -> nonOptimal_target(y) }) 33 | val expected_derivative = expected_primal_derivative.second 34 | val expected_primal = expected_primal_derivative.first 35 | val tol = 0.000001f 36 | if (Math.abs(derivative.basePrimal().value - expected_derivative.basePrimal().value) > tol) { 37 | return "FAIL: expected ${expected_derivative.basePrimal().value} but got ${derivative.basePrimal().value}" 38 | } 39 | if (Math.abs(primal.basePrimal().value - expected_primal.basePrimal().value) > tol) { 40 | return "FAIL: expected ${expected_primal.basePrimal().value} but got ${primal.basePrimal().value}" 41 | } 42 | } 43 | 44 | return "OK" 45 | } 46 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/exp.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | 4 | annotation class Optimize 5 | annotation class SecondOrderOptimize 6 | 7 | @SecondOrderOptimize 8 | @Optimize 9 | fun target(a: DScalar): DScalar { 10 | val x = if (a > 0f) { 11 | val w = exp(a) 12 | w 13 | } else { 14 | val z = -a 15 | z 16 | } 17 | return x 18 | } 19 | 20 | fun nonOptimal_target(a: DScalar): DScalar { 21 | val x = if (a > 0f) { exp(a) } else { -exp(a) } 22 | return x 23 | } 24 | 25 | fun box(): String { 26 | val x = FloatScalar(2.15f) 27 | val primal_derivative = primalAndReverseDerivative(x, { y: DScalar -> target(y) }) 28 | val primal = primal_derivative.first 29 | val derivative = primal_derivative.second 30 | val expected_primal_derivative = primalAndReverseDerivative(x, { y: DScalar -> nonOptimal_target(y) }) 31 | val expected_derivative = expected_primal_derivative.second 32 | val expected_primal = expected_primal_derivative.first 33 | val tol = 0.000001f 34 | if (Math.abs(derivative.basePrimal().value - expected_derivative.basePrimal().value) > tol) { 35 | return "FAIL: expected ${expected_derivative.basePrimal().value} but got ${derivative.basePrimal().value}" 36 | } 37 | if (Math.abs(primal.basePrimal().value - expected_primal.basePrimal().value) > tol) { 38 | return "FAIL: expected ${expected_primal.basePrimal().value} but got ${primal.basePrimal().value}" 39 | } 40 | return "OK" 41 | } 42 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/floatFunction.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | 4 | annotation class Optimize 5 | annotation class ReverseAD 6 | 7 | @Optimize 8 | fun foo(a: Float): Float { 9 | val i0 = a + a 10 | val x: Float 11 | if (i0 < 2f) { 12 | x = a * a 13 | } else { 14 | x = a 15 | } 16 | return i0 * x 17 | } 18 | 19 | fun nonOptimal_foo(a: DScalar): DScalar { 20 | val i0 = a + a 21 | val x: DScalar 22 | if (i0 < 2f) { 23 | x = a * a 24 | } else { 25 | x = a 26 | } 27 | return i0 * x 28 | } 29 | 30 | @ReverseAD 31 | fun jacobian_transposed_vector_product(x: Float, f: (Float) -> Float): Float { 32 | TODO() 33 | } 34 | 35 | fun box(): String { 36 | val x = FloatScalar(2.15f) 37 | val derivative: Float = jacobian_transposed_vector_product(x.basePrimal().value, ::foo) 38 | val expected_primal_derivative = primalAndReverseDerivative(x, { y: DScalar -> nonOptimal_foo(y) }) 39 | val expected_derivative = expected_primal_derivative.second 40 | val expected_primal = expected_primal_derivative.first 41 | val tol = 0.000001f 42 | if (Math.abs(derivative - expected_derivative.basePrimal().value) > tol) { 43 | return "FAIL: expected ${expected_derivative.basePrimal().value} but got $derivative" 44 | } 45 | return "OK" 46 | } 47 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/forwardsUnbox.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | 4 | annotation class Optimize 5 | 6 | @Optimize 7 | fun target(dead: FloatScalar, a: DScalar): DScalar { 8 | val c = 2f 9 | val i0 = a.pow(c) 10 | val i1 = i0 + i0 11 | val i2 = i1 * i1 12 | return i2 13 | } 14 | 15 | fun nonOptimal_target(dead: FloatScalar, a: DScalar): DScalar { 16 | val c = 2f 17 | val i0 = a.pow(c) 18 | val i1 = i0 + i0 19 | val i2 = i1 * i1 20 | return i2 21 | } 22 | 23 | fun box(): String { 24 | val y = FloatScalar(2.15f) 25 | val dead = FloatScalar(3.0f) 26 | val result = target(dead, y) 27 | val expected = nonOptimal_target(dead, y) 28 | val tol = 0.000001f 29 | if (Math.abs(result.basePrimal().value - expected.basePrimal().value) > tol) { 30 | return "FAIL: expected ${expected.basePrimal().value} but got ${result.basePrimal().value}" 31 | } 32 | return "OK" 33 | } 34 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/getValInitializer.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | 4 | annotation class Optimize 5 | 6 | @Optimize 7 | fun target(a: DScalar): DScalar { 8 | val x = a * a 9 | var s = a 10 | s = x 11 | val z1 = s * s 12 | val z2 = z1 * x 13 | return z2 14 | } 15 | 16 | fun vanilla_target(a: DScalar): DScalar { 17 | val x = a * a 18 | var s = a 19 | s = x 20 | val z1 = s * s 21 | val z2 = z1 * x 22 | return z2 23 | } 24 | 25 | fun box(): String { 26 | val x = FloatScalar(1.15f) 27 | val derivativePair = primalAndReverseDerivative(x, { y: DScalar -> target(y) }) 28 | val expected_derivativePair = primalAndReverseDerivative(x, { y: DScalar -> vanilla_target(y) }) 29 | val tol = 0.000001f 30 | 31 | if (Math.abs(derivativePair.first.basePrimal().value - expected_derivativePair.first.basePrimal().value) > tol) { 32 | return "PRIMAL FAIL: expected ${expected_derivativePair.first.basePrimal().value} but got ${derivativePair.first.basePrimal().value}" 33 | } 34 | if (Math.abs(derivativePair.second.basePrimal().value - expected_derivativePair.second.basePrimal().value) > tol) { 35 | return "DERIVATIVE FAIL: expected ${expected_derivativePair.second.basePrimal().value} but got ${derivativePair.second.basePrimal().value}" 36 | } 37 | return "OK" 38 | } 39 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/getterNoExplicitUnbox.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | 4 | annotation class Optimize 5 | 6 | class Nodes(val singleNode: DScalar, val listOfSomething: List) { 7 | fun singleNode() = singleNode 8 | } 9 | 10 | @Optimize 11 | fun target(a: DScalar, nodes: Nodes): DScalar { 12 | val x = nodes.listOfSomething[0] 13 | val operand = nodes.singleNode() 14 | val b = operand.pow(x) 15 | val c = a * b 16 | return c 17 | } 18 | 19 | fun nonOptimal_target(a: DScalar, nodes: Nodes): DScalar { 20 | val x = nodes.listOfSomething[0] 21 | val operand = nodes.singleNode() 22 | val b = operand.pow(x) 23 | val c = a * b 24 | return c 25 | } 26 | 27 | fun box(): String { 28 | val nodes = Nodes(FloatScalar(0.5f), listOf(2f)) 29 | val x = FloatScalar(1.15f) 30 | 31 | val derivativePair = primalAndReverseDerivative(x, { y: DScalar -> target(y, nodes) }) 32 | val expected_derivativePair = primalAndReverseDerivative(x, { y: DScalar -> nonOptimal_target(y, nodes) }) 33 | val tol = 0.000001f 34 | 35 | if (Math.abs(derivativePair.first.basePrimal().value - expected_derivativePair.first.basePrimal().value) > tol) { 36 | return "PRIMAL FAIL: expected ${expected_derivativePair.first.basePrimal().value} but got ${derivativePair.first.basePrimal().value}" 37 | } 38 | if (Math.abs(derivativePair.second.basePrimal().value - expected_derivativePair.second.basePrimal().value) > tol) { 39 | return "DERIVATIVE FAIL:expected ${expected_derivativePair.second.basePrimal().value} but got ${derivativePair.second.basePrimal().value}" 40 | } 41 | 42 | return "OK" 43 | } 44 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/if_statement.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | 4 | annotation class Optimize 5 | 6 | @Optimize 7 | fun if_statement(a: DScalar, c: Float): DScalar { 8 | val i0 = a * a 9 | var b = i0 * i0 10 | if (c > 2.0f) { 11 | b = b * i0 12 | } else if (c == 1.0f) { 13 | b = b * b 14 | } else { 15 | b = b * a 16 | } 17 | return b * i0 18 | } 19 | 20 | fun nonOptimal_if_statement(a: DScalar, c: Float): DScalar { 21 | val i0 = a * a 22 | var b = i0 * i0 23 | if (c > 2.0f) { 24 | b = b * i0 25 | } else if (c == 1.0f) { 26 | b = b * b 27 | } else { 28 | b = b * a 29 | } 30 | return b * i0 31 | } 32 | 33 | fun box(): String { 34 | val x = FloatScalar(1.15f) 35 | val constant1 = 1.0f 36 | val constant2 = 3.0f 37 | val constant3 = -1f 38 | for (element in listOf(constant1, constant2, constant3).withIndex()) { 39 | val c = element.value 40 | val derivativePair = primalAndReverseDerivative(x, { y: DScalar -> if_statement(y, c) }) 41 | val expected_derivativePair = primalAndReverseDerivative(x, { y: DScalar -> nonOptimal_if_statement(y, c) }) 42 | val tol = 0.000001f 43 | 44 | if (Math.abs(derivativePair.first.basePrimal().value - expected_derivativePair.first.basePrimal().value) > tol) { 45 | return "PRIMAL FAIL: (index ${element.index}) expected ${expected_derivativePair.first.basePrimal().value} but got ${derivativePair.first.basePrimal().value}" 46 | } 47 | if (Math.abs(derivativePair.second.basePrimal().value - expected_derivativePair.second.basePrimal().value) > tol) { 48 | return "DERIVATIVE FAIL: (index ${element.index}) expected ${expected_derivativePair.second.basePrimal().value} but got ${derivativePair.second.basePrimal().value}" 49 | } 50 | } 51 | 52 | return "OK" 53 | } 54 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/implicitParameter.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | 4 | annotation class Optimize 5 | 6 | @Optimize 7 | fun DScalar.target(): DScalar { 8 | return this * this 9 | } 10 | 11 | fun nonOptimal_target(a: DScalar): DScalar { 12 | return a * a 13 | } 14 | 15 | fun box(): String { 16 | val x = FloatScalar(2.15f) 17 | val primal_derivative = primalAndReverseDerivative(x, { y: DScalar -> y.target() }) 18 | val primal = primal_derivative.first 19 | val derivative = primal_derivative.second 20 | val expected_primal_derivative = primalAndReverseDerivative(x, { y: DScalar -> nonOptimal_target(y) }) 21 | val expected_derivative = expected_primal_derivative.second 22 | val expected_primal = expected_primal_derivative.first 23 | val tol = 0.000001f 24 | if (Math.abs(derivative.basePrimal().value - expected_derivative.basePrimal().value) > tol) { 25 | return "DERIVATIVE FAIL: expected ${expected_derivative.basePrimal().value} but got ${derivative.basePrimal().value}" 26 | } 27 | if (Math.abs(primal.basePrimal().value - expected_primal.basePrimal().value) > tol) { 28 | return "PRIMAL FAIL: expected ${expected_primal.basePrimal().value} but got ${primal.basePrimal().value}" 29 | } 30 | return "OK" 31 | } 32 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/initialization.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | annotation class Optimize 4 | 5 | @Optimize 6 | fun target(dead: FloatScalar, a: DScalar): DScalar { 7 | val c1 = 2f 8 | var c2 = c1 9 | val c = c2 10 | val i0 = a.pow(c) 11 | val i1 = i0 + i0 12 | val i2 = i1 * i1 13 | return i2 14 | } 15 | 16 | fun nonOptimal_target(dead: FloatScalar, a: DScalar): DScalar { 17 | val i0 = a.pow(2f) 18 | val i1 = i0 + i0 19 | val i2 = i1 * i1 20 | return i2 21 | } 22 | 23 | fun box(): String { 24 | val x = FloatScalar(2.15f) 25 | val dead = FloatScalar(3.0f) 26 | val derivative = primalAndReverseDerivative(x, { y: DScalar -> target(dead, y) }).second 27 | val expected_derivative = primalAndReverseDerivative(x, { y: DScalar -> nonOptimal_target(dead, y) }).second 28 | val tol = 0.000001 29 | if (Math.abs(derivative.basePrimal().value - expected_derivative.basePrimal().value) > tol) { 30 | return "FAIL: expected ${expected_derivative.basePrimal().value} but got ${derivative.basePrimal().value}" 31 | } 32 | return "OK" 33 | } 34 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/logProb.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | annotation class Optimize 4 | 5 | class Normal { 6 | companion object { 7 | val lnSqrt2Pi = kotlin.math.ln(kotlin.math.sqrt(2.0 * 3.14159)).toFloat() 8 | } 9 | } 10 | 11 | @Optimize 12 | fun logProbOf(value: DScalar, loc: FloatScalar, scale: FloatScalar): DScalar { 13 | val twoFloat = 2f 14 | val normal = Normal 15 | val variance = scale.pow(twoFloat) 16 | val constant = normal.lnSqrt2Pi 17 | val i0 = ln(scale) 18 | val i1 = i0 - constant 19 | val i2 = value - loc 20 | val i3 = i2.pow(twoFloat) 21 | val i4 = -i3 22 | val i5 = twoFloat * variance 23 | val i6 = i4 / i5 24 | val i7 = i6 - i1 25 | return i7 26 | } 27 | 28 | fun vanillaLogProbOf(value: DScalar, loc: FloatScalar, scale: FloatScalar): DScalar { 29 | val twoFloat = 2f 30 | val normal = Normal 31 | val variance = scale.pow(twoFloat) 32 | val constant = normal.lnSqrt2Pi 33 | val i0 = ln(scale) 34 | val i1 = i0 - constant 35 | val i2 = value - loc 36 | val i3 = i2.pow(twoFloat) 37 | val i4 = -i3 38 | val i5 = twoFloat * variance 39 | val i6 = i4 / i5 40 | val i7 = i6 - i1 41 | return i7 42 | } 43 | 44 | fun box(): String { 45 | val value = FloatScalar(1.715f) 46 | val loc = FloatScalar(3.0f) 47 | val scale = FloatScalar(2.5f) 48 | val (primal, derivative) = primalAndReverseDerivative(value, { y: DScalar -> logProbOf(y, loc, scale) }) 49 | val (expected_primal, expected_derivative) = primalAndReverseDerivative(value, { y: DScalar -> vanillaLogProbOf(y, loc, scale) }) 50 | val tol = 0.000001f 51 | if (Math.abs(primal.basePrimal().value - expected_primal.basePrimal().value) > tol) { 52 | return "PRIMAL FAIL: expected ${expected_primal.basePrimal().value} but got ${primal.basePrimal().value}" 53 | } 54 | if (Math.abs(derivative.basePrimal().value - expected_derivative.basePrimal().value) > tol) { 55 | return "DERIVATIVE FAIL: expected ${expected_derivative.basePrimal().value} but got ${derivative.basePrimal().value}." 56 | } 57 | return "OK" 58 | } 59 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/multipleOutputs.kt: -------------------------------------------------------------------------------- 1 | // WITH_RUNTIME 2 | package demo 3 | import org.diffkt.* 4 | 5 | annotation class Optimize 6 | 7 | @Optimize 8 | fun target(a: DScalar): DScalar { 9 | val i2 = a * a 10 | return i2 11 | } 12 | 13 | fun targetVanilla(a: DScalar): DScalar { 14 | val i2 = a * a 15 | return i2 16 | } 17 | 18 | fun multipleOutputsUsesOptimized(a: DScalar): Pair { 19 | val i0 = target(a) 20 | val i3 = a * cos(a) 21 | return Pair(i0, i3) 22 | } 23 | 24 | fun multipleOutputsUsesVanilla(a: DScalar): Pair { 25 | val i0 = targetVanilla(a) 26 | val i3 = a * cos(a) 27 | return Pair(i0, i3) 28 | } 29 | 30 | typealias Extractor = (input: DTensor, output: DTensor) -> DTensor 31 | fun box(): String { 32 | val x = FloatScalar(0.15f) 33 | val primal_derivative = primalAndReverseDerivative( 34 | x = x, 35 | f = ::multipleOutputsUsesOptimized, 36 | extractDerivative = { 37 | input: DScalar, 38 | output: Pair, 39 | extractDerivatives: Extractor -> 40 | val dxdy1 = extractDerivatives(input, output.first) 41 | val dxdy2 = extractDerivatives(input, output.second) 42 | Pair(dxdy1, dxdy2) 43 | } 44 | ) 45 | val primal = primal_derivative.first 46 | val derivative = primal_derivative.second 47 | val expected_primal_derivative = primalAndReverseDerivative( 48 | x = x, 49 | f = ::multipleOutputsUsesVanilla, 50 | extractDerivative = { 51 | input: DScalar, output: Pair, extractDerivatives: (input: DTensor, output: DTensor) -> DTensor -> 52 | val dxdy1 = extractDerivatives(input, output.first) 53 | val dxdy2 = extractDerivatives(input, output.second) 54 | Pair(dxdy1, dxdy2) 55 | } 56 | ) 57 | val expected_derivative = expected_primal_derivative.second 58 | val expected_primal = expected_primal_derivative.first 59 | val tol = 0.000001f 60 | for (output in listOf(Pair(primal.first, expected_primal.first), Pair(primal.second, expected_primal.second))) { 61 | if (Math.abs(output.first.basePrimal().value - output.second.basePrimal().value) > tol) { 62 | return "PRIMAL FAIL: expected ${output.first.basePrimal().value} but got ${output.second.basePrimal().value}" 63 | } 64 | } 65 | 66 | for (output in listOf(Pair(derivative.first as DScalar, expected_derivative.first as DScalar), Pair(derivative.second as DScalar, expected_derivative.second as DScalar))) { 67 | if (Math.abs(output.first.basePrimal().value - output.second.basePrimal().value) > tol) { 68 | return "PRIMAL FAIL: expected ${output.first.basePrimal().value} but got ${output.second.basePrimal().value}" 69 | } 70 | } 71 | return "OK" 72 | } 73 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/multipleOutputsControlFlow.kt: -------------------------------------------------------------------------------- 1 | // WITH_RUNTIME 2 | package demo 3 | import org.diffkt.* 4 | 5 | annotation class Optimize 6 | 7 | @Optimize 8 | fun target(a: DScalar): DScalar { 9 | var i = 0 10 | var s = a * a 11 | while (i < 3) { 12 | s = s * (s + a) 13 | i = i + 1 14 | } 15 | return s 16 | } 17 | 18 | fun targetVanilla(a: DScalar): DScalar { 19 | var i = 0 20 | var s = a * a 21 | while (i < 3) { 22 | s = s * (s + a) 23 | i = i + 1 24 | } 25 | return s 26 | } 27 | 28 | fun multipleOutputsUsesOptimized(a: DScalar): Pair { 29 | val i0 = target(a) 30 | val i3 = a * cos(a) 31 | return Pair(i0, i3) 32 | } 33 | 34 | fun multipleOutputsUsesVanilla(a: DScalar): Pair { 35 | val i0 = targetVanilla(a) 36 | val i3 = a * cos(a) 37 | return Pair(i0, i3) 38 | } 39 | 40 | typealias Extractor = (input: DTensor, output: DTensor) -> DTensor 41 | fun box(): String { 42 | val x = FloatScalar(0.15f) 43 | val primal_derivative = primalAndReverseDerivative( 44 | x = x, 45 | f = ::multipleOutputsUsesOptimized, 46 | extractDerivative = { 47 | input: DScalar, 48 | output: Pair, 49 | extractDerivatives: Extractor -> 50 | val dxdy1 = extractDerivatives(input, output.first) 51 | val dxdy2 = extractDerivatives(input, output.second) 52 | Pair(dxdy1, dxdy2) 53 | } 54 | ) 55 | val primal = primal_derivative.first 56 | val derivative = primal_derivative.second 57 | val expected_primal_derivative = primalAndReverseDerivative( 58 | x = x, 59 | f = ::multipleOutputsUsesVanilla, 60 | extractDerivative = { 61 | input: DScalar, output: Pair, extractDerivatives: (input: DTensor, output: DTensor) -> DTensor -> 62 | val dxdy1 = extractDerivatives(input, output.first) 63 | val dxdy2 = extractDerivatives(input, output.second) 64 | Pair(dxdy1, dxdy2) 65 | } 66 | ) 67 | val expected_derivative = expected_primal_derivative.second 68 | val expected_primal = expected_primal_derivative.first 69 | val tol = 0.000001f 70 | for (output in listOf(Pair(primal.first, expected_primal.first), Pair(primal.second, expected_primal.second))) { 71 | if (Math.abs(output.first.basePrimal().value - output.second.basePrimal().value) > tol) { 72 | return "PRIMAL FAIL: expected ${output.first.basePrimal().value} but got ${output.second.basePrimal().value}" 73 | } 74 | } 75 | 76 | for (output in listOf(Pair(derivative.first as DScalar, expected_derivative.first as DScalar), Pair(derivative.second as DScalar, expected_derivative.second as DScalar))) { 77 | if (Math.abs(output.first.basePrimal().value - output.second.basePrimal().value) > tol) { 78 | return "PRIMAL FAIL: expected ${output.first.basePrimal().value} but got ${output.second.basePrimal().value}" 79 | } 80 | } 81 | return "OK" 82 | } 83 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/nestedWhenVariable.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | annotation class Optimize 4 | annotation class SecondOrderOptimize 5 | 6 | @Optimize 7 | @SecondOrderOptimize 8 | fun target(a: DScalar, c: Double): DScalar { 9 | val u = FloatScalar.ZERO 10 | val g1 = if (c > 0) { 11 | if (c > 0.5) { 12 | val b = if (c > 0.56) { 13 | a 14 | } else { 15 | a + a 16 | } 17 | b 18 | } else { 19 | if (c >= 0.1) { 20 | a + a 21 | } else { 22 | -a 23 | } 24 | } 25 | } else u 26 | return g1 27 | } 28 | 29 | fun nonOptimal_target(a: DScalar, c: Double): DScalar { 30 | val u = FloatScalar.ZERO 31 | val g1 = if (c > 0) { 32 | if (c > 0.5) { 33 | val b = if (c > 0.56) { 34 | a 35 | } else { 36 | a + a 37 | } 38 | b 39 | } else { 40 | if (c >= 0.1) { 41 | a + a 42 | } else { 43 | -a 44 | } 45 | } 46 | } else u 47 | return g1 48 | } 49 | 50 | fun box(): String { 51 | val x = FloatScalar(1.15f) 52 | val constant1 = 0.01 53 | val constant4 = 0.2 54 | val constant2 = 1.0 55 | val constant3 = -0.5 56 | for (element in listOf(constant1, constant2, constant3, constant4, 0.57).withIndex()) { 57 | val c = element.value 58 | val derivativePair = primalAndReverseDerivative(x, { y: DScalar -> target(y, c) }) 59 | val expected_derivativePair = primalAndReverseDerivative(x, { y: DScalar -> nonOptimal_target(y, c) }) 60 | val tol = 0.000001 61 | if (Math.abs(derivativePair.second.basePrimal().value - expected_derivativePair.second.basePrimal().value) > tol) { 62 | return "DERIVATIVE FAIL: (index ${element.index}) expected ${expected_derivativePair.second.basePrimal().value} but got ${derivativePair.second.basePrimal().value}" 63 | } 64 | if (Math.abs(derivativePair.first.basePrimal().value - expected_derivativePair.first.basePrimal().value) > tol) { 65 | return "PRIMAL FAIL: (index ${element.index}) expected ${expected_derivativePair.first.basePrimal().value} but got ${derivativePair.first.basePrimal().value}" 66 | } 67 | } 68 | 69 | return "OK" 70 | } 71 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/nonActiveArgument.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | 4 | annotation class Optimize 5 | 6 | @Optimize 7 | fun target(notActive: FloatScalar, a: DScalar): DScalar { 8 | val constant = 5.0f 9 | val i2 = constant * a 10 | val i3 = i2 * notActive 11 | val i4 = cos(i3) 12 | return i4 13 | } 14 | 15 | fun nonOptimal_target(notActive: FloatScalar, a: DScalar): DScalar { 16 | val constant = 5.0f 17 | val i2 = constant * a 18 | val i3 = i2 * notActive 19 | val i4 = cos(i3) 20 | return i4 21 | } 22 | 23 | fun box(): String { 24 | val x = FloatScalar(2.15f) 25 | val dead = FloatScalar(3.0f) 26 | val derivative = primalAndReverseDerivative(x, { y: DScalar -> target(dead, y) }).second 27 | val expected_derivative = primalAndReverseDerivative(x, { y: DScalar -> nonOptimal_target(dead, y) }).second 28 | val tol = 0.000001 29 | if (Math.abs(derivative.basePrimal().value - expected_derivative.basePrimal().value) > tol) { 30 | return "FAIL: expected ${expected_derivative.basePrimal().value} but got ${derivative.basePrimal().value}" 31 | } 32 | return "OK" 33 | } 34 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/nullArgument.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | 4 | annotation class Optimize 5 | 6 | @Optimize 7 | fun target(dead: FloatScalar, a: DScalar): DScalar { 8 | val constant = 5.0f 9 | val i2 = constant * a 10 | return i2 11 | } 12 | 13 | fun nonOptimal_target(dead: FloatScalar, a: DScalar): DScalar { 14 | val constant = 5.0f 15 | val i2 = constant * a 16 | return i2 17 | } 18 | 19 | fun box(): String { 20 | val x = FloatScalar(2.15f) 21 | val dead = FloatScalar(3.0f) 22 | val derivative = primalAndReverseDerivative(x, { y: DScalar -> target(dead, y) }).second 23 | val expected_derivative = primalAndReverseDerivative(x, { y: DScalar -> nonOptimal_target(dead, y) }).second 24 | val tol = 0.000001 25 | if (Math.abs(derivative.basePrimal().value - expected_derivative.basePrimal().value) > tol) { 26 | return "FAIL: expected ${expected_derivative.basePrimal().value} but got ${derivative.basePrimal().value}" 27 | } 28 | return "OK" 29 | } 30 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/parameterWithNonParameterizedTypeArg.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | import org.diffkt.adOptimize.ToUnboxedFunction 4 | annotation class Optimize 5 | 6 | interface NodeType { 7 | val x: DScalar 8 | } 9 | class NodeTypeImpl(y: Float) : NodeType { 10 | override val x: DScalar = FloatScalar(y) 11 | } 12 | class Nodes(val children: List) 13 | 14 | @ToUnboxedFunction("demo.getChildToFloat") 15 | fun List.getChild(index: Int): DScalar = this[index].x 16 | fun List.getChildToFloat(index: Int): Float = this[index].x.basePrimal().value 17 | 18 | @Optimize 19 | fun target(a: DScalar, nodes: Nodes): DScalar { 20 | var i = 0 21 | var s = a 22 | while (i < nodes.children.size) { 23 | val i0 = nodes.children.getChild(i) 24 | val z = s * i0 25 | s = z 26 | i = i + 1 27 | } 28 | return s 29 | } 30 | 31 | fun nonOptimal_target(a: DScalar, nodes: Nodes): DScalar { 32 | var i = 0 33 | var s = a 34 | while (i < nodes.children.size) { 35 | val i0 = nodes.children.getChild(i) 36 | val z = s * i0 37 | s = z 38 | i = i + 1 39 | } 40 | return s 41 | } 42 | 43 | fun box(): String { 44 | val nodes = Nodes((0 until 3).map { NodeTypeImpl(it.toFloat() / 100f) }) 45 | val x = FloatScalar(1.15f) 46 | 47 | val derivativePair = primalAndReverseDerivative(x, { y: DScalar -> target(y, nodes) }) 48 | val expected_derivativePair = primalAndReverseDerivative(x, { y: DScalar -> nonOptimal_target(y, nodes) }) 49 | val tol = 0.000001f 50 | 51 | if (Math.abs(derivativePair.first.basePrimal().value - expected_derivativePair.first.basePrimal().value) > tol) { 52 | return "PRIMAL FAIL: expected ${expected_derivativePair.first.basePrimal().value} but got ${derivativePair.first.basePrimal().value}" 53 | } 54 | if (Math.abs(derivativePair.second.basePrimal().value - expected_derivativePair.second.basePrimal().value) > tol) { 55 | return "DERIVATIVE FAIL:expected ${expected_derivativePair.second.basePrimal().value} but got ${derivativePair.second.basePrimal().value}" 56 | } 57 | 58 | return "OK" 59 | } 60 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/reverseForwardControlFlow.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | 4 | annotation class Optimize 5 | annotation class SecondOrderOptimize 6 | 7 | @Optimize 8 | @SecondOrderOptimize 9 | fun target(a: DScalar, L: Int): DScalar { 10 | var i = 0 11 | var s = a 12 | while (i < L) { 13 | val z = s * s 14 | s = z 15 | i = i + 1 16 | } 17 | return s 18 | } 19 | 20 | fun nonOptimal_target(a: DScalar, L: Int): DScalar { 21 | var i = 0 22 | var s = a 23 | while (i < L) { 24 | val z = s * s 25 | s = z 26 | i = i + 1 27 | } 28 | return s 29 | } 30 | 31 | fun box(): String { 32 | val x = FloatScalar(1.15f) 33 | val constant1 = 4 34 | val constant2 = 7 35 | for (element in listOf(constant1, constant2).withIndex()) { 36 | val c = element.value 37 | val primal_derivative = primalAndForwardDerivative( 38 | x = x, 39 | f = { z: DScalar -> primalAndReverseDerivative(z, { xx: DScalar -> target(xx, c) }).second } 40 | ) 41 | val firstOrderDerivative = primal_derivative.first.basePrimal().value 42 | val secondOrderDerivative = primal_derivative.second.basePrimal().value 43 | val primal_derivative_expectation = primalAndForwardDerivative( 44 | x = x, 45 | f = { z: DScalar -> primalAndReverseDerivative(z, { xx: DScalar -> nonOptimal_target(xx, c) }).second } 46 | ) 47 | val expectedSecondOrderDerivative = primal_derivative_expectation.second.basePrimal().value 48 | val expectedFirstOrderDerivative = primal_derivative_expectation.first.basePrimal().value 49 | val tol = 0.000001f 50 | 51 | if (Math.abs(firstOrderDerivative - expectedFirstOrderDerivative) > tol) { 52 | return "FIRST ORDER DERIVATIVE FAIL: (index ${element.index}) expected $expectedFirstOrderDerivative but got $firstOrderDerivative" 53 | } 54 | if (Math.abs(secondOrderDerivative - expectedSecondOrderDerivative) > tol) { 55 | return "SECOND ORDER DERIVATIVE FAIL: (index ${element.index}) expected $expectedSecondOrderDerivative but got $secondOrderDerivative" 56 | } 57 | } 58 | 59 | return "OK" 60 | } 61 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/reverseForwardControlFlowNested.kt: -------------------------------------------------------------------------------- 1 | 2 | package demo 3 | import org.diffkt.* 4 | annotation class Optimize 5 | annotation class SecondOrderOptimize 6 | 7 | @SecondOrderOptimize 8 | fun target(a: DScalar, L: Int): DScalar { 9 | var i = 0 10 | var s = a 11 | while (i < L) { 12 | val j0 = i % 3 13 | val j1 = i % 2 14 | when { 15 | j0 == 0 -> { 16 | when { 17 | j1 == 0 -> { 18 | s = s * s 19 | } 20 | else -> { 21 | s = s * s * s 22 | } 23 | } 24 | } 25 | else -> { 26 | s = s * a 27 | } 28 | } 29 | i++ 30 | } 31 | return s 32 | } 33 | 34 | fun nonOptimal_target(a: DScalar, L: Int): DScalar { 35 | var i = 0 36 | var s = a 37 | while (i < L) { 38 | val j0 = i % 3 39 | val j1 = i % 2 40 | when { 41 | j0 == 0 -> { 42 | when { 43 | j1 == 0 -> { 44 | s = s * s 45 | } 46 | else -> { 47 | s = s * s * s 48 | } 49 | } 50 | } 51 | else -> { 52 | s = s * a 53 | } 54 | } 55 | i++ 56 | } 57 | return s 58 | } 59 | 60 | fun box(): String { 61 | val x = FloatScalar(1.15f) 62 | val constant1 = 4 63 | val constant2 = 7 64 | for (element in listOf(constant1, constant2).withIndex()) { 65 | val c = element.value 66 | val primal_derivative = primalAndForwardDerivative( 67 | x = x, 68 | f = { z: DScalar -> primalAndReverseDerivative(z, { xx: DScalar -> target(xx, c) }).second } 69 | ) 70 | val primal = primal_derivative.first.basePrimal().value 71 | val firstOrderDerivative = primal_derivative.second.basePrimal().value 72 | val primal_derivative_expectation = primalAndForwardDerivative( 73 | x = x, 74 | f = { z: DScalar -> primalAndReverseDerivative(z, { xx: DScalar -> nonOptimal_target(xx, c) }).second } 75 | ) 76 | val expectedFirstOrderDerivative = primal_derivative_expectation.second.basePrimal().value 77 | val expectedPrimal = primal_derivative_expectation.first.basePrimal().value 78 | val tol = 0.000001f 79 | 80 | if (Math.abs(primal - expectedPrimal) > tol) { 81 | return "PRIMAL FAIL: (index ${element.index}) expected $expectedPrimal but got $primal" 82 | } 83 | if (Math.abs(firstOrderDerivative - expectedFirstOrderDerivative) > tol) { 84 | return "DERIVATIVE FAIL: (index ${element.index}) expected $expectedFirstOrderDerivative but got $firstOrderDerivative" 85 | } 86 | } 87 | 88 | return "OK" 89 | } 90 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/reverseForwardControlFlowPrimal.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | 4 | annotation class Optimize 5 | annotation class SecondOrderOptimize 6 | 7 | @Optimize 8 | @SecondOrderOptimize 9 | fun target(a: DScalar, L: Int): DScalar { 10 | var i = 0 11 | var s = a 12 | while (i < L) { 13 | val z = s * s 14 | s = z 15 | i = i + 1 16 | } 17 | return s 18 | } 19 | 20 | fun nonOptimal_target(a: DScalar, L: Int): DScalar { 21 | var i = 0 22 | var s = a 23 | while (i < L) { 24 | val z = s * s 25 | s = z 26 | i = i + 1 27 | } 28 | return s 29 | } 30 | 31 | fun box(): String { 32 | val x = FloatScalar(1.15f) 33 | val constant1 = 4 34 | val constant2 = 7 35 | for (element in listOf(constant1, constant2).withIndex()) { 36 | val c = element.value 37 | val primal_derivative = primalAndForwardDerivative( 38 | x = x, 39 | f = { z: DScalar -> primalAndReverseDerivative(z, { xx: DScalar -> target(xx, c) }).first } 40 | ) 41 | val primal = primal_derivative.first.basePrimal().value 42 | val firstOrderDerivative = primal_derivative.second.basePrimal().value 43 | val primal_derivative_expectation = primalAndForwardDerivative( 44 | x = x, 45 | f = { z: DScalar -> primalAndReverseDerivative(z, { xx: DScalar -> nonOptimal_target(xx, c) }).first } 46 | ) 47 | val expectedFirstOrderDerivative = primal_derivative_expectation.second.basePrimal().value 48 | val expectedPrimal = primal_derivative_expectation.first.basePrimal().value 49 | val tol = 0.000001f 50 | 51 | if (Math.abs(primal - expectedPrimal) > tol) { 52 | return "PRIMAL FAIL: (index ${element.index}) expected $expectedPrimal but got $primal" 53 | } 54 | if (Math.abs(firstOrderDerivative - expectedFirstOrderDerivative) > tol) { 55 | return "DERIVATIVE FAIL: (index ${element.index}) expected $expectedFirstOrderDerivative but got $firstOrderDerivative" 56 | } 57 | } 58 | 59 | return "OK" 60 | } 61 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/reverseForwardControlFlowWhen.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | 4 | annotation class Optimize 5 | annotation class SecondOrderOptimize 6 | 7 | @Optimize 8 | @SecondOrderOptimize 9 | fun target(a: DScalar, c: Float): DScalar { 10 | var b = a + a 11 | when { 12 | c > 2.0f -> { b = b * b } 13 | else -> { b = b + a } 14 | } 15 | return b * a 16 | } 17 | 18 | fun nonOptimal_target(a: DScalar, c: Float): DScalar { 19 | var b = a + a 20 | when { 21 | c > 2.0f -> b = b * b 22 | else -> b = b + a 23 | } 24 | return b * a 25 | } 26 | 27 | fun box(): String { 28 | val x = FloatScalar(2f) 29 | val constant1 = 1.0f 30 | val constant2 = 3.0f 31 | val constant3 = -1f 32 | val tol = 0.000001f 33 | for (element in listOf(constant1, constant2, constant3).withIndex()) { 34 | val c = element.value 35 | val primal_derivative1 = primalAndForwardDerivative( 36 | x = x, 37 | f = { z: DScalar -> target(z, c) } 38 | ) 39 | val primal1 = primal_derivative1.first.basePrimal().value 40 | val firstOrderDerivative = primal_derivative1.second.basePrimal().value 41 | val primal_derivative_expectation1 = primalAndForwardDerivative( 42 | x = x, 43 | f = { z: DScalar -> nonOptimal_target(z, c) } 44 | ) 45 | val expectedDerivative1 = primal_derivative_expectation1.second.basePrimal().value 46 | val expectedPrimal1 = primal_derivative_expectation1.first.basePrimal().value 47 | 48 | if (Math.abs(primal1 - expectedPrimal1) > tol) { 49 | return "PRIMAL FAIL: (index ${element.index}) expected $expectedPrimal1 but got $primal1" 50 | } 51 | if (Math.abs(firstOrderDerivative - expectedDerivative1) > tol) { 52 | return "FIRST ORDER DERIVATIVE FAIL: (index ${element.index}) expected $expectedDerivative1 but got $firstOrderDerivative" 53 | } 54 | val primal_derivative = primalAndForwardDerivative( 55 | x = x, 56 | f = { z: DScalar -> primalAndReverseDerivative(z, { xx: DScalar -> target(xx, c) }).second } 57 | ) 58 | val primal = primal_derivative.first.basePrimal().value 59 | val secondOrderDerivative = primal_derivative.second.basePrimal().value 60 | val primal_derivative_expectation = primalAndForwardDerivative( 61 | x = x, 62 | f = { z: DScalar -> primalAndReverseDerivative(z, { xx: DScalar -> nonOptimal_target(xx, c) }).second } 63 | ) 64 | val expectedSecondOrderDerivative = primal_derivative_expectation.second.basePrimal().value 65 | val expectedPrimal = primal_derivative_expectation.first.basePrimal().value 66 | val tol = 0.000001f 67 | 68 | if (Math.abs(primal - expectedPrimal) > tol) { 69 | return "1 DERIVATIVE FAIL: (index ${element.index}) expected $expectedPrimal but got $primal" 70 | } 71 | if (Math.abs(secondOrderDerivative - expectedSecondOrderDerivative) > tol) { 72 | return "2 DERIVATIVE FAIL: (index ${element.index}) expected $expectedSecondOrderDerivative but got $secondOrderDerivative" 73 | } 74 | } 75 | 76 | return "OK" 77 | } 78 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/reverseForwardControlFlowWhenPrimal.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | 4 | annotation class Optimize 5 | annotation class SecondOrderOptimize 6 | 7 | @Optimize 8 | @SecondOrderOptimize 9 | fun target(a: DScalar, c: Float): DScalar { 10 | var b = a 11 | if (c > 2.0f) { 12 | b = b * b 13 | } else { 14 | b = b * a 15 | } 16 | return b 17 | } 18 | 19 | fun nonOptimal_target(a: DScalar, c: Float): DScalar { 20 | var b = a 21 | if (c > 2.0f) { 22 | b = b * b 23 | } else { 24 | b = b * a 25 | } 26 | return b 27 | } 28 | 29 | fun box(): String { 30 | val x = FloatScalar(1.15f) 31 | val constant1 = 1.0f 32 | val constant2 = 3.0f 33 | val constant3 = -1f 34 | for (element in listOf(constant1, constant2, constant3).withIndex()) { 35 | val c = element.value 36 | val primal_derivative = primalAndForwardDerivative( 37 | x = x, 38 | f = { z: DScalar -> primalAndReverseDerivative(z, { xx: DScalar -> target(xx, c) }).first } 39 | ) 40 | val primal = primal_derivative.first.basePrimal().value 41 | val firstOrderDerivative = primal_derivative.second.basePrimal().value 42 | val primal_derivative_expectation = primalAndForwardDerivative( 43 | x = x, 44 | f = { z: DScalar -> primalAndReverseDerivative(z, { xx: DScalar -> nonOptimal_target(xx, c) }).first } 45 | ) 46 | val expectedFirstOrderDerivative = primal_derivative_expectation.second.basePrimal().value 47 | val expectedPrimal = primal_derivative_expectation.first.basePrimal().value 48 | val tol = 0.000001f 49 | 50 | if (Math.abs(primal - expectedPrimal) > tol) { 51 | return "PRIMAL FAIL: (index ${element.index}) expected $expectedPrimal but got $primal" 52 | } 53 | if (Math.abs(firstOrderDerivative - expectedFirstOrderDerivative) > tol) { 54 | return "DERIVATIVE FAIL: (index ${element.index}) expected $expectedFirstOrderDerivative but got $firstOrderDerivative" 55 | } 56 | } 57 | 58 | return "OK" 59 | } 60 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/reverseForwardDerivative.kt: -------------------------------------------------------------------------------- 1 | 2 | 3 | package demo 4 | import org.diffkt.* 5 | 6 | annotation class Optimize 7 | annotation class SecondOrderOptimize 8 | 9 | @SecondOrderOptimize 10 | fun target(x: DScalar): DScalar { 11 | val i0 = x * x 12 | val i1 = i0 * i0 13 | return i1 14 | } 15 | 16 | fun nonOptimal_target(x: DScalar): DScalar { 17 | val i0 = x * x 18 | val i1 = i0 * i0 19 | return i1 20 | } 21 | 22 | fun box(): String { 23 | val x = FloatScalar(2f) 24 | val primal_derivative = primalAndForwardDerivative( 25 | x = x, 26 | f = { z: DScalar -> primalAndReverseDerivative(z, ::target).second } 27 | ) 28 | val firstOrderDerivative = primal_derivative.first.basePrimal().value 29 | val secondOrderDerivative = primal_derivative.second.basePrimal().value 30 | val primal_derivative_expectation = primalAndForwardDerivative( 31 | x = x, 32 | f = { z: DScalar -> primalAndReverseDerivative(z, ::nonOptimal_target).second } 33 | ) 34 | val expectedSecondOrderDerivative = primal_derivative_expectation.second.basePrimal().value 35 | val expectedFirstOrderDerivative = primal_derivative_expectation.first.basePrimal().value 36 | val tol = 0.000001f 37 | return if (Math.abs(expectedFirstOrderDerivative - firstOrderDerivative) > tol || Math.abs(secondOrderDerivative - expectedSecondOrderDerivative) > tol) { 38 | "FAIL: should have instantiated an empty node: df: $firstOrderDerivative , expected: $expectedFirstOrderDerivative df2: $secondOrderDerivative, $expectedSecondOrderDerivative" 39 | } else "OK" 40 | } 41 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/reverseForwardLn.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | annotation class Optimize 4 | annotation class SecondOrderOptimize 5 | 6 | @SecondOrderOptimize 7 | @Optimize 8 | fun lnOf(value: DScalar): DScalar { 9 | return ln(value) 10 | } 11 | 12 | fun vanillaLnOf(value: DScalar): DScalar { 13 | return ln(value) 14 | } 15 | 16 | fun box(): String { 17 | val value = FloatScalar(1.715f) 18 | val (primal, derivative) = primalAndForwardDerivative( 19 | x = value, 20 | f = { z: DScalar -> primalAndReverseDerivative(z, { xx: DScalar -> lnOf(xx) }).second } 21 | ) 22 | val (expected_primal, expected_derivative) = primalAndForwardDerivative( 23 | x = value, 24 | f = { z: DScalar -> primalAndReverseDerivative(z, { xx: DScalar -> vanillaLnOf(xx) }).second } 25 | ) 26 | val tol = 0.000001f 27 | if (Math.abs(primal.basePrimal().value - expected_primal.basePrimal().value) > tol) { 28 | return "PRIMAL FAIL: expected ${expected_primal.basePrimal().value} but got ${primal.basePrimal().value}" 29 | } 30 | if (Math.abs(derivative.basePrimal().value - expected_derivative.basePrimal().value) > tol) { 31 | return "DERIVATIVE FAIL: expected ${expected_derivative.basePrimal().value} but got ${derivative.basePrimal().value}." 32 | } 33 | return "OK" 34 | } 35 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/reverseForwardLogProb.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | annotation class Optimize 4 | annotation class SecondOrderOptimize 5 | 6 | class Normal { 7 | companion object { 8 | val lnSqrt2Pi = kotlin.math.ln(kotlin.math.sqrt(2.0 * 3.14159)).toFloat() 9 | } 10 | } 11 | 12 | @SecondOrderOptimize 13 | @Optimize 14 | fun logProbOf(value: DScalar, loc: FloatScalar, scale: FloatScalar): DScalar { 15 | val twoFloat = 2f 16 | val normal = Normal 17 | val variance = scale.pow(twoFloat) 18 | val constant = normal.lnSqrt2Pi 19 | val i0 = ln(scale) 20 | val i1 = i0 - constant 21 | val i2 = value - loc 22 | val i3 = i2.pow(twoFloat) 23 | val i4 = -i3 24 | val i5 = twoFloat * variance 25 | val i6 = i4 / i5 26 | val i7 = i6 - i1 27 | return i7 28 | } 29 | 30 | fun vanillaLogProbOf(value: DScalar, loc: FloatScalar, scale: FloatScalar): DScalar { 31 | val twoFloat = 2f 32 | val normal = Normal 33 | val variance = scale.pow(twoFloat) 34 | val constant = normal.lnSqrt2Pi 35 | val i0 = ln(scale) 36 | val i1 = i0 - constant 37 | val i2 = value - loc 38 | val i3 = i2.pow(twoFloat) 39 | val i4 = -i3 40 | val i5 = twoFloat * variance 41 | val i6 = i4 / i5 42 | val i7 = i6 - i1 43 | return i7 44 | } 45 | 46 | fun box(): String { 47 | val value = FloatScalar(1.715f) 48 | val loc = FloatScalar(3.0f) 49 | val scale = FloatScalar(2.5f) 50 | val (primal, derivative) = primalAndForwardDerivative( 51 | x = value, 52 | f = { z: DScalar -> primalAndReverseDerivative(z, { xx: DScalar -> logProbOf(xx, loc, scale) }).second } 53 | ) 54 | val (expected_primal, expected_derivative) = primalAndForwardDerivative( 55 | x = value, 56 | f = { z: DScalar -> primalAndReverseDerivative(z, { xx: DScalar -> vanillaLogProbOf(xx, loc, scale) }).second } 57 | ) 58 | val tol = 0.000001f 59 | if (Math.abs(primal.basePrimal().value - expected_primal.basePrimal().value) > tol) { 60 | return "PRIMAL FAIL: expected ${expected_primal.basePrimal().value} but got ${primal.basePrimal().value}" 61 | } 62 | if (Math.abs(derivative.basePrimal().value - expected_derivative.basePrimal().value) > tol) { 63 | return "DERIVATIVE FAIL: expected ${expected_derivative.basePrimal().value} but got ${derivative.basePrimal().value}." 64 | } 65 | return "OK" 66 | } 67 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/reverseForwardNonActiveIntermediateValues.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | 4 | annotation class Optimize 5 | annotation class SecondOrderOptimize 6 | 7 | @Optimize 8 | @SecondOrderOptimize 9 | fun target(a: DScalar, L: Int, nonActive: FloatScalar): DScalar { 10 | var i = 0 11 | var s = a 12 | while (i < L) { 13 | val k = nonActive + 2f 14 | val z = s * k 15 | s = z 16 | i = i + 1 17 | } 18 | return s 19 | } 20 | 21 | fun nonOptimal_target(a: DScalar, L: Int, nonActive: FloatScalar): DScalar { 22 | var i = 0 23 | var s = a 24 | while (i < L) { 25 | val k = nonActive + 2f 26 | val z = s * k 27 | s = z 28 | i = i + 1 29 | } 30 | return s 31 | } 32 | 33 | fun box(): String { 34 | val x = FloatScalar(1.15f) 35 | val constant1 = 4 36 | val constant2 = 7 37 | for (element in listOf(constant1, constant2).withIndex()) { 38 | val c = element.value 39 | val primal_derivative = primalAndForwardDerivative( 40 | x = x, 41 | f = { z: DScalar -> primalAndReverseDerivative(z, { xx: DScalar -> target(xx, c, x) }).second } 42 | ) 43 | val firstOrderDerivative = primal_derivative.first.basePrimal().value 44 | val secondOrderDerivative = primal_derivative.second.basePrimal().value 45 | val primal_derivative_expectation = primalAndForwardDerivative( 46 | x = x, 47 | f = { z: DScalar -> primalAndReverseDerivative(z, { xx: DScalar -> nonOptimal_target(xx, c, x) }).second } 48 | ) 49 | val expectedSecondOrderDerivative = primal_derivative_expectation.second.basePrimal().value 50 | val expectedFirstOrderDerivative = primal_derivative_expectation.first.basePrimal().value 51 | val tol = 0.000001f 52 | 53 | if (Math.abs(firstOrderDerivative - expectedFirstOrderDerivative) > tol) { 54 | return "FIRST ORDER DERIVATIVE FAIL: (index ${element.index}) expected $expectedFirstOrderDerivative but got $firstOrderDerivative" 55 | } 56 | if (Math.abs(secondOrderDerivative - expectedSecondOrderDerivative) > tol) { 57 | return "SECOND ORDER DERIVATIVE FAIL: (index ${element.index}) expected $expectedSecondOrderDerivative but got $secondOrderDerivative" 58 | } 59 | } 60 | 61 | return "OK" 62 | } 63 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/reverseForwardPrimal.kt: -------------------------------------------------------------------------------- 1 | 2 | 3 | package demo 4 | import org.diffkt.* 5 | 6 | annotation class Optimize 7 | 8 | annotation class SecondOrderOptimize 9 | 10 | @SecondOrderOptimize 11 | fun target(x: DScalar): DScalar { 12 | val i0 = x * x 13 | val i1 = i0 * i0 14 | return i1 15 | } 16 | 17 | fun nonOptimal_target(x: DScalar): DScalar { 18 | val i0 = x * x 19 | val i1 = i0 * i0 20 | return i1 21 | } 22 | 23 | fun box(): String { 24 | val x = FloatScalar(2f) 25 | val primal_derivative = primalAndForwardDerivative( 26 | x = x, 27 | f = { z: DScalar -> primalAndReverseDerivative(z, ::target).first } 28 | ) 29 | val primal = primal_derivative.first.basePrimal().value 30 | val firstOrderDerivative = primal_derivative.second.basePrimal().value 31 | val primal_derivative_expectation = primalAndForwardDerivative( 32 | x = x, 33 | f = { z: DScalar -> primalAndReverseDerivative(z, ::nonOptimal_target).first } 34 | ) 35 | val expectedPrimal = primal_derivative_expectation.first.basePrimal().value 36 | val expectedFirstDerivative = primal_derivative_expectation.second.basePrimal().value 37 | val tol = 0.000001f 38 | return if (Math.abs(expectedPrimal - primal) > tol || Math.abs(firstOrderDerivative - expectedFirstDerivative) > tol) { 39 | "FAIL: should have instantiated an empty node: df: $primal , expected: $expectedPrimal df2: $firstOrderDerivative, $expectedFirstDerivative" 40 | } else "OK" 41 | } 42 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/scalarNoop.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | 4 | annotation class Optimize 5 | 6 | @Optimize 7 | fun target(a: DScalar, b: FloatScalar): DScalar { 8 | val scalarNoop = b.publicFauxExpandToTangent(b) 9 | val i0 = sigmoid(a) 10 | return i0 11 | } 12 | 13 | fun nonOptimal_target(a: DScalar, b: FloatScalar): DScalar { 14 | val scalarNoop = b.publicFauxExpandToTangent(b) 15 | val i0 = sigmoid(a) 16 | return i0 17 | } 18 | 19 | @org.diffkt.adOptimize.ScalarNoop 20 | fun DTensor.publicFauxExpandToTangent(tangent: DTensor): DTensor { 21 | return this 22 | } 23 | 24 | fun box(): String { 25 | val w = FloatScalar(1.23f) 26 | val x = FloatScalar(1.15f) 27 | val primal_derivative = primalAndReverseDerivative(x, { y: DScalar -> target(y, w) }) 28 | val primal = primal_derivative.first 29 | val derivative = primal_derivative.second 30 | val expected_primal_derivative = primalAndReverseDerivative(x, { y: DScalar -> nonOptimal_target(y, w) }) 31 | val expected_derivative = expected_primal_derivative.second 32 | val expected_primal = expected_primal_derivative.first 33 | val tol = 0.000001f 34 | if (Math.abs(derivative.basePrimal().value - expected_derivative.basePrimal().value) > tol) { 35 | return "FAIL: expected ${expected_derivative.basePrimal().value} but got ${derivative.basePrimal().value}" 36 | } 37 | if (Math.abs(primal.basePrimal().value - expected_primal.basePrimal().value) > tol) { 38 | return "FAIL: expected ${expected_primal.basePrimal().value} but got ${primal.basePrimal().value}" 39 | } 40 | return "OK" 41 | } 42 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/simpleWhileLoop.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | 4 | annotation class Optimize 5 | 6 | @Optimize 7 | fun target(a: DScalar, L: Int): DScalar { 8 | var i = 0 9 | var s = a 10 | while (i < L) { 11 | val z = s * s 12 | s = z 13 | i = i + 1 14 | } 15 | return s 16 | } 17 | 18 | fun nonOptimal_target(a: DScalar, L: Int): DScalar { 19 | var i = 0 20 | var s = a 21 | while (i < L) { 22 | val z = s * s 23 | s = z 24 | i = i + 1 25 | } 26 | return s 27 | } 28 | 29 | fun box(): String { 30 | val x = FloatScalar(1.15f) 31 | val constant1 = 4 32 | val constant2 = 7 33 | for (element in listOf(constant1, constant2).withIndex()) { 34 | val c = element.value 35 | val derivativePair = primalAndReverseDerivative(x, { y: DScalar -> target(y, c) }) 36 | val expected_derivativePair = primalAndReverseDerivative(x, { y: DScalar -> nonOptimal_target(y, c) }) 37 | val tol = 0.000001f 38 | 39 | if (Math.abs(derivativePair.first.basePrimal().value - expected_derivativePair.first.basePrimal().value) > tol) { 40 | return "PRIMAL FAIL: (index ${element.index}) expected ${expected_derivativePair.first.basePrimal().value} but got ${derivativePair.first.basePrimal().value}" 41 | } 42 | if (Math.abs(derivativePair.second.basePrimal().value - expected_derivativePair.second.basePrimal().value) > tol) { 43 | return "DERIVATIVE FAIL: (index ${element.index}) expected ${expected_derivativePair.second.basePrimal().value} but got ${derivativePair.second.basePrimal().value}" 44 | } 45 | } 46 | 47 | return "OK" 48 | } 49 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/switchAssign.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | 4 | annotation class Optimize 5 | 6 | @Optimize 7 | fun target(a: DScalar): DScalar { 8 | val b = cos(a) 9 | val x = if (b >= 0f) { 10 | val i0 = a * a 11 | val i3 = i0 * b 12 | i3 13 | } else { 14 | a 15 | } 16 | val w = when { 17 | b >= 0f -> { 18 | val i1 = a * a 19 | i1 20 | } 21 | else -> { 22 | val i2 = a + a 23 | i2 24 | } 25 | } 26 | val z = w * x 27 | return z 28 | } 29 | 30 | fun nonOptimal_target(a: DScalar): DScalar { 31 | val b = cos(a) 32 | val x = if (b >= 0f) { 33 | val i0 = a * a 34 | val i3 = i0 * b 35 | i3 36 | } else { 37 | a 38 | } 39 | val w = when { 40 | b >= 0f -> { 41 | val i1 = a * a 42 | i1 43 | } 44 | else -> { 45 | val i2 = a + a 46 | i2 47 | } 48 | } 49 | val z = w * x 50 | return z 51 | } 52 | 53 | fun box(): String { 54 | val x1 = FloatScalar(2.15f) 55 | val x2 = FloatScalar(0.5f) 56 | for (input in listOf(x1, x2)) { 57 | val primal_derivative = primalAndReverseDerivative(input, { y: DScalar -> target(y) }) 58 | val primal = primal_derivative.first 59 | val derivative = primal_derivative.second 60 | val expected_primal_derivative = primalAndReverseDerivative(input, { y: DScalar -> nonOptimal_target(y) }) 61 | val expected_derivative = expected_primal_derivative.second 62 | val expected_primal = expected_primal_derivative.first 63 | val tol = 0.000001f 64 | if (Math.abs(derivative.basePrimal().value - expected_derivative.basePrimal().value) > tol) { 65 | return "FAIL: expected ${expected_derivative.basePrimal().value} but got ${derivative.basePrimal().value}" 66 | } 67 | if (Math.abs(primal.basePrimal().value - expected_primal.basePrimal().value) > tol) { 68 | return "FAIL: expected ${expected_primal.basePrimal().value} but got ${primal.basePrimal().value}" 69 | } 70 | } 71 | 72 | return "OK" 73 | } 74 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/unwrapONE.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | 4 | annotation class Optimize 5 | 6 | @Optimize 7 | fun target(a: DScalar): DScalar { 8 | val i0 = a.pow(2f) 9 | var i1: DScalar = FloatScalar.ONE 10 | if (i0 > 4f) { 11 | i1 = i1 + 3f 12 | } else { 13 | i1 = i1 + 4f 14 | } 15 | val i2 = i1 * i0 16 | return i2 17 | } 18 | 19 | fun nonOptimal_target(a: DScalar): DScalar { 20 | val i0 = a.pow(2f) 21 | var i1: DScalar = FloatScalar.ONE 22 | if (i0 > 4f) { 23 | i1 = i1 + 3f 24 | } else { 25 | i1 = i1 + 4f 26 | } 27 | val i2 = i1 * i0 28 | return i2 29 | } 30 | 31 | fun box(): String { 32 | val x = FloatScalar(2.15f) 33 | val primal_derivative = primalAndReverseDerivative(x, { y: DScalar -> target(y) }) 34 | val primal = primal_derivative.first 35 | val derivative = primal_derivative.second 36 | val expected_primal_derivative = primalAndReverseDerivative(x, { y: DScalar -> nonOptimal_target(y) }) 37 | val expected_derivative = expected_primal_derivative.second 38 | val expected_primal = expected_primal_derivative.first 39 | val tol = 0.000001f 40 | if (Math.abs(derivative.basePrimal().value - expected_derivative.basePrimal().value) > tol) { 41 | return "FAIL: expected ${expected_derivative.basePrimal().value} but got ${derivative.basePrimal().value}" 42 | } 43 | if (Math.abs(primal.basePrimal().value - expected_primal.basePrimal().value) > tol) { 44 | return "FAIL: expected ${expected_primal.basePrimal().value} but got ${primal.basePrimal().value}" 45 | } 46 | return "OK" 47 | } 48 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/unwrapZERO.kt: -------------------------------------------------------------------------------- 1 | package demo 2 | import org.diffkt.* 3 | 4 | annotation class Optimize 5 | 6 | @Optimize 7 | fun target(a: DScalar): DScalar { 8 | val i0 = a.pow(2f) 9 | var i1: DScalar = FloatScalar.ZERO 10 | if (i0 > 4f) { 11 | i1 = i1 + 3f 12 | } else { 13 | i1 = i1 + 4f 14 | } 15 | val i2 = i1 * i0 16 | return i2 17 | } 18 | 19 | fun nonOptimal_target(a: DScalar): DScalar { 20 | val i0 = a.pow(2f) 21 | var i1: DScalar = FloatScalar.ZERO 22 | if (i0 > 4f) { 23 | i1 = i1 + 3f 24 | } else { 25 | i1 = i1 + 4f 26 | } 27 | val i2 = i1 * i0 28 | return i2 29 | } 30 | 31 | fun box(): String { 32 | val x = FloatScalar(2.15f) 33 | val primal_derivative = primalAndReverseDerivative(x, { y: DScalar -> target(y) }) 34 | val primal = primal_derivative.first 35 | val derivative = primal_derivative.second 36 | val expected_primal_derivative = primalAndReverseDerivative(x, { y: DScalar -> nonOptimal_target(y) }) 37 | val expected_derivative = expected_primal_derivative.second 38 | val expected_primal = expected_primal_derivative.first 39 | val tol = 0.000001f 40 | if (Math.abs(derivative.basePrimal().value - expected_derivative.basePrimal().value) > tol) { 41 | return "FAIL: expected ${expected_derivative.basePrimal().value} but got ${derivative.basePrimal().value}" 42 | } 43 | if (Math.abs(primal.basePrimal().value - expected_primal.basePrimal().value) > tol) { 44 | return "FAIL: expected ${expected_primal.basePrimal().value} but got ${primal.basePrimal().value}" 45 | } 46 | return "OK" 47 | } 48 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/codegen/while_statement.kt: -------------------------------------------------------------------------------- 1 | 2 | package demo 3 | import org.diffkt.* 4 | annotation class Optimize 5 | 6 | @Optimize 7 | fun target(a: DScalar, L: Int): DScalar { 8 | var b = a * a 9 | val i0 = b + a 10 | var i = 0 11 | var s = i0 + i0 12 | while (i < L) { 13 | val j0 = i % 3 14 | val j1 = i % 2 15 | when { 16 | j0 == 0 -> { 17 | when { 18 | j1 == 0 -> { 19 | s = s * b 20 | } 21 | else -> { 22 | s = s * i0 23 | } 24 | } 25 | } 26 | else -> { 27 | s = s * a 28 | } 29 | } 30 | i++ 31 | } 32 | return s 33 | } 34 | 35 | fun nonOptimal_target(a: DScalar, L: Int): DScalar { 36 | var b = a * a 37 | val i0 = b + a 38 | var i = 0 39 | var s = i0 + i0 40 | while (i < L) { 41 | val j0 = i % 3 42 | val j1 = i % 2 43 | when { 44 | j0 == 0 -> { 45 | when { 46 | j1 == 0 -> { 47 | s = s * b 48 | } 49 | else -> { 50 | s = s * i0 51 | } 52 | } 53 | } 54 | else -> { 55 | s = s * a 56 | } 57 | } 58 | i++ 59 | } 60 | return s 61 | } 62 | 63 | fun box(): String { 64 | val x = FloatScalar(1.15f) 65 | val constant1 = 4 66 | val constant2 = 7 67 | for (element in listOf(constant1, constant2).withIndex()) { 68 | val c = element.value 69 | val derivativePair = primalAndReverseDerivative(x, { y: DScalar -> target(y, c) }) 70 | val expected_derivativePair = primalAndReverseDerivative(x, { y: DScalar -> nonOptimal_target(y, c) }) 71 | val tol = 0.000001f 72 | 73 | if (Math.abs(derivativePair.first.basePrimal().value - expected_derivativePair.first.basePrimal().value) > tol) { 74 | return "PRIMAL FAIL: (index ${element.index}) expected ${expected_derivativePair.first.basePrimal().value} but got ${derivativePair.first.basePrimal().value}" 75 | } 76 | if (Math.abs(derivativePair.second.basePrimal().value - expected_derivativePair.second.basePrimal().value) > tol) { 77 | return "DERIVATIVE FAIL: (index ${element.index}) expected ${expected_derivativePair.second.basePrimal().value} but got ${derivativePair.second.basePrimal().value}" 78 | } 79 | } 80 | 81 | return "OK" 82 | } 83 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/ir/control_flow.kt: -------------------------------------------------------------------------------- 1 | // WITH_RUNTIME 2 | // SKIP_KT_DUMP 3 | package demo 4 | import org.diffkt.* 5 | 6 | annotation class Optimize 7 | 8 | @Optimize 9 | fun nested_if(a: DScalar, c: Float): DScalar { 10 | val temp0 = 0.0f 11 | val i0 = a * a 12 | var b = i0 * i0 13 | when { 14 | c > temp0 -> { 15 | b = b * i0 16 | } 17 | else -> { 18 | b = i0 * a 19 | } 20 | } 21 | val y = b * i0 22 | return y 23 | } 24 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/ir/derivative.kt: -------------------------------------------------------------------------------- 1 | // WITH_RUNTIME 2 | // SKIP_KT_DUMP 3 | package demo 4 | import org.diffkt.* 5 | 6 | annotation class Optimize 7 | 8 | @Optimize 9 | fun target(a: DScalar): DScalar { 10 | val c = 2f 11 | val i0 = a.pow(c) 12 | val i1 = i0 + i0 13 | val i2 = i1 * i1 14 | return i2 15 | } 16 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/ir/firstAndSecondOrderDerivative.kt: -------------------------------------------------------------------------------- 1 | // SKIP_KT_DUMP 2 | package demo 3 | import org.diffkt.* 4 | 5 | annotation class Optimize 6 | annotation class SecondOrderOptimize 7 | 8 | @Optimize 9 | @SecondOrderOptimize 10 | fun target(a: DScalar): DScalar { 11 | val c = 2f 12 | val i0 = a.pow(c) 13 | val i1 = i0 + i0 14 | val i2 = i1 * i1 15 | return i2 16 | } 17 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/ir/floatFunction.kt: -------------------------------------------------------------------------------- 1 | // WITH_RUNTIME 2 | // SKIP_KT_DUMP 3 | package demo 4 | import org.diffkt.* 5 | 6 | annotation class Optimize 7 | annotation class ReverseAD 8 | 9 | @Optimize 10 | fun foo(a: Float): Float { 11 | return a * a 12 | } 13 | 14 | @ReverseAD 15 | fun jacobian_transposed_vector_product(x: Float, f: (Float) -> Float): Float { 16 | TODO() 17 | } 18 | 19 | fun box() { 20 | val derivative = jacobian_transposed_vector_product(2.15f, ::foo) 21 | } 22 | -------------------------------------------------------------------------------- /adoptimize-integration-tests/src/test/testData/ir/secondOrderDerivative.kt: -------------------------------------------------------------------------------- 1 | // WITH_RUNTIME 2 | // SKIP_KT_DUMP 3 | package demo 4 | import org.diffkt.* 5 | 6 | annotation class Optimize 7 | annotation class SecondOrderOptimize 8 | 9 | @SecondOrderOptimize 10 | fun target(a: DScalar): DScalar { 11 | val c = 2f 12 | val i0 = a.pow(c) 13 | val i1 = i0 + i0 14 | val i2 = i1 * i1 15 | return i2 16 | } 17 | -------------------------------------------------------------------------------- /adoptimize-publish/build.gradle.kts: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | import com.github.jengelman.gradle.plugins.shadow.tasks.ShadowJar 9 | description = "AD optimize cli compiler plugin" 10 | 11 | plugins { 12 | id("com.github.johnrengelman.shadow") version "6.1.0" 13 | } 14 | 15 | dependencies { 16 | val ktVersion: String by System.getProperties() 17 | compileOnly("org.jetbrains.kotlin:kotlin-stdlib:$ktVersion") 18 | api("org.jetbrains.kotlin:kotlin-compiler-embeddable:$ktVersion") 19 | compileOnly(project(":adoptimize-cli-compiler-plugin")) 20 | compileOnly(project(":plugin-generators-common")) 21 | compileOnly(project(":adoptimize-common")) 22 | } 23 | 24 | val fatJarArtifact by configurations.creating 25 | val shadowFatJar: ShadowJar = tasks.getByName("shadowJar") { 26 | val convention = project.convention.getPlugin() 27 | from(convention.sourceSets.main.get().output) 28 | archiveClassifier.set("") 29 | configurations = mutableListOf(project.configurations.compileOnly.get()) 30 | dependencies { 31 | include(project(":adoptimize-cli-compiler-plugin")) 32 | include(project(":plugin-generators-common")) 33 | include(project(":adoptimize-common")) 34 | } 35 | } 36 | 37 | val publishArtifact = artifacts.add(fatJarArtifact.name, shadowFatJar) 38 | 39 | publishing { 40 | publications { 41 | create("diffkt-adoptimize-publishing") { 42 | artifactId = System.getProperty("ADOptimizeArtifactID") 43 | artifact(publishArtifact) 44 | } 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /buildSrc/build.gradle.kts: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | plugins { 9 | `kotlin-dsl` 10 | `java-gradle-plugin` 11 | } 12 | 13 | repositories { 14 | mavenCentral() 15 | mavenLocal() 16 | maven { 17 | url = uri("https://maven.pkg.jetbrains.space/kotlin/p/kotlin/bootstrap") 18 | } 19 | } 20 | 21 | dependencies { 22 | val ktVersion: String by System.getProperties() 23 | implementation("org.jetbrains.kotlin:kotlin-gradle-plugin:$ktVersion") 24 | } 25 | 26 | tasks["build"].dependsOn(":prepare-deps:build") 27 | 28 | val adOptimizeArtifactID = "meta-diffkt-adoptimize-compiler-plugin" 29 | val diffPrepCompilerPluginArtifactID = "meta-diffkt-differentiable-api-preprocessor-compiler-plugin" 30 | System.setProperty("ADOptimizeArtifactID", adOptimizeArtifactID) 31 | System.setProperty("diffPrepCompilerPluginArtifactID", diffPrepCompilerPluginArtifactID) 32 | System.setProperty("ADDiffKtVersion", "0.1.0-2d523b5") 33 | -------------------------------------------------------------------------------- /buildSrc/gradle.properties: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | systemProp.ktVersion=1.7.0-dev-444 9 | systemProp.intellijVersion=203.8084.24 10 | systemProp.kotlinBuildDir=kotlin-build-dependencies 11 | systemProp.repoDirName=repo 12 | systemProp.kotlinBuildGroupId=kotlin.build -------------------------------------------------------------------------------- /buildSrc/settings.gradle.kts: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | pluginManagement { 9 | repositories { 10 | gradlePluginPortal() 11 | } 12 | } 13 | 14 | include(":prepare-deps") 15 | -------------------------------------------------------------------------------- /buildSrc/src/main/kotlin/tasks.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | import org.gradle.api.Project 9 | import org.gradle.api.Task 10 | import org.gradle.api.artifacts.dsl.RepositoryHandler 11 | import org.gradle.api.artifacts.repositories.IvyArtifactRepository 12 | import org.gradle.api.tasks.TaskProvider 13 | import org.gradle.api.tasks.testing.Test 14 | import org.gradle.kotlin.dsl.provideDelegate 15 | import java.io.File 16 | 17 | fun Project.projectTest( 18 | taskName: String = "test", 19 | body: Test.() -> Unit = {} 20 | ): TaskProvider = getOrCreateTask(taskName) { 21 | val version = System.getProperty("intellijVersion") 22 | systemProperty("idea.home.path", "${System.getProperty("user.home")}/.gradle/kotlin-build-dependencies/repo/kotlin.build/ideaIC/$version/artifacts") 23 | systemProperty("idea.ignore.disabled.plugins", "true") 24 | body() 25 | } 26 | 27 | inline fun Project.getOrCreateTask(taskName: String, noinline body: T.() -> Unit): TaskProvider = 28 | if (tasks.names.contains(taskName)) tasks.named(taskName, T::class.java).apply { configure(body) } 29 | else tasks.register(taskName, T::class.java, body) 30 | 31 | val kotlinBuildDir: String by System.getProperties() 32 | val repoDirName: String by System.getProperties() 33 | 34 | private fun Project.kotlinBuildLocalRepoDir(): File = rootProject.gradle.gradleUserHomeDir.resolve(kotlinBuildDir).resolve(repoDirName) 35 | 36 | fun RepositoryHandler.kotlinBuildLocalRepo(project: Project): IvyArtifactRepository = ivy { 37 | url = project.kotlinBuildLocalRepoDir().toURI() 38 | 39 | patternLayout { 40 | ivy("[organisation]/[module]/[revision]/[module].ivy.xml") 41 | ivy("[organisation]/[module]/[revision]/ivy/[module].ivy.xml") 42 | ivy("[organisation]/ideaIC/[revision]/ivy/[module].ivy.xml") // bundled plugins 43 | 44 | artifact("[organisation]/[module]/[revision]/artifacts/lib/[artifact](-[classifier]).[ext]") 45 | artifact("[organisation]/[module]/[revision]/artifacts/[artifact](-[classifier]).[ext]") 46 | artifact("[organisation]/intellij-core/[revision]/artifacts/[artifact](-[classifier]).[ext]") 47 | artifact("[organisation]/ideaIC/[revision]/artifacts/plugins/[module]/lib/[artifact](-[classifier]).[ext]") // bundled plugins 48 | artifact("[organisation]/sources/[artifact]-[revision](-[classifier]).[ext]") 49 | artifact("[organisation]/[module]/[revision]/[artifact](-[classifier]).[ext]") 50 | } 51 | 52 | metadataSources { 53 | ivyDescriptor() 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /config/build.gradle.kts: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | plugins { 9 | id("com.github.gmazzo.buildconfig") version "3.0.2" 10 | } 11 | 12 | buildConfig { 13 | packageName("config") 14 | buildConfigField("String", "PLUGIN_GROUP", "\"$group\"") 15 | 16 | buildConfigField("String", "ADOPTIMIZE_ID", "\"adoptimize\"") 17 | buildConfigField("String", "DIFFPREP_ID", "\"differentiable-api-preprocessor\"") 18 | 19 | buildConfigField("String", "ADOPTIMIZE_ARTIFACT_ID", "\"${System.getProperty("ADOptimizeArtifactID")}\"") 20 | buildConfigField("String", "DIFFPREP_ARTIFACT_ID", "\"${System.getProperty("diffPrepCompilerPluginArtifactID")}\"") 21 | 22 | buildConfigField("String", "PLUGIN_VERSION", "\"$version\"") 23 | 24 | buildConfigField("String", "AD_DIFFKT_VERSION", "\"${System.getProperty("ADDiffKtVersion")}\"") 25 | buildConfigField("String", "KOTLIN_VERSION", "\"${System.getProperty("ktVersion")}\"") 26 | } 27 | 28 | publishing { 29 | publications { 30 | create("meta-adoptimize-config") { 31 | from(components["java"]) 32 | artifactId = "meta-adoptimize-config" 33 | } 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /differentiable-api-preprocessor-compiler-plugin/build.gradle.kts: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | import com.github.jengelman.gradle.plugins.shadow.tasks.ShadowJar 9 | description = "plugin to prep a differentiable api for optimization" 10 | 11 | plugins { 12 | id("com.github.johnrengelman.shadow") version "6.1.0" 13 | } 14 | 15 | dependencies { 16 | val ktVersion: String by System.getProperties() 17 | compileOnly("org.jetbrains.kotlin:kotlin-stdlib:$ktVersion") 18 | api("org.jetbrains.kotlin:kotlin-compiler-embeddable:$ktVersion") 19 | compileOnly(project(":adoptimize-common")) 20 | compileOnly(project(":plugin-generators-common")) 21 | } 22 | 23 | sourceSets { 24 | main {} 25 | test { java.srcDirs("test", "tests") } 26 | } 27 | 28 | val shadowArtifact by configurations.creating 29 | val shadowJar: ShadowJar = tasks.getByName("shadowJar") { 30 | val convention = project.convention.getPlugin() 31 | archiveClassifier.set("sources") 32 | from(convention.sourceSets.main.get().output) 33 | configurations = mutableListOf(project.configurations.compileOnly.get()) 34 | relocate("org.jetbrains.org.objectweb.asm.tree.analysis", "org.objectweb.asm.tree.analysis") 35 | relocate("org.jetbrains.kotlin.com.intellij", "com.intellij") 36 | dependencies { 37 | exclude(dependency("org.jetbrains.kotlin:kotlin-stdlib")) 38 | // and its transitive dependencies: 39 | exclude(dependency("org.jetbrains.kotlin:kotlin-stdlib-common")) 40 | exclude(dependency("org.jetbrains:annotations")) 41 | 42 | exclude(dependency("com.intellij:openapi")) 43 | // and its transitive dependencies: 44 | exclude(dependency("com.intellij:extensions")) 45 | exclude(dependency("com.intellij:annotations")) 46 | } 47 | } 48 | 49 | artifacts { 50 | add(shadowArtifact.name, shadowJar) 51 | } 52 | -------------------------------------------------------------------------------- /differentiable-api-preprocessor-compiler-plugin/src/main/kotlin/diffPrep/DiffPrepClassLifterDelegate.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package diffPrep 9 | 10 | import adOptimizeCommon.reverseNodeNameFromOperationsName 11 | import diffPrep.metadata.DifferentiableApi 12 | import org.jetbrains.kotlin.ir.declarations.IrClass 13 | import org.jetbrains.kotlin.ir.declarations.IrSimpleFunction 14 | import org.jetbrains.kotlin.ir.types.isSubtypeOfClass 15 | import org.jetbrains.kotlin.ir.util.defaultType 16 | import org.jetbrains.kotlin.ir.util.kotlinFqName 17 | import pluginCommon.generators.ClassFunctionAttributes 18 | import pluginCommon.generators.overrideRoot 19 | import pluginCommon.lowerings.ClassLifterDelegate 20 | 21 | class DiffPrepClassLifterDelegate( 22 | val differentiableApi: DifferentiableApi 23 | ) : ClassLifterDelegate() { 24 | override fun shouldLiftClass(clazz: IrClass): Boolean { 25 | return clazz.parent is IrSimpleFunction && clazz.defaultType.isSubtypeOfClass(differentiableApi.reverseDifferentiableScalar.clazz.symbol) 26 | } 27 | 28 | override fun liftedClassName(originalClass: IrClass): String = reverseNodeNameFromOperationsName(originalClass.parent.kotlinFqName.toString()) 29 | 30 | override fun customizeCopyOfMethod(oldMethod: IrSimpleFunction): ClassFunctionAttributes { 31 | val attributes = ClassFunctionAttributes(oldMethod) 32 | if (oldMethod.overrideRoot() == differentiableApi.reverseDifferentiableScalar.backpropMethod.overrideRoot()) { 33 | attributes.isInline = true 34 | } 35 | return attributes 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /differentiable-api-preprocessor-compiler-plugin/src/main/kotlin/diffPrep/DiffPrepErrorMessagesExtension.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package diffPrep 9 | 10 | import org.jetbrains.kotlin.diagnostics.rendering.DefaultErrorMessages 11 | import org.jetbrains.kotlin.diagnostics.rendering.DiagnosticFactoryToRendererMap 12 | import org.jetbrains.kotlin.diagnostics.rendering.Renderers 13 | 14 | class DiffPrepErrorMessagesExtension : DefaultErrorMessages.Extension { 15 | private val _map: DiagnosticFactoryToRendererMap by lazy { 16 | val renderMap = DiagnosticFactoryToRendererMap() 17 | renderMap.put(Errors.NO_UNBOXEDFUNCTION_FOUND, "The annotation references a function with whose signature is not compatible: {0}", Renderers.TO_STRING) 18 | renderMap.put(Errors.ANNOTATION_REFERENCES_UNRESOLVED_DECLARATIONS, "The annotation references unresolved declarations: {0}", Renderers.TO_STRING) 19 | renderMap.put(Errors.INVALID_SIGNATURE_ANNOTATION, "The annotation does not have the expected members: {0}", Renderers.TO_STRING) 20 | renderMap 21 | } 22 | 23 | override fun getMap(): DiagnosticFactoryToRendererMap { 24 | return _map 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /differentiable-api-preprocessor-compiler-plugin/src/main/kotlin/diffPrep/DiffPrepErrors.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package diffPrep 9 | 10 | import org.jetbrains.kotlin.diagnostics.DiagnosticFactory1 11 | import org.jetbrains.kotlin.diagnostics.Severity 12 | import org.jetbrains.kotlin.psi.KtAnnotationEntry 13 | import org.jetbrains.kotlin.psi.KtClass 14 | 15 | class Errors { 16 | companion object { 17 | val NO_UNBOXEDFUNCTION_FOUND = DiagnosticFactory1.create(Severity.ERROR).also { it.initializeName("NO_UNBOXEDFUNCTION_FOUND") } 18 | val INVALID_SIGNATURE_ANNOTATION: DiagnosticFactory1 = DiagnosticFactory1.create(Severity.ERROR).also { it.initializeName("INVALID_SIGNATURE_ANNOTATION") } 19 | val ANNOTATION_REFERENCES_UNRESOLVED_DECLARATIONS: DiagnosticFactory1 = DiagnosticFactory1.create(Severity.ERROR).also { it.initializeName("ANNOTATION_REFERENCES_UNRESOLVED_DECLARATIONS") } 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /differentiable-api-preprocessor-compiler-plugin/src/main/kotlin/diffPrep/DifferentiableApiPreprocessorCompilerConfigurationExtension.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package diffPrep 9 | 10 | import org.jetbrains.kotlin.config.CompilerConfiguration 11 | import org.jetbrains.kotlin.config.JVMConfigurationKeys 12 | import org.jetbrains.kotlin.config.JvmSerializeIrMode 13 | import org.jetbrains.kotlin.extensions.CompilerConfigurationExtension 14 | 15 | class DifferentiableApiPreprocessorCompilerConfigurationExtension : CompilerConfigurationExtension { 16 | override fun updateConfiguration(configuration: CompilerConfiguration) { 17 | configuration.put(JVMConfigurationKeys.IR, true) 18 | configuration.put(JVMConfigurationKeys.SERIALIZE_IR, JvmSerializeIrMode.INLINE) 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /differentiable-api-preprocessor-compiler-plugin/src/main/kotlin/diffPrep/ErrorHandling.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package diffPrep 9 | 10 | class DiffApiPrepException(message: String) : Exception(message) 11 | -------------------------------------------------------------------------------- /differentiable-api-preprocessor-compiler-plugin/src/main/kotlin/diffPrep/metadata/BoxedPrimitiveInfo.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package diffPrep.metadata 9 | 10 | import org.jetbrains.kotlin.ir.declarations.IrClass 11 | import org.jetbrains.kotlin.ir.declarations.IrProperty 12 | import org.jetbrains.kotlin.ir.types.IrType 13 | import org.jetbrains.kotlin.ir.util.IrMessageLogger 14 | import org.jetbrains.kotlin.ir.util.companionObject 15 | import org.jetbrains.kotlin.ir.util.properties 16 | 17 | class BoxedPrimitiveInfo( 18 | val boxedPrimitiveClass: IrClass, 19 | val valueProperty: IrProperty, 20 | val primitiveType: IrType, 21 | val companionZero: IrProperty, 22 | val companionOne: IrProperty, 23 | messageLogger: IrMessageLogger 24 | ) { 25 | init { 26 | if (!boxedPrimitiveClass.properties.contains(valueProperty)) { 27 | messageLogger.report(IrMessageLogger.Severity.WARNING, "The value property must be a property of the boxedPrimitive.", null) 28 | } 29 | val companionZeroIsInCompanion = boxedPrimitiveClass.companionObject()?.let { companion -> 30 | companion.properties.contains(companionZero) 31 | } ?: false 32 | if (!companionZeroIsInCompanion) { 33 | messageLogger.report(IrMessageLogger.Severity.WARNING, "The zero property must be a property of the boxedPrimitive's companion.", null) 34 | } 35 | val companionOneIsInCompanion = boxedPrimitiveClass.companionObject()?.let { companion -> 36 | companion.properties.contains(companionOne) 37 | } ?: false 38 | if (!companionOneIsInCompanion) { 39 | messageLogger.report(IrMessageLogger.Severity.WARNING, "The one property must be a property of the boxedPrimitive's companion.", null) 40 | } 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /differentiable-api-preprocessor-compiler-plugin/src/main/kotlin/diffPrep/metadata/DifferentiableApi.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package diffPrep.metadata 9 | 10 | import org.jetbrains.kotlin.ir.declarations.IrClass 11 | import org.jetbrains.kotlin.ir.declarations.IrProperty 12 | import org.jetbrains.kotlin.ir.declarations.IrSimpleFunction 13 | 14 | class ReverseDifferentiableScalarMetadata( 15 | val clazz: IrClass, 16 | val upstreamProperty: IrProperty, 17 | val backpropMethod: IrSimpleFunction, 18 | val pushbackMethod: IrSimpleFunction, 19 | ) 20 | 21 | class ForwardDifferentiableScalarMetadata( 22 | val clazz: IrClass, 23 | val tangentProperty: IrProperty 24 | ) 25 | 26 | class DifferentiableApi( 27 | val reverseDifferentiableScalar: ReverseDifferentiableScalarMetadata, 28 | val forwardDifferentiableScalar: ForwardDifferentiableScalarMetadata, 29 | val derivativeId: IrProperty, 30 | val primalProperty: IrProperty, 31 | val scalarPlusFunction: IrSimpleFunction, 32 | val tensorPlusFunction: IrSimpleFunction, 33 | val scalarRoot: IrClass, 34 | val primalAndPullbackFunction: IrSimpleFunction, 35 | val boxedPrimitiveInfo: BoxedPrimitiveInfo, 36 | val dTensorRoot: IrClass, 37 | val reverseOperations: IrClass, 38 | val stackClass: StackClass, 39 | val scalarNoopClass: IrClass 40 | ) 41 | -------------------------------------------------------------------------------- /differentiable-api-preprocessor-compiler-plugin/src/main/kotlin/diffPrep/metadata/StackClass.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package diffPrep.metadata 9 | 10 | import org.jetbrains.kotlin.ir.declarations.IrClass 11 | import org.jetbrains.kotlin.ir.declarations.IrFunction 12 | 13 | class StackClass(val clazz: IrClass, val popMethod: IrFunction, val pushMethod: IrFunction, val notEmptyMethod: IrFunction, val topMethod: IrFunction) 14 | -------------------------------------------------------------------------------- /differentiable-api-preprocessor-compiler-plugin/src/main/resources/META-INF/services/org.jetbrains.kotlin.compiler.plugin.CommandLineProcessor: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | diffPrep.DifferentiableApiPreprocessorCommandLineProcessor -------------------------------------------------------------------------------- /differentiable-api-preprocessor-compiler-plugin/src/main/resources/META-INF/services/org.jetbrains.kotlin.compiler.plugin.ComponentRegistrar: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | diffPrep.DifferentiableApiPreprocessorComponentRegistrar -------------------------------------------------------------------------------- /differentiable-api-preprocessor-gradle-plugin/build.gradle.kts: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | description = "differentiable api gradle plugin" 9 | 10 | plugins { 11 | id("java-gradle-plugin") 12 | } 13 | 14 | dependencies { 15 | val ktVersion: String by System.getProperties() 16 | compileOnly("org.jetbrains.kotlin:kotlin-stdlib:$ktVersion") 17 | compileOnly(kotlin("gradle-plugin-api")) 18 | implementation(project(":config")) 19 | } 20 | 21 | group = "org.meta.diffkt.adoptimize" 22 | // generate plugin descriptors in the resulting JAR's META-INF directory 23 | gradlePlugin { 24 | plugins { 25 | create("meta-diffkt-differentiable-api-preprocessor-gradle-plugin") { 26 | id = "meta-diffkt-differentiable-api-preprocessor" 27 | implementationClass = "diffPrep.gradle.DifferentiableApiPreprocessorGradleSubPlugin" 28 | } 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /differentiable-api-preprocessor-gradle-plugin/src/main/kotlin/diffPrep/gradle/DifferentiableApiPreprocessorExtension.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package diffPrep.gradle 9 | 10 | open class DifferentiableApiPreprocessorExtension { 11 | var reverse = "" 12 | var stackImpl = "" 13 | var primalAndPullback = "" 14 | var boxedPrimitive = "" 15 | var unboxedFunction = "" 16 | var scalarRoot = "" 17 | var resourcesPath = "" 18 | var toReverse = "" 19 | var dTensor = "" 20 | var reverseScalarOperations = "" 21 | var scalarNoop = "" 22 | var forwardDifferentiable = "" 23 | 24 | open fun reverseAnnotation(customReverseAnnotation: String) { 25 | this.reverse = customReverseAnnotation 26 | } 27 | 28 | open fun stackImplAnnotation(customStackImpl: String) { 29 | this.stackImpl = customStackImpl 30 | } 31 | 32 | open fun primalAndPullbackAnnotation(customPrimalAndPullback: String) { 33 | this.primalAndPullback = customPrimalAndPullback 34 | } 35 | 36 | open fun boxedPrimitive(customBoxedPrimitive: String) { 37 | this.boxedPrimitive = customBoxedPrimitive 38 | } 39 | 40 | open fun unboxedFunction(customUnboxedFunction: String) { 41 | this.unboxedFunction = customUnboxedFunction 42 | } 43 | 44 | open fun scalarRoot(customScalarRoot: String) { 45 | this.scalarRoot = customScalarRoot 46 | } 47 | 48 | open fun resourcesPath(resourcesPath: String) { 49 | this.resourcesPath = resourcesPath 50 | } 51 | 52 | open fun toReverseAnnotation(toReverseAnnotation: String) { 53 | this.toReverse = toReverseAnnotation 54 | } 55 | 56 | open fun dTensorAnnotation(dTensorAnnotation: String) { 57 | this.dTensor = dTensorAnnotation 58 | } 59 | 60 | open fun reverseScalarOperationsAnnotation(operationsAnnotationFqn: String) { 61 | this.reverseScalarOperations = operationsAnnotationFqn 62 | } 63 | 64 | open fun scalarNoop(scalarNoopFqn: String) { 65 | this.scalarNoop = scalarNoopFqn 66 | } 67 | 68 | open fun forwardDifferentiable(forwardDifferentiableFqn: String) { 69 | this.forwardDifferentiable = forwardDifferentiableFqn 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /differentiable-api-preprocessor-gradle-plugin/src/main/kotlin/diffPrep/gradle/DifferentiableApiPreprocessorGradleExtension.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package diffPrep.gradle 9 | 10 | import config.BuildConfig 11 | import org.gradle.api.Project 12 | import org.gradle.api.provider.Provider 13 | import org.gradle.tooling.provider.model.ToolingModelBuilderRegistry 14 | import org.jetbrains.kotlin.gradle.plugin.KotlinCompilation 15 | import org.jetbrains.kotlin.gradle.plugin.KotlinCompilerPluginSupportPlugin 16 | import org.jetbrains.kotlin.gradle.plugin.SubpluginArtifact 17 | import org.jetbrains.kotlin.gradle.plugin.SubpluginOption 18 | import javax.inject.Inject 19 | 20 | class DifferentiableApiPreprocessorGradleSubPlugin @Inject internal constructor(private val registry: ToolingModelBuilderRegistry) : KotlinCompilerPluginSupportPlugin { 21 | 22 | override fun apply(target: Project) { 23 | target.extensions.create("differentiableApiPreprocessor", DifferentiableApiPreprocessorExtension::class.java) 24 | } 25 | 26 | override fun isApplicable(kotlinCompilation: KotlinCompilation<*>): Boolean = true 27 | 28 | override fun applyToCompilation(kotlinCompilation: KotlinCompilation<*>): Provider> { 29 | val project = kotlinCompilation.target.project 30 | 31 | val differentiableApiPreprocessor = project.extensions.getByType(DifferentiableApiPreprocessorExtension::class.java) 32 | 33 | return project.provider { 34 | val options = mutableListOf() 35 | options += SubpluginOption("reverse", differentiableApiPreprocessor.reverse) 36 | options += SubpluginOption("stackImpl", differentiableApiPreprocessor.stackImpl) 37 | options += SubpluginOption("primalAndPullback", differentiableApiPreprocessor.primalAndPullback) 38 | options += SubpluginOption("boxedPrimitive", differentiableApiPreprocessor.boxedPrimitive) 39 | options += SubpluginOption("unboxedFunction", differentiableApiPreprocessor.unboxedFunction) 40 | options += SubpluginOption("scalarRoot", differentiableApiPreprocessor.scalarRoot) 41 | options += SubpluginOption("resourcesPath", differentiableApiPreprocessor.resourcesPath) 42 | options += SubpluginOption("toReverseNode", differentiableApiPreprocessor.toReverse) 43 | options += SubpluginOption("DTensorRoot", differentiableApiPreprocessor.dTensor) 44 | options += SubpluginOption("reverseOperations", differentiableApiPreprocessor.reverseScalarOperations) 45 | options += SubpluginOption("scalarNoop", differentiableApiPreprocessor.scalarNoop) 46 | options += SubpluginOption("forward", differentiableApiPreprocessor.forwardDifferentiable) 47 | options 48 | } 49 | } 50 | 51 | // we assume the gradle plugin id and the compiler plugin id are the same. 52 | override fun getCompilerPluginId() = BuildConfig.DIFFPREP_ID 53 | 54 | override fun getPluginArtifact(): SubpluginArtifact = 55 | SubpluginArtifact(groupId = BuildConfig.PLUGIN_GROUP, artifactId = BuildConfig.DIFFPREP_ARTIFACT_ID, version = BuildConfig.PLUGIN_VERSION) 56 | } 57 | -------------------------------------------------------------------------------- /differentiable-api-preprocessor-integration-tests/build.gradle.kts: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | repositories { 9 | kotlinBuildLocalRepo(project) 10 | } 11 | 12 | val platformDependencies: Array by rootProject.extra 13 | val testDependencies: Array by rootProject.extra 14 | val testRuntimeDependencies: Array by rootProject.extra 15 | val kotlinStd by configurations.creating 16 | val coreDependencies: Array by rootProject.extra 17 | 18 | dependencies { 19 | val ktVersion: String by System.getProperties() 20 | val intellijVersion: String by System.getProperties() 21 | 22 | kotlinStd("org.jetbrains.kotlin:kotlin-stdlib:$ktVersion") 23 | testImplementation(project(":differentiable-api-preprocessor-compiler-plugin", "shadowArtifact")) 24 | testImplementation("org.jetbrains.kotlin:kotlin-compiler-internal-test-framework:$ktVersion") 25 | testImplementation("org.jetbrains.kotlin:kotlin-scripting-compiler:$ktVersion") 26 | testImplementation("one.util:streamex:0.7.3") 27 | testImplementation(project(":adoptimize-common")) 28 | 29 | testDependencies.forEach { 30 | testImplementation(it) 31 | } 32 | 33 | testRuntimeDependencies.forEach { 34 | testRuntimeOnly(it) 35 | } 36 | 37 | platformDependencies.forEach { 38 | testImplementation("com.jetbrains.intellij.platform:$it:$intellijVersion") 39 | } 40 | 41 | coreDependencies.forEach { artifactName -> 42 | testImplementation("kotlin.build:intellij-core:$intellijVersion") { 43 | artifact { 44 | name = artifactName 45 | type = "jar" 46 | extension = "jar" 47 | } 48 | } 49 | } 50 | } 51 | 52 | val testArtifact by configurations.creating 53 | val testJar = tasks.register("testJar") { 54 | val convention = project.convention.getPlugin() 55 | archiveClassifier.set("sources") 56 | from(convention.sourceSets.test.get().output) 57 | } 58 | 59 | artifacts { 60 | add(testArtifact.name, testJar) 61 | } 62 | 63 | projectTest { 64 | workingDir = rootDir 65 | useJUnitPlatform() 66 | } 67 | -------------------------------------------------------------------------------- /differentiable-api-preprocessor-integration-tests/src/test/kotlin/diffPrep/AbstractDifferentiablePreprocessorBlackBoxTest.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package diffPrep 9 | 10 | import org.jetbrains.kotlin.test.builders.TestConfigurationBuilder 11 | import org.jetbrains.kotlin.test.runners.codegen.AbstractIrBlackBoxCodegenTest 12 | 13 | abstract class AbstractDifferentiablePreprocessorBlackBoxTest : AbstractIrBlackBoxCodegenTest(), DifferentiablePreprocessorBaseTest { 14 | override fun configure(builder: TestConfigurationBuilder) { 15 | super.configure(builder) 16 | with(builder) { 17 | useConfigurators({ ts -> DifferentiablePreprocessorConfigurator(ts, ::resourcesDirectoryFromTestFile) }) 18 | } 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /differentiable-api-preprocessor-integration-tests/src/test/kotlin/diffPrep/AbstractDifferentiablePreprocessorIrTest.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package diffPrep 9 | 10 | import org.jetbrains.kotlin.test.builders.TestConfigurationBuilder 11 | import org.jetbrains.kotlin.test.runners.ir.AbstractIrTextTest 12 | 13 | abstract class AbstractDifferentiablePreprocessorIrTest : AbstractIrTextTest(), DifferentiablePreprocessorBaseTest { 14 | override fun configure(builder: TestConfigurationBuilder) { 15 | super.configure(builder) 16 | with(builder) { 17 | useConfigurators({ ts -> DifferentiablePreprocessorConfigurator(ts, ::resourcesDirectoryFromTestFile) }) 18 | } 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /differentiable-api-preprocessor-integration-tests/src/test/kotlin/diffPrep/AbstractDifferentiablePrepropessorDiagnosticTests.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package diffPrep 9 | 10 | import org.jetbrains.kotlin.test.builders.TestConfigurationBuilder 11 | import org.jetbrains.kotlin.test.runners.AbstractDiagnosticTest 12 | import java.io.File 13 | 14 | abstract class AbstractDifferentiablePrepropessorDiagnosticTests : AbstractDiagnosticTest() { 15 | override fun configure(builder: TestConfigurationBuilder) { 16 | super.configure(builder) 17 | with(builder) { 18 | useConfigurators({ ts -> DifferentiablePreprocessorConfigurator(ts, { testFile: File -> testFile.resolveSibling("resources") }) }) 19 | } 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /differentiable-api-preprocessor-integration-tests/src/test/kotlin/diffPrep/DifferentiablePreprocessorBaseTest.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package diffPrep 9 | 10 | import java.io.File 11 | 12 | interface DifferentiablePreprocessorBaseTest { 13 | val homeDir: String 14 | 15 | fun srcProjectRoot() = File(System.getProperty("user.dir"), homeDir) 16 | 17 | fun resourcesDirectoryFromTestFile(testFile: File) = testFile.resolveSibling("resources") 18 | } 19 | -------------------------------------------------------------------------------- /differentiable-api-preprocessor-integration-tests/src/test/kotlin/diffPrep/DifferentiablePreprocessorDiagnosticTest.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package diffPrep 9 | 10 | import org.junit.jupiter.api.Test 11 | 12 | class DifferentiablePreprocessorDiagnosticTest : AbstractDifferentiablePrepropessorDiagnosticTests() { 13 | val homeDir = "differentiable-api-preprocessor-integration-tests/src/test/testData/diagnostics" 14 | 15 | @Test 16 | fun testValidApi() { 17 | runTest("$homeDir/validApi.kt") 18 | } 19 | 20 | @Test 21 | fun testToUnBoxFunctionInvalidSignature() { 22 | runTest("$homeDir/toUnBoxFunctionInvalidSignature.kt") 23 | } 24 | 25 | @Test 26 | fun testToUnboxClassMethods() { 27 | runTest("$homeDir/toUnboxClassMethods.kt") 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /differentiable-api-preprocessor-integration-tests/src/test/kotlin/diffPrep/DifferentiablePreprocessorIrTest.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package diffPrep 9 | 10 | import org.junit.jupiter.api.Test 11 | 12 | class DifferentiablePreprocessorIrTest : AbstractDifferentiablePreprocessorIrTest() { 13 | override val homeDir: String = "differentiable-api-preprocessor-integration-tests/src/test/testData/ir" 14 | 15 | @Test 16 | fun testApi() { 17 | runTest("$homeDir/api.kt") 18 | } 19 | 20 | @Test 21 | fun testTypeOperator() { 22 | runTest("$homeDir/typeOperator.kt") 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /differentiable-api-preprocessor-integration-tests/src/test/kotlin/diffPrep/DifferentiablePreprocessorWithTempDirectoryTest.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package diffPrep 9 | 10 | import org.junit.jupiter.api.Test 11 | import java.io.File 12 | 13 | class DifferentiablePreprocessorWithTempDirectoryTest : AbstractDifferentiablePreprocessorBlackBoxTest() { 14 | override val homeDir: String = "differentiable-api-preprocessor-integration-tests/src/test/testData/withTmpDir" 15 | 16 | @Test 17 | fun singleModule() { 18 | runAdOptimizePropertiesTest("singleModule/validApi.kt") 19 | } 20 | 21 | @Test 22 | fun multiModule() { 23 | runAdOptimizePropertiesTest("multiModule/validApi.kt") 24 | } 25 | 26 | private fun runAdOptimizePropertiesTest(testPath: String) { 27 | val relativePath = "$homeDir/$testPath" 28 | val testFile = File(srcProjectRoot(), testPath) 29 | runTest(relativePath) 30 | val adOptimizePropertiesFile = File(resourcesDirectoryFromTestFile(testFile), adOptimizeCommon.propertiesFileName) 31 | assert(adOptimizePropertiesFile.exists(), { "The properties file was not written to the source directory" }) 32 | val contents = adOptimizePropertiesFile.readText() 33 | val expectation = """ 34 | reverseClass=demo.ReverseNode 35 | forwardScalarClass=demo.ForwardNode 36 | primalProperty=primal 37 | upstreamProperty=upstream 38 | tangentProperty=tangent 39 | backpropMethod=backpropogate 40 | pushbackMethod=pushback 41 | derivativeId=derivativeID 42 | scalarRoot=demo.DifferentiableDouble 43 | primalAndPullbackFunction=demo.primalAndPullback 44 | boxedPrimitive=demo.DDouble 45 | valueProperty=value 46 | primitiveType=Double 47 | stackImpl=demo.Stack 48 | scalarPlusFunction=demo.plus 49 | tensorPlusFunction=demo.plus 50 | toUnboxFunction=demo.ToUnboxedFunction 51 | toReverse=demo.ToReverse 52 | dTensor=demo.DiffTensor 53 | reverseOperations=demo.ReverseScalarOperations 54 | scalarZero=demo.DDouble.Companion.ZERO 55 | scalarOne=demo.DDouble.Companion.ONE 56 | scalarNoop=demo.ScalarNoop 57 | 58 | """.trimIndent() 59 | val diff = expectation.compareTo(contents) 60 | assert(diff == 0, { "The properties file did not contain the expected contents. Expected \n`$expectation` but got \n`$contents`" }) 61 | adOptimizePropertiesFile.delete() 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /differentiable-api-preprocessor-integration-tests/src/test/testData/diagnostics/toUnboxClassMethods.kt: -------------------------------------------------------------------------------- 1 | // WITH_RUNTIME 2 | // !DIAGNOSTICS: -UNUSED_PARAMETER -DEBUG_INFO_MISSING_UNRESOLVED -DEBUG_INFO_LEAKING_THIS 3 | // SKIP_TXT 4 | // FIR_IDENTICAL 5 | package demo 6 | 7 | annotation class ReverseDifferentiable( 8 | val primalField:String, 9 | val upstreamField:String, 10 | val backpropogateMethod:String, 11 | val pushbackMethod:String, 12 | val derivativeID:String) 13 | 14 | annotation class ScalarRoot 15 | annotation class PrimalAndPullback 16 | annotation class StackImpl 17 | annotation class BoxedPrimitive(val valueField:String) 18 | annotation class ToUnboxedFunction(val functionName:String) 19 | annotation class DTensorRoot 20 | annotation class ScalarNoop 21 | annotation class ForwardDifferentiable(val tangentProperty:String) 22 | 23 | @DTensorRoot 24 | class DiffTensor 25 | 26 | open class DerivativeID(private val seq:Int) : Comparable { 27 | override fun compareTo(other: DerivativeID): Int { 28 | return this.seq.compareTo(other.seq) 29 | } 30 | } 31 | 32 | class ReverseDerivativeID(s:Int) : DerivativeID(s) { 33 | val backpropogateWorkList: java.util.Stack = java.util.Stack() 34 | } 35 | 36 | val zeroDerivativeID = DerivativeID(0) 37 | 38 | @ScalarRoot 39 | sealed class DifferentiableDouble { 40 | abstract val primal:DifferentiableDouble 41 | open fun value():Double = primal.value() 42 | abstract val derivativeID: DerivativeID 43 | fun zero():DifferentiableDouble = DDouble(0.0) 44 | } 45 | 46 | @BoxedPrimitive("value") 47 | class DDouble(val value:Double) : DifferentiableDouble() { 48 | override val derivativeID = zeroDerivativeID 49 | override val primal: DifferentiableDouble = this 50 | override fun value(): Double = this.value 51 | } 52 | 53 | @ReverseDifferentiable("primal", "upstream", "backpropogate", "pushback", "derivativeID") 54 | abstract class ReverseNode(d: ReverseDerivativeID) : DifferentiableDouble() { 55 | var upstream:DifferentiableDouble = DDouble(0.0) 56 | abstract override val derivativeID:ReverseDerivativeID 57 | abstract fun backpropogate() 58 | open override val primal: DifferentiableDouble = DDouble(0.0) 59 | fun pushback(value:DifferentiableDouble) {} 60 | } 61 | 62 | @ForwardDifferentiable("tangent") 63 | class ForwardNode(d: DerivativeID, override val primal: DifferentiableDouble, val tangent:DifferentiableDouble) : DifferentiableDouble() { 64 | override val derivativeID: DerivativeID = d 65 | } 66 | 67 | @StackImpl 68 | class StackForCompiler { 69 | fun pop():T {TODO()} 70 | fun push(d:T){} 71 | fun top():T {TODO()} 72 | fun notEmpty():Boolean {TODO()} 73 | } 74 | 75 | @PrimalAndPullback 76 | fun primalAndPullback(operand:DifferentiableDouble, operator:(DifferentiableDouble) -> DifferentiableDouble):DifferentiableDouble {TODO()} 77 | 78 | @ToUnboxedFunction("kotlin.Double.plus") 79 | operator fun DifferentiableDouble.plus(other:DifferentiableDouble):DifferentiableDouble {TODO()} 80 | 81 | class DoesNotMatter { 82 | fun doSomethingToDoubles(a:Double):Double {TODO()} 83 | @ToUnboxedFunction("demo.DoesNotMatter.doSomethingToDoubles") 84 | fun doSomething(a:DifferentiableDouble):DifferentiableDouble {TODO()} 85 | 86 | @ToUnboxedFunction("demo.DoesNotMatter.referencedMismatchParameterCount") 87 | fun referencesMismatchParameterCount(x:DifferentiableDouble, z:Double):DifferentiableDouble{TODO()} 88 | fun referencedMismatchParameterCount():Double {TODO()} 89 | } 90 | 91 | -------------------------------------------------------------------------------- /differentiable-api-preprocessor-integration-tests/src/test/testData/diagnostics/validApi.kt: -------------------------------------------------------------------------------- 1 | // WITH_RUNTIME 2 | // !DIAGNOSTICS: -UNUSED_PARAMETER -DEBUG_INFO_MISSING_UNRESOLVED -DEBUG_INFO_LEAKING_THIS 3 | // SKIP_TXT 4 | // FIR_IDENTICAL 5 | package demo 6 | 7 | annotation class ReverseDifferentiable( 8 | val primalField: String, 9 | val upstreamField: String, 10 | val backpropogateMethod: String, 11 | val pushbackMethod: String, 12 | val derivativeID: String 13 | ) 14 | 15 | annotation class ScalarRoot 16 | annotation class PrimalAndPullback 17 | annotation class StackImpl 18 | annotation class BoxedPrimitive(val valueField: String) 19 | annotation class ToUnboxedFunction(val functionName: String) 20 | annotation class DTensorRoot 21 | annotation class ToReverse(val fqClass: String) 22 | annotation class ScalarNoop 23 | annotation class ForwardDifferentiable(val tangentProperty: String) 24 | 25 | @DTensorRoot 26 | class DiffTensor 27 | operator fun DiffTensor.plus(other: DiffTensor): DiffTensor { TODO() } 28 | 29 | open class DerivativeID(private val seq: Int) : Comparable { 30 | override fun compareTo(other: DerivativeID): Int { 31 | return this.seq.compareTo(other.seq) 32 | } 33 | } 34 | 35 | class ReverseDerivativeID(s: Int) : DerivativeID(s) { 36 | val backpropogateWorkList: java.util.Stack = java.util.Stack() 37 | } 38 | 39 | val zeroDerivativeID = DerivativeID(0) 40 | 41 | @ScalarRoot 42 | sealed class DifferentiableDouble { 43 | abstract val primal: DifferentiableDouble 44 | open fun value(): Double = primal.value() 45 | abstract val derivativeID: DerivativeID 46 | fun zero(): DifferentiableDouble = DDouble(0.0) 47 | } 48 | 49 | @BoxedPrimitive("value") 50 | class DDouble(val value: Double) : DifferentiableDouble() { 51 | override val derivativeID = zeroDerivativeID 52 | override val primal: DifferentiableDouble = this 53 | override fun value(): Double = this.value 54 | companion object { 55 | val ZERO = DDouble(0.0) 56 | } 57 | } 58 | 59 | @ReverseDifferentiable("primal", "upstream", "backpropogate", "pushback", "derivativeID") 60 | abstract class ReverseNode(d: ReverseDerivativeID) : DifferentiableDouble() { 61 | var upstream: DifferentiableDouble = DDouble(0.0) 62 | abstract override val derivativeID: ReverseDerivativeID 63 | abstract fun backpropogate() 64 | open override val primal: DifferentiableDouble = DDouble(0.0) 65 | fun pushback(value: DifferentiableDouble) {} 66 | } 67 | 68 | @ForwardDifferentiable("tangent") 69 | class ForwardNode(d: DerivativeID, override val primal: DifferentiableDouble, val tangent: DifferentiableDouble) : DifferentiableDouble() { 70 | override val derivativeID: DerivativeID = d 71 | } 72 | 73 | @StackImpl 74 | class StackForCompiler { 75 | fun pop(): T { TODO() } 76 | fun push(d: T) {} 77 | fun top(): T { TODO() } 78 | fun notEmpty(): Boolean { TODO() } 79 | } 80 | 81 | operator fun DifferentiableDouble.plus(other: DifferentiableDouble): DifferentiableDouble { TODO() } 82 | 83 | @PrimalAndPullback 84 | fun primalAndPullback(operand: DifferentiableDouble, operator: (DifferentiableDouble) -> DifferentiableDouble): DifferentiableDouble { TODO() } 85 | -------------------------------------------------------------------------------- /differentiable-api-preprocessor-integration-tests/src/test/testData/ir/resources/adoptimize.properties: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | reverseClass=demo.ReverseNode 9 | forwardScalarClass=demo.ForwardNode 10 | primalProperty=primal 11 | upstreamProperty=upstream 12 | tangentProperty=tangent 13 | backpropMethod=backpropogate 14 | pushbackMethod=pushback 15 | derivativeId=derivativeID 16 | scalarRoot=demo.DifferentiableDouble 17 | primalAndPullbackFunction=demo.primalAndPullback 18 | boxedPrimitive=demo.DDouble 19 | valueProperty=value 20 | primitiveType=Double 21 | stackImpl=demo.StackForCompiler 22 | scalarPlusFunction=demo.plus 23 | tensorPlusFunction=demo.plus 24 | toUnboxFunction=demo.ToUnboxedFunction 25 | toReverse=demo.ToReverse 26 | dTensor=demo.DiffTensor 27 | reverseOperations=demo.ReverseScalarOperations 28 | scalarZero=demo.DDouble.Companion.ZERO 29 | scalarOne=demo.DDouble.Companion.ONE 30 | scalarNoop=demo.ScalarNoop 31 | -------------------------------------------------------------------------------- /differentiable-api-preprocessor-publish/build.gradle.kts: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | import com.github.jengelman.gradle.plugins.shadow.tasks.ShadowJar 9 | description = "Differentiable API preprocessor cli compiler plugin" 10 | 11 | plugins { 12 | id("com.github.johnrengelman.shadow") version "6.1.0" 13 | } 14 | 15 | dependencies { 16 | val ktVersion: String by System.getProperties() 17 | compileOnly("org.jetbrains.kotlin:kotlin-stdlib:$ktVersion") 18 | api("org.jetbrains.kotlin:kotlin-compiler-embeddable:$ktVersion") 19 | compileOnly(project(":differentiable-api-preprocessor-compiler-plugin")) 20 | compileOnly(project(":plugin-generators-common")) 21 | compileOnly(project(":adoptimize-common")) 22 | } 23 | 24 | val fatJarArtifact by configurations.creating 25 | val shadowFatJar: ShadowJar = tasks.getByName("shadowJar") { 26 | val convention = project.convention.getPlugin() 27 | from(convention.sourceSets.main.get().output) 28 | archiveClassifier.set("") 29 | configurations = mutableListOf(project.configurations.compileOnly.get()) 30 | dependencies { 31 | include(project(":differentiable-api-preprocessor-compiler-plugin")) 32 | include(project(":plugin-generators-common")) 33 | include(project(":adoptimize-common")) 34 | } 35 | } 36 | 37 | val publishArtifact = artifacts.add(fatJarArtifact.name, shadowFatJar) 38 | 39 | publishing { 40 | publications { 41 | create("diffkt-diffPrep-publishing") { 42 | artifactId = System.getProperty("diffPrepCompilerPluginArtifactID") 43 | artifact(publishArtifact) 44 | } 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /gradle.properties: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | systemProp.junit4Version = 4.12 9 | systemProp.junitPlatformLauncherVersion = 1.8.0-M1 10 | systemProp.junitJupiterVersion=5.7.1 11 | systemProp.trove4jVersion=1.0.20200330 12 | 13 | systemProp.group=facebook 14 | systemProp.version=0.0.1-SNAPSHOT 15 | 16 | 17 | -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/optimizer-plugins/0695dc024d4a2f6eaf0559026efcda2a66b5e810/gradle/wrapper/gradle-wrapper.jar -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.properties: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | distributionBase=GRADLE_USER_HOME 9 | distributionPath=wrapper/dists 10 | distributionUrl=https\://services.gradle.org/distributions/gradle-7.2-bin.zip 11 | zipStoreBase=GRADLE_USER_HOME 12 | zipStorePath=wrapper/dists 13 | -------------------------------------------------------------------------------- /gradlew.bat: -------------------------------------------------------------------------------- 1 | @rem 2 | @rem Copyright 2015 the original author or authors. 3 | @rem 4 | @rem Licensed under the Apache License, Version 2.0 (the "License"); 5 | @rem you may not use this file except in compliance with the License. 6 | @rem You may obtain a copy of the License at 7 | @rem 8 | @rem https://www.apache.org/licenses/LICENSE-2.0 9 | @rem 10 | @rem Unless required by applicable law or agreed to in writing, software 11 | @rem distributed under the License is distributed on an "AS IS" BASIS, 12 | @rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | @rem See the License for the specific language governing permissions and 14 | @rem limitations under the License. 15 | @rem 16 | 17 | @if "%DEBUG%" == "" @echo off 18 | @rem ########################################################################## 19 | @rem 20 | @rem Gradle startup script for Windows 21 | @rem 22 | @rem ########################################################################## 23 | 24 | @rem Set local scope for the variables with windows NT shell 25 | if "%OS%"=="Windows_NT" setlocal 26 | 27 | set DIRNAME=%~dp0 28 | if "%DIRNAME%" == "" set DIRNAME=. 29 | set APP_BASE_NAME=%~n0 30 | set APP_HOME=%DIRNAME% 31 | 32 | @rem Resolve any "." and ".." in APP_HOME to make it shorter. 33 | for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi 34 | 35 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 36 | set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" 37 | 38 | @rem Find java.exe 39 | if defined JAVA_HOME goto findJavaFromJavaHome 40 | 41 | set JAVA_EXE=java.exe 42 | %JAVA_EXE% -version >NUL 2>&1 43 | if "%ERRORLEVEL%" == "0" goto execute 44 | 45 | echo. 46 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 47 | echo. 48 | echo Please set the JAVA_HOME variable in your environment to match the 49 | echo location of your Java installation. 50 | 51 | goto fail 52 | 53 | :findJavaFromJavaHome 54 | set JAVA_HOME=%JAVA_HOME:"=% 55 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe 56 | 57 | if exist "%JAVA_EXE%" goto execute 58 | 59 | echo. 60 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 61 | echo. 62 | echo Please set the JAVA_HOME variable in your environment to match the 63 | echo location of your Java installation. 64 | 65 | goto fail 66 | 67 | :execute 68 | @rem Setup the command line 69 | 70 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar 71 | 72 | 73 | @rem Execute Gradle 74 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* 75 | 76 | :end 77 | @rem End local scope for the variables with windows NT shell 78 | if "%ERRORLEVEL%"=="0" goto mainEnd 79 | 80 | :fail 81 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of 82 | rem the _cmd.exe /c_ return code! 83 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 84 | exit /b 1 85 | 86 | :mainEnd 87 | if "%OS%"=="Windows_NT" endlocal 88 | 89 | :omega 90 | -------------------------------------------------------------------------------- /optimizer-plugins/adoptimize-integration-tests/testDependencies/org/diffApi/0.0.1/diffApi-0.0.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/optimizer-plugins/0695dc024d4a2f6eaf0559026efcda2a66b5e810/optimizer-plugins/adoptimize-integration-tests/testDependencies/org/diffApi/0.0.1/diffApi-0.0.1.jar -------------------------------------------------------------------------------- /optimizer-plugins/bmgoptimize-integration-tests/testDependencies/org/bmgApi/0.0.1/bmgApi-0.0.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/optimizer-plugins/0695dc024d4a2f6eaf0559026efcda2a66b5e810/optimizer-plugins/bmgoptimize-integration-tests/testDependencies/org/bmgApi/0.0.1/bmgApi-0.0.1.jar -------------------------------------------------------------------------------- /plugin-generators-common/build.gradle.kts: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | description = "Shared codegen utilities among plugins" 9 | 10 | dependencies { 11 | val ktVersion: String by System.getProperties() 12 | compileOnly("org.jetbrains.kotlin:kotlin-stdlib:$ktVersion") 13 | api("org.jetbrains.kotlin:kotlin-compiler-embeddable:$ktVersion") 14 | 15 | api("org.jetbrains.kotlin:kotlin-reflect:$ktVersion") 16 | testRuntimeOnly("org.junit.platform:junit-platform-launcher:1.8.0-M1") 17 | testRuntimeOnly("org.junit.jupiter:junit-jupiter-engine:5.7.1") 18 | testImplementation("junit", "junit", "4.12") 19 | testImplementation("org.junit.jupiter:junit-jupiter-api:5.7.1") 20 | } 21 | 22 | tasks.withType { 23 | useJUnitPlatform() 24 | } 25 | -------------------------------------------------------------------------------- /plugin-generators-common/src/main/kotlin/pluginCommon/ErrorHandling.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package pluginCommon 9 | 10 | class PluginCodegenException(message: String) : Exception(message) 11 | -------------------------------------------------------------------------------- /plugin-generators-common/src/main/kotlin/pluginCommon/ScopeSubstitutionMap.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package pluginCommon 9 | 10 | import org.jetbrains.kotlin.ir.declarations.IrValueDeclaration 11 | 12 | class ScopeSubstitutionMap { 13 | private val scopeMaps = java.util.ArrayDeque>() 14 | 15 | operator fun get(src: IrValueDeclaration): IrValueDeclaration? { 16 | for (scope in scopeMaps) { 17 | val maybeTarget = scope[src] 18 | if (maybeTarget != null) { 19 | return maybeTarget 20 | } 21 | } 22 | return null 23 | } 24 | 25 | operator fun set(src: IrValueDeclaration, target: IrValueDeclaration) { 26 | scopeMaps.first.put(src, target) 27 | } 28 | 29 | fun push() { 30 | scopeMaps.push(mutableMapOf()) 31 | } 32 | 33 | fun pop(): Map { 34 | return scopeMaps.pop() 35 | } 36 | 37 | fun top() = scopeMaps.first 38 | } 39 | -------------------------------------------------------------------------------- /plugin-generators-common/src/main/kotlin/pluginCommon/Substitutor.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package pluginCommon 9 | 10 | import org.jetbrains.kotlin.ir.declarations.IrValueDeclaration 11 | 12 | interface Substitutor { 13 | operator fun get(key: TKey): TValue? 14 | operator fun set(key: TKey, value: TValue) 15 | 16 | companion object { 17 | fun emptySubstitutor() = object : Substitutor { 18 | override fun get(key: TK): TV? = null 19 | override fun set(key: TK, value: TV) {} 20 | } 21 | } 22 | } 23 | 24 | class MapWrapper(val map: MutableMap) : Substitutor { 25 | override fun get(key: TKey): TValue? = map[key] 26 | override fun set(key: TKey, value: TValue) { map[key] = value } 27 | fun contains(key: TKey): Boolean = get(key) != null 28 | } 29 | 30 | class ScopeSubstitutionMapSubstitutor(val substitutionMap: ScopeSubstitutionMap) : Substitutor { 31 | override fun get(key: IrValueDeclaration): IrValueDeclaration? = substitutionMap[key] 32 | override fun set(key: IrValueDeclaration, value: IrValueDeclaration) { substitutionMap[key] = value } 33 | } 34 | -------------------------------------------------------------------------------- /plugin-generators-common/src/main/kotlin/pluginCommon/generators/DescriptorWrappers.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package pluginCommon.generators 9 | 10 | import org.jetbrains.kotlin.descriptors.* 11 | import org.jetbrains.kotlin.ir.ObsoleteDescriptorBasedAPI 12 | import org.jetbrains.kotlin.ir.declarations.* 13 | import org.jetbrains.kotlin.ir.descriptors.IrBasedClassConstructorDescriptor 14 | import org.jetbrains.kotlin.ir.descriptors.IrBasedFieldDescriptor 15 | import org.jetbrains.kotlin.ir.descriptors.IrBasedPropertyDescriptor 16 | import org.jetbrains.kotlin.ir.descriptors.IrBasedSimpleFunctionDescriptor 17 | import org.jetbrains.kotlin.ir.symbols.* 18 | import org.jetbrains.kotlin.ir.util.IdSignature 19 | import org.jetbrains.kotlin.resolve.constants.ConstantValue 20 | 21 | internal open class LateInitIRSymbol(val descriptorCreator: (TIR) -> TD) : IrBindableSymbol { 22 | @ObsoleteDescriptorBasedAPI 23 | override val descriptor: TD 24 | get() = lateInitDescriptor ?: throw IllegalStateException("the descriptor of the symbol has not yet been bound") 25 | private var lateInitDescriptor: TD? = null 26 | private var _owner: TIR? = null 27 | 28 | @ObsoleteDescriptorBasedAPI 29 | override val hasDescriptor: Boolean 30 | get() = lateInitDescriptor != null 31 | override val isBound: Boolean 32 | get() = _owner != null 33 | override val owner: TIR 34 | get() = _owner ?: throw IllegalStateException("The symbol is unbound!") 35 | override val signature: IdSignature? = null 36 | 37 | override fun bind(owner: TIR) { 38 | _owner = owner 39 | lateInitDescriptor = descriptorCreator(owner) 40 | } 41 | 42 | override var privateSignature: IdSignature? = null 43 | } 44 | 45 | internal class LateInitFunctionSymbol : LateInitIRSymbol(::FunctionIrBasedDescriptorWrapper), IrSimpleFunctionSymbol 46 | internal class LateInitPropertySymbol : LateInitIRSymbol(::PropertyIrBasedDescriptorWrapper), IrPropertySymbol 47 | internal class LateInitFieldSymbol : LateInitIRSymbol(::FieldIrBasedDescriptorWrapper), IrFieldSymbol 48 | internal class LateInitConstructorSymbol : LateInitIRSymbol(::ConstructorIrBasedDescriptorWrapper), IrConstructorSymbol 49 | 50 | internal class ConstructorIrBasedDescriptorWrapper(c: IrConstructor) : IrBasedClassConstructorDescriptor(c) { 51 | override fun hasStableParameterNames(): Boolean { 52 | return true 53 | } 54 | } 55 | 56 | internal class FunctionIrBasedDescriptorWrapper(f: IrSimpleFunction) : IrBasedSimpleFunctionDescriptor(f) { 57 | override fun hasStableParameterNames(): Boolean { 58 | return true 59 | } 60 | } 61 | 62 | internal class PropertyIrBasedDescriptorWrapper(p: IrProperty) : IrBasedPropertyDescriptor(p) { 63 | private var backingField: FieldDescriptor? = null 64 | internal fun initialize() { 65 | backingField = owner.backingField?.descriptor as? FieldDescriptor 66 | } 67 | override fun getCompileTimeInitializer(): ConstantValue<*>? = null 68 | override fun getBackingField(): FieldDescriptor? { 69 | return backingField 70 | } 71 | 72 | override fun getDelegateField(): FieldDescriptor? { 73 | return null 74 | } 75 | 76 | override fun getContextReceiverParameters(): List = emptyList() 77 | } 78 | 79 | internal class FieldIrBasedDescriptorWrapper(f: IrField) : IrBasedFieldDescriptor(f), FieldDescriptor { 80 | override fun getCompileTimeInitializer(): ConstantValue<*>? { 81 | return null 82 | } 83 | 84 | override val correspondingProperty: PropertyDescriptor get() = owner.descriptor 85 | } 86 | -------------------------------------------------------------------------------- /plugin-generators-common/src/main/kotlin/pluginCommon/generators/ParameterInfo.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package pluginCommon.generators 9 | 10 | import org.jetbrains.kotlin.ir.types.IrType 11 | import org.jetbrains.kotlin.name.Name 12 | 13 | class ParameterInfo(val name: Name, val tpe: IrType) 14 | -------------------------------------------------------------------------------- /plugin-generators-common/src/main/kotlin/pluginCommon/lowerings/ElseBranchLowering.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package pluginCommon.lowerings 9 | 10 | import org.jetbrains.kotlin.ir.IrBuiltIns 11 | import org.jetbrains.kotlin.ir.IrElement 12 | import org.jetbrains.kotlin.ir.UNDEFINED_OFFSET 13 | import org.jetbrains.kotlin.ir.declarations.IrFunction 14 | import org.jetbrains.kotlin.ir.expressions.IrElseBranch 15 | import org.jetbrains.kotlin.ir.expressions.IrWhen 16 | import org.jetbrains.kotlin.ir.expressions.impl.IrBlockImpl 17 | import org.jetbrains.kotlin.ir.expressions.impl.IrConstImpl 18 | import org.jetbrains.kotlin.ir.expressions.impl.IrElseBranchImpl 19 | import org.jetbrains.kotlin.ir.types.isUnit 20 | import org.jetbrains.kotlin.ir.visitors.IrElementVisitorVoid 21 | import org.jetbrains.kotlin.ir.visitors.acceptChildrenVoid 22 | import org.jetbrains.kotlin.ir.visitors.acceptVoid 23 | 24 | class ElseBranchLowering(val builtIns: IrBuiltIns) : FunctionLowering { 25 | override fun lower(function: IrFunction): IrFunction { 26 | function.acceptVoid(object : IrElementVisitorVoid { 27 | override fun visitElement(element: IrElement) { 28 | element.acceptChildrenVoid(this) 29 | } 30 | 31 | override fun visitWhen(expression: IrWhen) { 32 | if (expression.branches.filterIsInstance().isEmpty() && expression.type.isUnit()) { 33 | expression.branches.add( 34 | IrElseBranchImpl( 35 | UNDEFINED_OFFSET, UNDEFINED_OFFSET, 36 | IrConstImpl.boolean( 37 | UNDEFINED_OFFSET, UNDEFINED_OFFSET, builtIns.booleanType, true 38 | ), 39 | IrBlockImpl( 40 | UNDEFINED_OFFSET, 41 | UNDEFINED_OFFSET, builtIns.unitType, null, emptyList() 42 | ) 43 | ) 44 | ) 45 | } 46 | } 47 | }) 48 | return function 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /plugin-generators-common/src/main/kotlin/pluginCommon/lowerings/FunctionLowering.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package pluginCommon.lowerings 9 | 10 | import org.jetbrains.kotlin.ir.declarations.IrAnonymousInitializer 11 | import org.jetbrains.kotlin.ir.declarations.IrFunction 12 | 13 | interface FunctionLowering { 14 | fun lower(declaration: IrFunction): IrFunction 15 | } 16 | 17 | interface AnonymousInitializerLowering { 18 | fun lower(declaration: IrAnonymousInitializer): IrAnonymousInitializer 19 | } 20 | 21 | interface DeclarationWithBodyLowering : FunctionLowering, AnonymousInitializerLowering 22 | -------------------------------------------------------------------------------- /plugin-generators-common/src/main/kotlin/pluginCommon/lowerings/UnitCastTransformer.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package pluginCommon.lowerings 9 | 10 | import org.jetbrains.kotlin.ir.IrBuiltIns 11 | import org.jetbrains.kotlin.ir.IrElement 12 | import org.jetbrains.kotlin.ir.IrStatement 13 | import org.jetbrains.kotlin.ir.declarations.IrFunction 14 | import org.jetbrains.kotlin.ir.expressions.* 15 | import org.jetbrains.kotlin.ir.expressions.impl.IrSetValueImpl 16 | import org.jetbrains.kotlin.ir.types.isUnit 17 | import pluginCommon.PluginCodegenException 18 | 19 | class UnitCastTransformer(val builtIns: IrBuiltIns) : FunctionLowering { 20 | override fun lower(function: IrFunction): IrFunction { 21 | val target = function.body ?: throw PluginCodegenException("No Body to transform") 22 | val unitCastExpressionMapper = object : ExpressionMapper { 23 | override fun mapTypeOperatorCall( 24 | parent: IrElement, 25 | expression: IrTypeOperatorCall 26 | ): ExpressionMapper.ImageStatement? { 27 | return when { 28 | expression.type.isUnit() -> { 29 | val statementsToAddToParent = mutableListOf() 30 | when (val argument = expression.argument) { 31 | is IrBlock -> { 32 | val statements = argument.statements 33 | statements.removeLast() 34 | 35 | for (i in 0 until statements.size) { 36 | statementsToAddToParent.add(statements[i]) 37 | } 38 | 39 | val newExpression = statements[statements.size - 1] 40 | if (newExpression is IrSetValueImpl && newExpression.type.isUnit()) { 41 | return ExpressionMapper.ImageStatement(newExpression, statementsToAddToParent) 42 | } 43 | 44 | return null 45 | } 46 | else -> null 47 | } 48 | } 49 | else -> null 50 | } 51 | } 52 | } 53 | function.body = target.accept(ShallowTransformer(unitCastExpressionMapper), null) as IrBody 54 | return function 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /plugin-generators-common/src/test/kotlin/pluginCommon/DependencyContainerTests.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | package pluginCommon 9 | 10 | import org.junit.jupiter.api.Test 11 | import kotlin.reflect.full.createType 12 | 13 | // if I add Bar and Foo is dependent on Bat that is dependent on Bar, then 14 | // the container should be able to instantiate Foo without me adding Bar. 15 | // However, the container should not store the instance of Bar. 16 | class DependencyContainerSuccessTest { 17 | interface Bar 18 | class Bat(bar: Bar) 19 | class BarImpl : Bar 20 | class Foo(bat: Bat) 21 | @Test 22 | fun getSuccess() { 23 | val container = DependencyContainer() 24 | container.put(BarImpl()) 25 | 26 | val foo = container.get() 27 | 28 | assert(container.singletonServices.contains(Bat::class.createType()) == false) 29 | assert(container.singletonServices.contains(Foo::class.createType()) == false) 30 | } 31 | } 32 | 33 | class DependencyContainerGetFailureTest { 34 | interface Bar 35 | interface Interface1 36 | 37 | class Bat(bar: Bar, somethingElse: Interface1) 38 | class Bat2(string: String, int: Int) 39 | class BarImpl : Bar 40 | class Target1(bat: Bat) 41 | class Target2(bar2: Bat2) 42 | 43 | @Test 44 | fun getFailureOnInterface() { 45 | val container = DependencyContainer() 46 | container.put(BarImpl()) 47 | 48 | try { 49 | val target1 = container.get() 50 | assert(false, { "An exception should have been thrown because Interface1 cannot be instantiated." }) 51 | } catch (e: PluginCodegenException) { 52 | println(e.message) 53 | } 54 | } 55 | 56 | @Test 57 | fun getFailureOnPrimitive() { 58 | val container = DependencyContainer() 59 | try { 60 | val target2 = container.get() 61 | assert(false, { "An exception should have been thrown because Interface1 cannot be instantiated." }) 62 | } catch (e: PluginCodegenException) { 63 | } 64 | } 65 | } 66 | 67 | class DependencyContainerRecursionFailureTest { 68 | interface Bar 69 | class BarImpl(b: BarImpl) : Bar 70 | @Test 71 | fun recursionFailureOnRecursion() { 72 | val container = DependencyContainer() 73 | try { 74 | val instance = container.get() 75 | assert(false, { "An exception should have been thrown because Interface1 cannot be instantiated." }) 76 | } catch (e: PluginCodegenException) { 77 | assert(e.message?.contains("Recursion!") == true) 78 | } 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /producer-consumer/README.md: -------------------------------------------------------------------------------- 1 | To debug this project locally, perform the followingcommands: 2 | ``` 3 | # in project root 4 | ./gradlew publishToMavenLocal 5 | cd producer-consumer 6 | ./gradlew :producer:publishToMavenLocal 7 | ./gradlew :consumer:run 8 | ``` 9 | 10 | If you would like to debug the compilation process (this was useful for me to debug serialization/deserialization), following these steps: 11 | 1. download the Kotlin compiler 12 | 2. publish the Kotlin compiler locally 13 | 3. update the version of the compiler used in this repo by editing the gradle.properties file. 14 | 4. Remote debug from this repo by using this command: ```./gradlew [task] --no-daemon -Dorg.gradle.debug=true -Dkotlin.compiler.execution.strategy="in-process" -Dkotlin.daemon.jvm.options="-Xdebug,-Xrunjdwp:transport=dt_socket,address=5005,server=y,suspend=n"```, where task refers to either consumer:run or producer:compileKotlin. 15 | 5. Setup a Remote JVM Debug build in the compiler and execute it -------------------------------------------------------------------------------- /producer-consumer/build.gradle.kts: -------------------------------------------------------------------------------- 1 | /* 2 | * 3 | * Copyright (c) Meta Platforms, Inc. and affiliates. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | 10 | import org.jetbrains.kotlin.gradle.tasks.KotlinCompile 11 | plugins { 12 | val ktVersion: String by System.getProperties() 13 | kotlin("jvm") version ktVersion 14 | } 15 | 16 | allprojects { 17 | group = "test" 18 | apply(plugin = "java") 19 | apply(plugin = "kotlin") 20 | 21 | dependencies { 22 | implementation(kotlin("stdlib")) 23 | } 24 | 25 | java { 26 | sourceCompatibility = JavaVersion.VERSION_1_8 27 | targetCompatibility = JavaVersion.VERSION_1_8 28 | } 29 | 30 | tasks.withType() { 31 | kotlinOptions.jvmTarget = "1.8" 32 | kotlinOptions.freeCompilerArgs += "-Xserialize-ir=inline" 33 | kotlinOptions.freeCompilerArgs += "-XXLanguage:+ProperCheckAnnotationsTargetInTypeUsePositions" 34 | } 35 | 36 | repositories { 37 | mavenLocal() 38 | mavenCentral() 39 | maven { 40 | url = uri("https://maven.pkg.jetbrains.space/kotlin/p/kotlin/bootstrap") 41 | } 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /producer-consumer/consumer/build.gradle.kts: -------------------------------------------------------------------------------- 1 | /* 2 | * 3 | * Copyright (c) Meta Platforms, Inc. and affiliates. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | 10 | plugins { 11 | `maven-publish` 12 | kotlin("jvm") 13 | val pluginVersions: String by System.getProperties() 14 | id("meta-diffkt-adoptimize") version pluginVersions 15 | application 16 | } 17 | 18 | application { 19 | mainClass.set("consumer.MainKt") 20 | } 21 | 22 | adOptimize { 23 | this.diffApi("test", "producer", "1.0-SNAPSHOT") 24 | this.optimizeAnnotation("consumer.Optimize") 25 | } 26 | 27 | dependencies { 28 | implementation(kotlin("stdlib-jdk8")) 29 | implementation(kotlin("reflect")) 30 | implementation("test", "producer", "1.0-SNAPSHOT") 31 | } 32 | -------------------------------------------------------------------------------- /producer-consumer/consumer/src/main/kotlin/main.kt: -------------------------------------------------------------------------------- 1 | /* 2 | * 3 | * Copyright (c) Meta Platforms, Inc. and affiliates. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | 10 | package consumer 11 | 12 | import demo.* 13 | 14 | annotation class Optimize 15 | 16 | @Optimize 17 | fun squared(a: DifferentiableDouble): DifferentiableDouble { 18 | return a * a 19 | } 20 | 21 | fun squaredVanilla(a: DifferentiableDouble): DifferentiableDouble { 22 | return a * a 23 | } 24 | 25 | fun main() { 26 | val x = DDouble(0.12) 27 | 28 | // reverse 29 | run { 30 | val optimized = primalAndReverseDerivative(x, ::squared) 31 | val vanilla = primalAndReverseDerivative(x, ::squaredVanilla) 32 | if (optimized.first.value() != vanilla.first.value()) { 33 | throw IllegalStateException("PRIMAL FAIL: expected ${vanilla.first.value()} but got ${optimized.first.value()}") 34 | } 35 | if (optimized.second.value() != vanilla.second.value()) { 36 | throw IllegalStateException("DERIVATIVE FAIL: expected ${vanilla.second.value()} but got ${optimized.second.value()}") 37 | } 38 | } 39 | 40 | // forward 41 | run { 42 | val optimized = primalAndForwardDerivative(x, ::squared) 43 | val vanilla = primalAndForwardDerivative(x, ::squaredVanilla) 44 | if (optimized.first.value() != vanilla.first.value()) { 45 | throw IllegalStateException("PRIMAL FAIL: expected ${vanilla.first.value()} but got ${optimized.first.value()}") 46 | } 47 | if (optimized.second.value() != vanilla.second.value()) { 48 | throw IllegalStateException("DERIVATIVE FAIL: expected ${vanilla.second.value()} but got ${optimized.second.value()}") 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /producer-consumer/gradle.properties: -------------------------------------------------------------------------------- 1 | # 2 | # 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | # 9 | 10 | systemProp.ktVersion=1.7.0-dev-444 11 | systemProp.pluginVersions=0.0.1-SNAPSHOT -------------------------------------------------------------------------------- /producer-consumer/gradle/wrapper/gradle-wrapper.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/optimizer-plugins/0695dc024d4a2f6eaf0559026efcda2a66b5e810/producer-consumer/gradle/wrapper/gradle-wrapper.jar -------------------------------------------------------------------------------- /producer-consumer/gradle/wrapper/gradle-wrapper.properties: -------------------------------------------------------------------------------- 1 | # 2 | # 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | # 9 | 10 | distributionBase=GRADLE_USER_HOME 11 | distributionPath=wrapper/dists 12 | distributionUrl=https\://services.gradle.org/distributions/gradle-7.1-bin.zip 13 | zipStoreBase=GRADLE_USER_HOME 14 | zipStorePath=wrapper/dists 15 | -------------------------------------------------------------------------------- /producer-consumer/gradlew.bat: -------------------------------------------------------------------------------- 1 | @rem 2 | @rem Copyright 2015 the original author or authors. 3 | @rem 4 | @rem Licensed under the Apache License, Version 2.0 (the "License"); 5 | @rem you may not use this file except in compliance with the License. 6 | @rem You may obtain a copy of the License at 7 | @rem 8 | @rem https://www.apache.org/licenses/LICENSE-2.0 9 | @rem 10 | @rem Unless required by applicable law or agreed to in writing, software 11 | @rem distributed under the License is distributed on an "AS IS" BASIS, 12 | @rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | @rem See the License for the specific language governing permissions and 14 | @rem limitations under the License. 15 | @rem 16 | 17 | @if "%DEBUG%" == "" @echo off 18 | @rem ########################################################################## 19 | @rem 20 | @rem Gradle startup script for Windows 21 | @rem 22 | @rem ########################################################################## 23 | 24 | @rem Set local scope for the variables with windows NT shell 25 | if "%OS%"=="Windows_NT" setlocal 26 | 27 | set DIRNAME=%~dp0 28 | if "%DIRNAME%" == "" set DIRNAME=. 29 | set APP_BASE_NAME=%~n0 30 | set APP_HOME=%DIRNAME% 31 | 32 | @rem Resolve any "." and ".." in APP_HOME to make it shorter. 33 | for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi 34 | 35 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 36 | set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" 37 | 38 | @rem Find java.exe 39 | if defined JAVA_HOME goto findJavaFromJavaHome 40 | 41 | set JAVA_EXE=java.exe 42 | %JAVA_EXE% -version >NUL 2>&1 43 | if "%ERRORLEVEL%" == "0" goto execute 44 | 45 | echo. 46 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 47 | echo. 48 | echo Please set the JAVA_HOME variable in your environment to match the 49 | echo location of your Java installation. 50 | 51 | goto fail 52 | 53 | :findJavaFromJavaHome 54 | set JAVA_HOME=%JAVA_HOME:"=% 55 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe 56 | 57 | if exist "%JAVA_EXE%" goto execute 58 | 59 | echo. 60 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 61 | echo. 62 | echo Please set the JAVA_HOME variable in your environment to match the 63 | echo location of your Java installation. 64 | 65 | goto fail 66 | 67 | :execute 68 | @rem Setup the command line 69 | 70 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar 71 | 72 | 73 | @rem Execute Gradle 74 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* 75 | 76 | :end 77 | @rem End local scope for the variables with windows NT shell 78 | if "%ERRORLEVEL%"=="0" goto mainEnd 79 | 80 | :fail 81 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of 82 | rem the _cmd.exe /c_ return code! 83 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 84 | exit /b 1 85 | 86 | :mainEnd 87 | if "%OS%"=="Windows_NT" endlocal 88 | 89 | :omega 90 | -------------------------------------------------------------------------------- /producer-consumer/producer/build.gradle.kts: -------------------------------------------------------------------------------- 1 | /* 2 | * 3 | * Copyright (c) Meta Platforms, Inc. and affiliates. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | 10 | plugins { 11 | `maven-publish` 12 | val pluginVersions: String by System.getProperties() 13 | id("meta-diffkt-differentiable-api-preprocessor") version pluginVersions 14 | } 15 | 16 | group = "org.example" 17 | version = "1.0-SNAPSHOT" 18 | 19 | differentiableApiPreprocessor { 20 | this.stackImplAnnotation("demo.StackImpl") 21 | this.boxedPrimitive("demo.BoxedPrimitive") 22 | this.scalarRoot("demo.ScalarRoot") 23 | this.primalAndPullbackAnnotation("demo.PrimalAndPullback") 24 | this.reverseAnnotation("demo.ReverseDifferentiable") 25 | this.unboxedFunction("demo.ToUnboxedFunction") 26 | val userDir = System.getProperty("user.dir") 27 | val pathToResources = "$userDir/src/main/resources" 28 | this.resourcesPath(pathToResources) 29 | this.toReverseAnnotation("demo.ToReverse") 30 | this.dTensorAnnotation("demo.DTensorRoot") 31 | this.reverseScalarOperationsAnnotation("demo.ReverseOperations") 32 | this.scalarNoop("demo.ScalarNoop") 33 | this.forwardDifferentiable("demo.ForwardDifferentiable") 34 | } 35 | 36 | tasks { 37 | withType { 38 | kotlinOptions.jvmTarget = "1.8" 39 | } 40 | 41 | withType> { 42 | kotlinOptions { 43 | freeCompilerArgs = freeCompilerArgs + "-Xserialize-ir=inline" + "-opt-in=org.jetbrains.kotlin.ir.ObsoleteDescriptorBasedAPI" 44 | } 45 | } 46 | } 47 | 48 | publishing { 49 | publications { 50 | create("maven") { 51 | groupId = "test" 52 | artifactId = "producer" 53 | version = project.version.toString() 54 | from(components["java"]) 55 | } 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /producer-consumer/producer/src/main/resources/adoptimize.properties: -------------------------------------------------------------------------------- 1 | reverseClass=demo.ReverseNode 2 | forwardScalarClass=demo.ForwardNode 3 | primalProperty=primal 4 | upstreamProperty=upstream 5 | tangentProperty=tangent 6 | backpropMethod=backpropagate 7 | pushbackMethod=pushback 8 | derivativeId=derivativeID 9 | scalarRoot=demo.DifferentiableDouble 10 | primalAndPullbackFunction=demo.primalAndPullback 11 | boxedPrimitive=demo.DDouble 12 | valueProperty=value 13 | primitiveType=Double 14 | stackImpl=demo.StackForCompiler 15 | scalarPlusFunction=demo.plus 16 | tensorPlusFunction=demo.plus 17 | toUnboxFunction=demo.ToUnboxedFunction 18 | toReverse=demo.ToReverse 19 | dTensor=demo.DiffTensor 20 | reverseOperations=demo.ReverseScalarOperations 21 | scalarZero=demo.DDouble.Companion.ZERO 22 | scalarOne=demo.DDouble.Companion.ONE 23 | scalarNoop=demo.ScalarNoop 24 | -------------------------------------------------------------------------------- /producer-consumer/settings.gradle.kts: -------------------------------------------------------------------------------- 1 | /* 2 | * 3 | * Copyright (c) Meta Platforms, Inc. and affiliates. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | 10 | rootProject.name = "producer-consumer" 11 | 12 | pluginManagement { 13 | repositories { 14 | gradlePluginPortal() 15 | mavenCentral() 16 | mavenLocal() 17 | maven { 18 | url = uri("https://maven.pkg.jetbrains.space/kotlin/p/kotlin/bootstrap") 19 | } 20 | } 21 | } 22 | 23 | include(":consumer") 24 | include(":producer") 25 | -------------------------------------------------------------------------------- /producer-consumer/src/main/resources/adoptimize.properties: -------------------------------------------------------------------------------- 1 | # 2 | # 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | # 9 | 10 | reverseClass=demo.ReverseNode 11 | forwardScalarClass=demo.ForwardNode 12 | primalProperty=primal 13 | upstreamProperty=upstream 14 | tangentProperty=tangent 15 | backpropMethod=backpropagate 16 | pushbackMethod=pushback 17 | derivativeId=derivativeID 18 | scalarRoot=demo.DifferentiableDouble 19 | primalAndPullbackFunction=demo.primalAndPullback 20 | boxedPrimitive=demo.DDouble 21 | valueProperty=value 22 | primitiveType=Double 23 | stackImpl=demo.StackForCompiler 24 | scalarPlusFunction=demo.plus 25 | tensorPlusFunction=demo.plus 26 | toUnboxFunction=demo.ToUnboxedFunction 27 | toReverse=demo.ToReverse 28 | dTensor=demo.DiffTensor 29 | reverseOperations=demo.ReverseScalarOperations 30 | scalarZero=demo.DDouble.Companion.ZERO 31 | scalarOne=demo.DDouble.Companion.ONE 32 | scalarNoop=demo.ScalarNoop 33 | -------------------------------------------------------------------------------- /settings.gradle.kts: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | rootProject.name = "optimizer-plugins" 9 | 10 | include(":adoptimize-cli-compiler-plugin") 11 | include(":adoptimize-gradle-plugin") 12 | include(":adoptimize-integration-tests") 13 | include(":adoptimize-publish") 14 | include(":adoptimize-common") 15 | 16 | include(":differentiable-api-preprocessor-compiler-plugin") 17 | include(":differentiable-api-preprocessor-gradle-plugin") 18 | include(":differentiable-api-preprocessor-integration-tests") 19 | include(":differentiable-api-preprocessor-publish") 20 | 21 | include(":config") 22 | include("plugin-generators-common") 23 | 24 | pluginManagement { 25 | repositories { 26 | gradlePluginPortal() 27 | mavenCentral() 28 | maven { 29 | url = uri("https://maven.pkg.jetbrains.space/kotlin/p/kotlin/bootstrap") 30 | } 31 | } 32 | } 33 | --------------------------------------------------------------------------------