├── .gitignore ├── docs ├── .gitignore ├── src │ ├── functions.md │ ├── .vitepress │ │ ├── theme │ │ │ ├── index.ts │ │ │ └── style.css │ │ └── config.mts │ ├── index.md │ ├── examples.md │ └── atoms.md ├── package.json ├── make.jl └── Project.toml ├── .JuliaFormatter.toml ├── .github ├── workflows │ ├── SpellCheck.yml │ ├── CompatHelper.yml │ ├── TagBot.yml │ ├── Downgrade.yml │ ├── CI.yml │ ├── FormatCheck.yml │ └── Documenter.yml └── dependabot.yml ├── src ├── canon.jl ├── lianalg.jl ├── SymbolicAnalysis.jl ├── gdcp │ ├── lorentz.jl │ ├── gdcp_rules.jl │ └── spd.jl ├── rules.jl └── atoms.jl ├── test ├── runtests.jl ├── Project.toml ├── test.jl ├── lorentz.jl └── dgp.jl ├── LICENSE ├── README.md ├── Project.toml └── .typos.toml /.gitignore: -------------------------------------------------------------------------------- 1 | Manifest.toml 2 | .DS_Store -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | node_modules/ 3 | package-lock.json 4 | Manifest.toml -------------------------------------------------------------------------------- /.JuliaFormatter.toml: -------------------------------------------------------------------------------- 1 | style = "sciml" 2 | format_markdown = true 3 | format_docstrings = true 4 | -------------------------------------------------------------------------------- /docs/src/functions.md: -------------------------------------------------------------------------------- 1 | # Special functions 2 | 3 | Since some atoms are not available in the base language or other packages 4 | we have implemented them here. 5 | 6 | ```@autodocs 7 | Modules=[SymbolicAnalysis] 8 | Pages=["atoms.jl", "spd.jl"] 9 | ``` 10 | -------------------------------------------------------------------------------- /.github/workflows/SpellCheck.yml: -------------------------------------------------------------------------------- 1 | name: Spell Check 2 | 3 | on: [pull_request] 4 | 5 | jobs: 6 | typos-check: 7 | name: Spell Check with Typos 8 | runs-on: ubuntu-latest 9 | steps: 10 | - name: Checkout Actions Repository 11 | uses: actions/checkout@v6 12 | - name: Check spelling 13 | uses: crate-ci/typos@v1.31.1 -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 2 | version: 2 3 | updates: 4 | - package-ecosystem: "github-actions" 5 | directory: "/" # Location of package manifests 6 | schedule: 7 | interval: "weekly" 8 | ignore: 9 | - dependency-name: "crate-ci/typos" 10 | update-types: ["version-update:semver-patch", "version-update:semver-minor"] 11 | -------------------------------------------------------------------------------- /src/canon.jl: -------------------------------------------------------------------------------- 1 | function canonize(ex) 2 | rs = [@rule (adjoint(~x) * (~Y * ~x))[1] => quad_form(~x, ~Y) 3 | @rule ((adjoint(~B) * ~X) * ~B)[Base.OneTo(size(~B, 2)), Base.OneTo(size( 4 | ~B, 1))] => conjugation(~X, ~B)] 5 | try 6 | rc = SymbolicUtils.Chain(rs) 7 | ex = SymbolicUtils.Postwalk(rc)(ex) 8 | ex = SymbolicUtils.Prewalk(rc)(ex) 9 | return ex 10 | catch 11 | return ex 12 | end 13 | end 14 | -------------------------------------------------------------------------------- /docs/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "scripts": { 3 | "docs:dev": "vitepress dev build/.documenter", 4 | "docs:build": "vitepress build build/.documenter", 5 | "docs:preview": "vitepress preview build/.documenter" 6 | }, 7 | "dependencies": { 8 | "@shikijs/transformers": "^1.1.7", 9 | "markdown-it": "^14.1.0", 10 | "markdown-it-footnote": "^4.0.0", 11 | "markdown-it-mathjax3": "^4.3.2", 12 | "vitepress-plugin-tabs": "^0.5.0" 13 | }, 14 | "devDependencies": { 15 | "vitepress": "^1.1.4" 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /.github/workflows/CompatHelper.yml: -------------------------------------------------------------------------------- 1 | name: CompatHelper 2 | on: 3 | schedule: 4 | - cron: 0 0 * * * 5 | workflow_dispatch: 6 | jobs: 7 | CompatHelper: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - name: Pkg.add("CompatHelper") 11 | run: julia -e 'using Pkg; Pkg.add("CompatHelper")' 12 | - name: CompatHelper.main() 13 | env: 14 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 15 | COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} 16 | run: julia -e 'using CompatHelper; CompatHelper.main()' 17 | -------------------------------------------------------------------------------- /test/runtests.jl: -------------------------------------------------------------------------------- 1 | using SymbolicAnalysis: 2 | propagate_curvature, 3 | propagate_sign, 4 | propagate_gcurvature, 5 | getcurvature, 6 | getsign, 7 | getgcurvature 8 | 9 | using SafeTestsets, Test 10 | 11 | @testset "DCP" begin 12 | include("test.jl") 13 | end 14 | 15 | @testset "DGCP - SPD Manifold" begin 16 | include("dgp.jl") 17 | end 18 | 19 | @testset "DGCP - Lorentz Manifold" begin 20 | include("lorentz.jl") 21 | end 22 | -------------------------------------------------------------------------------- /docs/src/.vitepress/theme/index.ts: -------------------------------------------------------------------------------- 1 | // .vitepress/theme/index.ts 2 | import { h } from 'vue' 3 | import type { Theme } from 'vitepress' 4 | import DefaultTheme from 'vitepress/theme' 5 | 6 | import { enhanceAppWithTabs } from 'vitepress-plugin-tabs/client' 7 | import './style.css' 8 | 9 | export default { 10 | extends: DefaultTheme, 11 | Layout() { 12 | return h(DefaultTheme.Layout, null, { 13 | // https://vitepress.dev/guide/extending-default-theme#layout-slots 14 | }) 15 | }, 16 | enhanceApp({ app, router, siteData }) { 17 | enhanceAppWithTabs(app) 18 | } 19 | } satisfies Theme -------------------------------------------------------------------------------- /test/Project.toml: -------------------------------------------------------------------------------- 1 | [deps] 2 | ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" 3 | LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" 4 | Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e" 5 | Manopt = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5" 6 | Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" 7 | OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb" 8 | OptimizationManopt = "e57b7fff-7ee7-4550-b4f0-90e9476e9fb6" 9 | PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" 10 | Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" 11 | SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" 12 | Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" 13 | Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" 14 | Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" -------------------------------------------------------------------------------- /docs/make.jl: -------------------------------------------------------------------------------- 1 | using Documenter, DocumenterVitepress 2 | 3 | using SymbolicAnalysis 4 | 5 | makedocs(; 6 | modules = [SymbolicAnalysis], 7 | authors = "Vaibhav Dixit, Shashi Gowda", 8 | repo = "https://github.com/Vaibhavdixit02/SymbolicAnalysis.jl", 9 | sitename = "SymbolicAnalysis.jl", 10 | format = DocumenterVitepress.MarkdownVitepress( 11 | repo = "https://github.com/Vaibhavdixit02/SymbolicAnalysis.jl", 12 | devurl = "dev" 13 | ), 14 | pages = [ 15 | "Home" => "index.md", 16 | "Examples" => "examples.md", 17 | "Atoms" => "atoms.md", 18 | "Special Functions" => "functions.md" 19 | ], 20 | warnonly = true 21 | ) 22 | 23 | deploydocs(; repo = "github.com/Vaibhavdixit02/SymbolicAnalysis.jl", push_preview = true) 24 | -------------------------------------------------------------------------------- /.github/workflows/TagBot.yml: -------------------------------------------------------------------------------- 1 | name: TagBot 2 | on: 3 | issue_comment: 4 | types: 5 | - created 6 | workflow_dispatch: 7 | inputs: 8 | lookback: 9 | default: 3 10 | permissions: 11 | actions: read 12 | checks: read 13 | contents: write 14 | deployments: read 15 | issues: read 16 | discussions: read 17 | packages: read 18 | pages: read 19 | pull-requests: read 20 | repository-projects: read 21 | security-events: read 22 | statuses: read 23 | jobs: 24 | TagBot: 25 | if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot' 26 | runs-on: ubuntu-latest 27 | steps: 28 | - uses: JuliaRegistries/TagBot@v1 29 | with: 30 | token: ${{ secrets.GITHUB_TOKEN }} 31 | ssh: ${{ secrets.DOCUMENTER_KEY }} 32 | -------------------------------------------------------------------------------- /src/lianalg.jl: -------------------------------------------------------------------------------- 1 | using LinearAlgebra 2 | using Symbolics: Num, simplify 3 | 4 | function LinearAlgebra.ishermitian(A::AbstractMatrix{Num}; kwargs...) 5 | indsm, indsn = axes(A) 6 | if indsm != indsn 7 | return false 8 | end 9 | for i in indsn, j in i:last(indsn) 10 | 11 | d = simplify(A[i, j] - adjoint(A[j, i])) 12 | if !isapprox(d, 0.0; kwargs...) 13 | return false 14 | end 15 | end 16 | return true 17 | end 18 | 19 | ## Numbers 20 | function LinearAlgebra._chol!(x::Num, uplo) 21 | rx = real(x) 22 | rxr = sqrt(abs(rx)) 23 | rval = convert(promote_type(typeof(x), typeof(rxr)), rxr) 24 | d = rx - abs(x) |> simplify 25 | println(d) 26 | isapprox(d, 0.0) ? (rval, convert(BlasInt, 0)) : (rval, convert(BlasInt, 1)) 27 | end 28 | -------------------------------------------------------------------------------- /.github/workflows/Downgrade.yml: -------------------------------------------------------------------------------- 1 | name: Downgrade 2 | on: 3 | pull_request: 4 | branches: 5 | - master 6 | paths-ignore: 7 | - 'docs/**' 8 | push: 9 | branches: 10 | - master 11 | paths-ignore: 12 | - 'docs/**' 13 | jobs: 14 | test: 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | downgrade_mode: ['alldeps'] 19 | julia-version: ['1.10'] 20 | steps: 21 | - uses: actions/checkout@v6 22 | - uses: julia-actions/setup-julia@v2 23 | with: 24 | version: ${{ matrix.julia-version }} 25 | - uses: julia-actions/julia-downgrade-compat@v2 26 | with: 27 | skip: Pkg,TOML 28 | - uses: julia-actions/julia-buildpkg@v1 29 | - uses: julia-actions/julia-runtest@v1 30 | with: 31 | ALLOW_RERESOLVE: false -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Vaibhav Dixit and contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SymbolicAnalysis.jl 2 | 3 | 4 | 5 | Symbolics.jl based function property propagation for optimization 6 | 7 | SymbolicAnalysis is a package for implementing the Disciplined Programming approach to optimization, 8 | As demonstrated by the [DCP framework](https://dcp.stanford.edu/), and further followups to it for further classes of 9 | functions https://www.cvxpy.org/tutorial/index.html such as DGP, DQP etc, symbolic representation of problems can be leveraged 10 | to identify and facilitate building Convex (or similar function properties) expressions. 11 | 12 | This package aims to utilize expression graph rewriting and metadata propagation supported by Symbolics.jl, to support 13 | propagation of several of these properties - limited right now to Euclidean Convexity and Geodesic Convexity on the Symmetric 14 | Positive Definite manifold. This package provides an easier to expand implementation of functional properties than the previous 15 | implementations [CVXPY](https://www.cvxpy.org/index.html) and [Convex.jl](https://github.com/jump-dev/Convex.jl) as well as a 16 | more performant implementation of the function property propagation. 17 | -------------------------------------------------------------------------------- /.github/workflows/CI.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: 3 | push: 4 | branches: 5 | - main 6 | tags: ['*'] 7 | pull_request: 8 | workflow_dispatch: 9 | concurrency: 10 | # Skip intermediate builds: always. 11 | # Cancel intermediate builds: only if it is a pull request build. 12 | group: ${{ github.workflow }}-${{ github.ref }} 13 | cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} 14 | jobs: 15 | test: 16 | name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} 17 | runs-on: ${{ matrix.os }} 18 | timeout-minutes: 60 19 | permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created 20 | actions: write 21 | contents: read 22 | strategy: 23 | fail-fast: false 24 | matrix: 25 | version: 26 | - '1' 27 | os: 28 | - ubuntu-latest 29 | arch: 30 | - x64 31 | steps: 32 | - uses: actions/checkout@v6 33 | - uses: julia-actions/setup-julia@v2 34 | with: 35 | version: ${{ matrix.version }} 36 | arch: ${{ matrix.arch }} 37 | - uses: julia-actions/cache@v2 38 | - uses: julia-actions/julia-buildpkg@v1 39 | - uses: julia-actions/julia-runtest@v1 40 | -------------------------------------------------------------------------------- /.github/workflows/FormatCheck.yml: -------------------------------------------------------------------------------- 1 | name: format-check 2 | 3 | on: 4 | push: 5 | branches: 6 | - 'master' 7 | - 'release-' 8 | tags: '*' 9 | pull_request: 10 | 11 | jobs: 12 | build: 13 | runs-on: ${{ matrix.os }} 14 | strategy: 15 | matrix: 16 | julia-version: [1] 17 | julia-arch: [x86] 18 | os: [ubuntu-latest] 19 | steps: 20 | - uses: julia-actions/setup-julia@latest 21 | with: 22 | version: ${{ matrix.julia-version }} 23 | 24 | - uses: actions/checkout@v6 25 | - name: Install JuliaFormatter and format 26 | # This will use the latest version by default but you can set the version like so: 27 | # 28 | # julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter", version="0.13.0"))' 29 | run: | 30 | julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' 31 | julia -e 'using JuliaFormatter; format(".", verbose=true)' 32 | - name: Format check 33 | run: | 34 | julia -e ' 35 | out = Cmd(`git diff --name-only`) |> read |> String 36 | if out == "" 37 | exit(0) 38 | else 39 | @error "Some files have not been formatted !!!" 40 | write(stdout, out) 41 | exit(1) 42 | end' 43 | -------------------------------------------------------------------------------- /docs/Project.toml: -------------------------------------------------------------------------------- 1 | [deps] 2 | Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" 3 | DocumenterVitepress = "4710194d-e776-4893-9690-8d956a29c365" 4 | SymbolicAnalysis = "4297ee4d-0239-47d8-ba5d-195ecdf594fe" 5 | DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" 6 | Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" 7 | Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" 8 | DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" 9 | DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2" 10 | IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" 11 | LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" 12 | LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" 13 | Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e" 14 | PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" 15 | RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" 16 | StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" 17 | SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" 18 | Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" 19 | 20 | [compat] 21 | DataStructures = "0.18" 22 | Dictionaries = "0.4" 23 | Distributions = "0.25" 24 | DomainSets = "0.7" 25 | DSP = "0.7" 26 | IfElse = "0.1" 27 | Manifolds = "0.9" 28 | LinearAlgebra = "1.10" 29 | LogExpFunctions = "0.3" 30 | PDMats = "0.11" 31 | RecursiveArrayTools = "3" 32 | StatsBase = "0.34" 33 | Symbolics = "6" 34 | SymbolicUtils = "3" -------------------------------------------------------------------------------- /Project.toml: -------------------------------------------------------------------------------- 1 | name = "SymbolicAnalysis" 2 | uuid = "4297ee4d-0239-47d8-ba5d-195ecdf594fe" 3 | authors = ["Vaibhav Dixit ", "Shashi Gowda "] 4 | version = "0.3.2" 5 | 6 | [deps] 7 | DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2" 8 | DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" 9 | Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" 10 | Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" 11 | DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" 12 | IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" 13 | LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" 14 | LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" 15 | Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e" 16 | PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" 17 | RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" 18 | StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" 19 | SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" 20 | Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" 21 | 22 | [compat] 23 | DSP = "0.7" 24 | DataStructures = "0.18, 0.19" 25 | Dictionaries = "0.4" 26 | Distributions = "0.25" 27 | DomainSets = "0.7" 28 | IfElse = "0.1" 29 | LinearAlgebra = "1.10" 30 | LogExpFunctions = "0.3" 31 | Manifolds = "0.9, 0.10" 32 | PDMats = "0.11" 33 | RecursiveArrayTools = "3" 34 | StatsBase = "0.34" 35 | Symbolics = "6" 36 | SymbolicUtils = "3.1.2" 37 | julia = "1.10" 38 | 39 | -------------------------------------------------------------------------------- /.typos.toml: -------------------------------------------------------------------------------- 1 | [default.extend-words] 2 | # Julia-specific functions 3 | indexin = "indexin" 4 | findfirst = "findfirst" 5 | findlast = "findlast" 6 | eachindex = "eachindex" 7 | setp = "setp" 8 | getp = "getp" 9 | setu = "setu" 10 | getu = "getu" 11 | 12 | # Mathematical/scientific terms 13 | jacobian = "jacobian" 14 | hessian = "hessian" 15 | eigenvalue = "eigenvalue" 16 | eigenvector = "eigenvector" 17 | discretization = "discretization" 18 | linearization = "linearization" 19 | parameterized = "parameterized" 20 | discretized = "discretized" 21 | vectorized = "vectorized" 22 | 23 | # Common variable patterns in Julia/SciML 24 | ists = "ists" 25 | ispcs = "ispcs" 26 | osys = "osys" 27 | rsys = "rsys" 28 | usys = "usys" 29 | fsys = "fsys" 30 | eqs = "eqs" 31 | rhs = "rhs" 32 | lhs = "lhs" 33 | ode = "ode" 34 | pde = "pde" 35 | sde = "sde" 36 | dde = "dde" 37 | bvp = "bvp" 38 | ivp = "ivp" 39 | 40 | # Common abbreviations 41 | tol = "tol" 42 | rtol = "rtol" 43 | atol = "atol" 44 | idx = "idx" 45 | jdx = "jdx" 46 | prev = "prev" 47 | curr = "curr" 48 | init = "init" 49 | tmp = "tmp" 50 | vec = "vec" 51 | arr = "arr" 52 | dt = "dt" 53 | du = "du" 54 | dx = "dx" 55 | dy = "dy" 56 | dz = "dz" 57 | 58 | # Algorithm/type suffixes 59 | alg = "alg" 60 | prob = "prob" 61 | sol = "sol" 62 | cb = "cb" 63 | opts = "opts" 64 | args = "args" 65 | kwargs = "kwargs" 66 | 67 | # Scientific abbreviations 68 | ND = "ND" 69 | nd = "nd" 70 | MTK = "MTK" 71 | ODE = "ODE" 72 | PDE = "PDE" 73 | SDE = "SDE" 74 | -------------------------------------------------------------------------------- /docs/src/index.md: -------------------------------------------------------------------------------- 1 | # SymbolicAnalysis.jl 2 | 3 | Symbolics-based function property propagation for optimization 4 | 5 | SymbolicAnalysis is a package for implementing the Disciplined Programming approach to optimization. Testing convexity structure in nonlinear programs relies on verifying the convexity of objectives and constraints. [Disciplined Convex Programming (DCP)](https://dcp.stanford.edu/), is a framework for automating this verification task for a wide range of convex functions that can be decomposed into basic convex functions (atoms) using convexity-preserving compositions and transformations (rules). 6 | 7 | This package aims to utilize expression graph rewriting and metadata propagation provided by Symbolics.jl, for analysis of relevant properties - limited right now to Euclidean Convexity and Geodesic Convexity on the Symmetric Positive Definite manifold. This package provides an easy to expand implementation of "atoms", that are functions that have known properties. This allows users to add atoms to the library more easily than the previous implementations [CVXPY](https://www.cvxpy.org/index.html) and [Convex.jl](https://github.com/jump-dev/Convex.jl). 8 | 9 | ## Installation 10 | 11 | To install this package, run the following in the Julia REPL: 12 | 13 | ```julia 14 | using Pkg 15 | Pkg.add("SymbolicAnalysis") 16 | ``` 17 | 18 | ## Usage 19 | 20 | The main interface to this package is the `analyze` function. 21 | 22 | ```@autodocs 23 | Modules=[SymbolicAnalysis] 24 | Pages=["SymbolicAnalysis.jl"] 25 | ``` 26 | -------------------------------------------------------------------------------- /docs/src/.vitepress/config.mts: -------------------------------------------------------------------------------- 1 | import { defineConfig } from 'vitepress' 2 | import { tabsMarkdownPlugin } from 'vitepress-plugin-tabs' 3 | import mathjax3 from "markdown-it-mathjax3"; 4 | import footnote from "markdown-it-footnote"; 5 | 6 | // https://vitepress.dev/reference/site-config 7 | export default defineConfig({ 8 | base: 'REPLACE_ME_DOCUMENTER_VITEPRESS',// TODO: replace this in makedocs! 9 | title: 'REPLACE_ME_DOCUMENTER_VITEPRESS', 10 | description: "A VitePress Site", 11 | lastUpdated: true, 12 | cleanUrls: true, 13 | outDir: 'REPLACE_ME_DOCUMENTER_VITEPRESS', // This is required for MarkdownVitepress to work correctly... 14 | head: [['link', { rel: 'icon', href: 'REPLACE_ME_DOCUMENTER_VITEPRESS_FAVICON' }]], 15 | ignoreDeadLinks: true, 16 | 17 | markdown: { 18 | math: true, 19 | config(md) { 20 | md.use(tabsMarkdownPlugin), 21 | md.use(mathjax3), 22 | md.use(footnote) 23 | }, 24 | theme: { 25 | light: "github-light", 26 | dark: "github-dark"} 27 | }, 28 | themeConfig: { 29 | outline: 'deep', 30 | logo: 'REPLACE_ME_DOCUMENTER_VITEPRESS', 31 | search: { 32 | provider: 'local', 33 | options: { 34 | detailedView: true 35 | } 36 | }, 37 | nav: 'REPLACE_ME_DOCUMENTER_VITEPRESS', 38 | sidebar: 'REPLACE_ME_DOCUMENTER_VITEPRESS', 39 | editLink: 'REPLACE_ME_DOCUMENTER_VITEPRESS', 40 | socialLinks: [ 41 | { icon: 'github', link: 'REPLACE_ME_DOCUMENTER_VITEPRESS' } 42 | ], 43 | footer: { 44 | message: 'Made with DocumenterVitepress.jl
', 45 | copyright: `© Copyright ${new Date().getUTCFullYear()}.` 46 | } 47 | } 48 | }) 49 | -------------------------------------------------------------------------------- /.github/workflows/Documenter.yml: -------------------------------------------------------------------------------- 1 | # Sample workflow for building and deploying a VitePress site to GitHub Pages 2 | # 3 | name: Documenter 4 | 5 | on: 6 | # Runs on pushes targeting the `master` branch. Change this to `main` if you're 7 | # using the `main` branch as the default branch. 8 | push: 9 | branches: 10 | - main 11 | tags: ['*'] 12 | pull_request: 13 | 14 | # Allows you to run this workflow manually from the Actions tab 15 | workflow_dispatch: 16 | 17 | # Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages 18 | permissions: 19 | contents: write 20 | pages: write 21 | id-token: write 22 | statuses: write 23 | 24 | # Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. 25 | # However, do NOT cancel in-progress runs as we want to allow these production deployments to complete. 26 | concurrency: 27 | group: pages 28 | cancel-in-progress: false 29 | 30 | jobs: 31 | # Build job 32 | build: 33 | runs-on: ubuntu-latest 34 | steps: 35 | - name: Checkout 36 | uses: actions/checkout@v6 37 | - name: Setup Julia 38 | uses: julia-actions/setup-julia@v2 39 | - name: Pull Julia cache 40 | uses: julia-actions/cache@v2 41 | - name: Install documentation dependencies 42 | run: julia --project=docs -e 'using Pkg; Pkg.develop(PackageSpec(path = pwd())); Pkg.instantiate(); Pkg.precompile(); Pkg.status()' 43 | #- name: Creating new mds from src 44 | - name: Build and deploy docs 45 | uses: julia-actions/julia-docdeploy@v1 46 | env: 47 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token 48 | DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # For authentication with SSH deploy key 49 | GKSwstype: "100" # for Plots.jl plots (if you have them) 50 | JULIA_DEBUG: "Documenter" 51 | DATADEPS_ALWAYS_ACCEPT: true -------------------------------------------------------------------------------- /src/SymbolicAnalysis.jl: -------------------------------------------------------------------------------- 1 | module SymbolicAnalysis 2 | 3 | using DomainSets 4 | using LinearAlgebra 5 | using LogExpFunctions 6 | using StatsBase 7 | using Distributions 8 | using DSP, DataStructures 9 | 10 | using Symbolics 11 | import Symbolics: Symbolic, issym, Term 12 | using SymbolicUtils: iscall 13 | using SymbolicUtils.Rewriters 14 | SymbolicUtils.inspect_metadata[] = true 15 | 16 | struct VarDomain end 17 | 18 | include("rules.jl") 19 | include("atoms.jl") 20 | include("gdcp/gdcp_rules.jl") 21 | include("gdcp/spd.jl") 22 | include("gdcp/lorentz.jl") 23 | include("canon.jl") 24 | 25 | struct AnalysisResult 26 | curvature::SymbolicAnalysis.Curvature 27 | sign::SymbolicAnalysis.Sign 28 | gcurvature::Union{SymbolicAnalysis.GCurvature, Nothing} 29 | end 30 | 31 | """ 32 | analyze(ex) 33 | analyze(ex, M) 34 | 35 | Analyze the expression `ex` and return the curvature and sign of the expression. If a manifold `M` from [Manifolds.jl](https://juliamanifolds.github.io/Manifolds.jl/stable/) is provided, also return the geodesic curvature of the expression. 36 | Currently supports the `SymmetricPositiveDefinite` and `Lorentz` manifolds. 37 | 38 | The returned `AnalysisResult` contains the following fields: 39 | 40 | - `curvature::SymbolicAnalysis.Curvature`: The curvature of the expression. 41 | - `sign::SymbolicAnalysis.Sign`: The sign of the expression. 42 | - `gcurvature::Union{SymbolicAnalysis.GCurvature,Nothing}`: The geodesic curvature of the expression if `M` is provided. Otherwise, `nothing`. 43 | """ 44 | function analyze(ex, M::Union{AbstractManifold, Nothing} = nothing) 45 | ex = unwrap(ex) 46 | ex = canonize(ex) 47 | ex = propagate_sign(ex) 48 | ex = propagate_curvature(ex) 49 | if isnothing(M) 50 | return AnalysisResult(getcurvature(ex), getsign(ex), nothing) 51 | else 52 | @assert M isa SymmetricPositiveDefinite || M isa Lorentz "Only SymmetricPositiveDefinite and Lorentz manifolds are currently supported" 53 | ex = propagate_gcurvature(ex, M) 54 | return AnalysisResult(getcurvature(ex), getsign(ex), getgcurvature(ex)) 55 | end 56 | end 57 | 58 | export analyze 59 | 60 | end 61 | -------------------------------------------------------------------------------- /test/test.jl: -------------------------------------------------------------------------------- 1 | using SymbolicAnalysis 2 | using Symbolics, SymbolicAnalysis.LogExpFunctions 3 | using Symbolics: unwrap 4 | using LinearAlgebra, Test 5 | 6 | @variables x y 7 | y = setmetadata( 8 | y, 9 | SymbolicAnalysis.VarDomain, 10 | Symbolics.DomainSets.HalfLine{Number, :open}() 11 | ) 12 | ex1 = exp(y) - log(y) |> unwrap 13 | ex1 = propagate_curvature(propagate_sign(ex1)) 14 | 15 | @test getcurvature(ex1) == SymbolicAnalysis.Convex 16 | @test getsign(ex1) == SymbolicAnalysis.AnySign 17 | 18 | ex2 = -sqrt(x^2) |> unwrap 19 | ex2 = propagate_curvature(propagate_sign(ex2)) 20 | 21 | @test getcurvature(ex2) == SymbolicAnalysis.UnknownCurvature 22 | @test getsign(ex2) == SymbolicAnalysis.Negative 23 | 24 | ex = -1 * LogExpFunctions.xlogx(x) |> unwrap 25 | ex = propagate_curvature(propagate_sign(ex)) 26 | @test getcurvature(ex) == SymbolicAnalysis.Concave 27 | @test getsign(ex) == SymbolicAnalysis.AnySign 28 | 29 | ex = 2 * abs(x) - 1 |> unwrap 30 | ex = propagate_curvature(propagate_sign(ex)) 31 | @test getcurvature(ex) == SymbolicAnalysis.Convex 32 | @test getsign(ex) == SymbolicAnalysis.AnySign 33 | 34 | # x = setmetadata(x, SymbolicAnalysis.Sign, SymbolicAnalysis.Positive) 35 | ex = abs(x)^2 |> unwrap 36 | ex = propagate_curvature(propagate_sign(ex)) 37 | @test getcurvature(ex) == SymbolicAnalysis.Convex 38 | @test getsign(ex) == SymbolicAnalysis.Positive 39 | 40 | ex = abs(x)^2 + abs(x)^3 |> unwrap 41 | ex = propagate_curvature(propagate_sign(ex)) 42 | @test getcurvature(ex) == SymbolicAnalysis.Convex 43 | @test getsign(ex) == SymbolicAnalysis.Positive 44 | 45 | @variables x[1:3] y 46 | ex = x .- y |> unwrap 47 | ex = propagate_curvature(propagate_sign(ex)) 48 | @test_broken getcurvature(ex) == SymbolicAnalysis.Affine 49 | @test getsign(ex) == SymbolicAnalysis.AnySign 50 | 51 | ex = exp.(x) |> unwrap 52 | ex = propagate_curvature(propagate_sign(ex)) 53 | @test getcurvature(ex) == SymbolicAnalysis.Convex 54 | @test getsign(ex) == SymbolicAnalysis.Positive 55 | 56 | ##vector * scalar gets simplified 57 | 58 | @variables x y z 59 | obj = x^2 + y^2 + z^2 |> unwrap 60 | 61 | ex = propagate_curvature(propagate_sign(obj)) 62 | @test_broken getcurvature(ex) == SymbolicAnalysis.Convex 63 | @test getsign(ex) == SymbolicAnalysis.Positive 64 | 65 | cons = [x + y + z ~ 10 66 | log1p(x)^2 - log1p(z) ≲ 0] 67 | 68 | ex = propagate_curvature(propagate_sign(cons[1].lhs |> unwrap)) 69 | @test getcurvature(ex) == SymbolicAnalysis.Affine 70 | 71 | ex = propagate_curvature(propagate_sign(cons[2].lhs)) 72 | @test getcurvature(ex) == SymbolicAnalysis.Convex 73 | 74 | @variables x y z 75 | 76 | ex = SymbolicAnalysis.quad_over_lin(x - y, 1 - max(x, y)) |> unwrap 77 | ex = propagate_curvature(propagate_sign(ex)) 78 | @test getcurvature(ex) == SymbolicAnalysis.Convex 79 | -------------------------------------------------------------------------------- /docs/src/examples.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | Here are some examples demonstrating the use of the `analyze` function from the `SymbolicAnalysis` package. 4 | 5 | ## Basic Expression Analysis 6 | 7 | ```@example euclidean1 8 | using SymbolicAnalysis, Symbolics, Symbolics.DomainSets 9 | 10 | @variables x 11 | ex1 = exp(x) - log(x) 12 | result = analyze(ex1) 13 | @show result.curvature 14 | ``` 15 | 16 | This example analyzes a simple expression `exp(x) - log(x)`, determining that it's convex and can have any sign. 17 | 18 | ## Analysis on Manifolds 19 | 20 | We can perform DGCP analysis on the Symmetric Positive Definite (SPD) manifold by passing a manifold from [Manifolds.jl](https://juliamanifolds.github.io/Manifolds.jl/stable/) to the `analyze` function. We consider the Karcher mean problem which involves finding the geometric mean of SPD matrices: 21 | 22 | ```@example manifold1 23 | using SymbolicAnalysis, Symbolics, Manifolds, LinearAlgebra 24 | 25 | @variables X[1:5, 1:5] 26 | 27 | M = SymmetricPositiveDefinite(5) 28 | 29 | As = [rand(5, 5) for i in 1:5] 30 | As = [As[i] * As[i]' for i in 1:5] # Make them SPD 31 | 32 | ex2 = sum(Manifolds.distance(M, As[i], X)^2 for i in 1:5) 33 | result = analyze(ex2, M) 34 | @show result.curvature 35 | @show result.gcurvature 36 | ``` 37 | 38 | This analysis shows that the Karcher mean objective function is geodesically convex on the SPD manifold. 39 | 40 | ### Domain aware analysis 41 | 42 | We can also assert the domain of the variable by assigning `VarDomain` metadata that takes a `Domain` from the [DomainSets.jl](https://juliaapproximation.github.io/DomainSets.jl/dev/) package. 43 | 44 | ```@example euclidean1 45 | @variables x y 46 | 47 | x = setmetadata( 48 | x, 49 | SymbolicAnalysis.VarDomain, 50 | OpenInterval(0, 1) 51 | ) 52 | 53 | y = setmetadata( 54 | y, 55 | SymbolicAnalysis.VarDomain, 56 | OpenInterval(0, 1) 57 | ) 58 | 59 | ex = SymbolicAnalysis.quad_over_lin(x - y, 1 - max(x, y)) 60 | result = analyze(ex) 61 | @show result.curvature 62 | ``` 63 | 64 | This example analyzes a quadratic expression over a linear expression, showing that it's convex. 65 | 66 | ## Analysis on the Lorentz Manifold 67 | 68 | We can also perform DGCP analysis on the Lorentz manifold, which is a model of hyperbolic space: 69 | 70 | ```@example lorentz1 71 | using SymbolicAnalysis, Symbolics, Manifolds, LinearAlgebra 72 | 73 | # Create a Lorentz manifold of dimension 2 (3D ambient space) 74 | M = Lorentz(2) 75 | 76 | # Define symbolic variables and fixed points 77 | @variables p[1:3] 78 | q = [0.0, 0.0, 1.0] # A point on the Lorentz model 79 | 80 | # Create a composite function from Lorentz atoms 81 | ex = 2.0 * Manifolds.distance(M, q, p) + 82 | SymbolicAnalysis.lorentz_log_barrier(p) 83 | 84 | # Analyze the expression 85 | result = analyze(ex, M) 86 | @show result.gcurvature 87 | ``` 88 | 89 | This example shows that the sum of the Lorentz distance function and the log-barrier function is geodesically convex on the Lorentz manifold. 90 | -------------------------------------------------------------------------------- /test/lorentz.jl: -------------------------------------------------------------------------------- 1 | using Manifolds, Symbolics, SymbolicAnalysis, LinearAlgebra 2 | using Test 3 | using Symbolics: unwrap 4 | using SymbolicAnalysis: propagate_sign, propagate_curvature, propagate_gcurvature 5 | 6 | @testset "Lorentz Manifold" begin 7 | # Create a Lorentz manifold of dimension 2 (3D ambient space) 8 | M = Lorentz(2) 9 | 10 | # Define symbolic variables 11 | @variables p[1:3] 12 | 13 | # Test lorentz_distance 14 | q = [0.0, 0.0, 1.0] # A point on the Lorentz model 15 | ex = Manifolds.distance(M, q, p) |> unwrap 16 | ex = propagate_sign(ex) 17 | ex = propagate_gcurvature(ex, M) 18 | @test SymbolicAnalysis.getgcurvature(ex) == SymbolicAnalysis.GConvex 19 | 20 | # Test analyze function 21 | analyze_res = analyze(ex, M) 22 | @test analyze_res.gcurvature == SymbolicAnalysis.GConvex 23 | 24 | # Test lorentz_log_barrier 25 | ex = SymbolicAnalysis.lorentz_log_barrier(p) |> unwrap 26 | ex = propagate_sign(ex) 27 | ex = propagate_gcurvature(ex, M) 28 | @test SymbolicAnalysis.getgcurvature(ex) == SymbolicAnalysis.GConvex 29 | 30 | # Test lorentz_homogeneous_quadratic 31 | A = [2.0 0.0 0.0; 0.0 2.0 0.0; 0.0 0.0 1.0] # Positive definite matrix 32 | ex = SymbolicAnalysis.lorentz_homogeneous_quadratic(A, p) |> unwrap 33 | ex = propagate_sign(ex) 34 | ex = propagate_gcurvature(ex, M) 35 | @test SymbolicAnalysis.getgcurvature(ex) == SymbolicAnalysis.GConvex 36 | 37 | # Test lorentz_homogeneous_diagonal 38 | a = [2.0, 2.0, 1.0] # min(a[1:2]) + a[3] = 2 + 1 = 3 ≥ 0 39 | ex = SymbolicAnalysis.lorentz_homogeneous_diagonal(a, p) |> unwrap 40 | ex = propagate_sign(ex) 41 | ex = propagate_gcurvature(ex, M) 42 | @test SymbolicAnalysis.getgcurvature(ex) == SymbolicAnalysis.GConvex 43 | 44 | # Test lorentz_least_squares 45 | # Create a valid X and y that satisfy the geodesic convexity conditions: 46 | # ∑^d_i=1(X'y)^2_i ≤ (X'y)^2_{d+1} and (X'y)_{d+1} ≤ 0 47 | X = [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0] 48 | y = [0.0, 0.0, -1.0] # X'y = [0.0, 0.0, -1.0], which satisfies both conditions 49 | ex = SymbolicAnalysis.lorentz_least_squares(X, y, p) |> unwrap 50 | ex = propagate_sign(ex) 51 | ex = propagate_gcurvature(ex, M) 52 | @test SymbolicAnalysis.getgcurvature(ex) == SymbolicAnalysis.GConvex 53 | 54 | @testset "Least Squares Problem" begin 55 | # Define variables for symbolic testing 56 | @variables p[1:3] 57 | M = Manifolds.Lorentz(2) # 2D Lorentz model (3D ambient) 58 | 59 | # Create a valid test case with data that satisfies geodesic convexity conditions 60 | # We need X and y such that: 61 | # 1. ∑_{i=1}^d (X'y)²_i ≤ (X'y)²_{d+1} 62 | # 2. (X'y)_{d+1} ≤ 0 63 | X = [1.0 0.0 2.0; 0.0 1.0 3.0; 2.0 2.0 10.0] 64 | y = [1.0, 2.0, -5.0] 65 | 66 | # Verify conditions explicitly for the test 67 | Xty = X' * y 68 | @test sum(Xty[1:2] .^ 2) <= Xty[3]^2 # Condition 1 69 | @test Xty[3] <= 0 # Condition 2 70 | 71 | # Compose the least squares problem from atoms 72 | A = X' * X # Positive semidefinite, automatically ∂L-copositive 73 | b = -2 * X' * y # Must be in Lorentz cone 74 | c = y' * y 75 | 76 | # Create expression using lorentz_nonhomogeneous_quadratic 77 | expr = SymbolicAnalysis.lorentz_nonhomogeneous_quadratic(A, b, c, p) 78 | 79 | # Verify geodesic convexity through DGCP framework 80 | expr = propagate_sign(expr) 81 | expr = propagate_gcurvature(expr, M) 82 | @test SymbolicAnalysis.getgcurvature(expr) == SymbolicAnalysis.GConvex 83 | 84 | # Verify that the composition matches the direct expansion 85 | # direct_expr = c - 2 * p' * X' * y + p' * X' * X * p 86 | # @test isequal(simplify(expr), simplify(direct_expr)) 87 | end 88 | # Test composition of functions 89 | ex = 2.0 * Manifolds.distance(M, q, p) + SymbolicAnalysis.lorentz_log_barrier(p) |> 90 | unwrap 91 | ex = propagate_sign(ex) 92 | ex = propagate_gcurvature(ex, M) 93 | @test SymbolicAnalysis.getgcurvature(ex) == SymbolicAnalysis.GConvex 94 | 95 | # Test lorentz_transform (should preserve geodesic convexity) 96 | # Create a Lorentz boost in the x-direction 97 | cosh_phi = 1.2 98 | sinh_phi = sqrt(cosh_phi^2 - 1) 99 | O = [1.0 0.0 0.0; 0.0 cosh_phi -sinh_phi; 0.0 -sinh_phi cosh_phi] 100 | 101 | # Create a compound expression with the transform 102 | q_transformed = O * q 103 | ex = Manifolds.distance(M, q_transformed, p) |> unwrap 104 | ex = propagate_sign(ex) 105 | ex = propagate_gcurvature(ex, M) 106 | @test SymbolicAnalysis.getgcurvature(ex) == SymbolicAnalysis.GConvex 107 | end 108 | -------------------------------------------------------------------------------- /docs/src/.vitepress/theme/style.css: -------------------------------------------------------------------------------- 1 | @import url(https://fonts.googleapis.com/css?family=Space+Mono:regular,italic,700,700italic); 2 | @import url(https://fonts.googleapis.com/css?family=Space+Grotesk:regular,italic,700,700italic); 3 | 4 | /* Customize default theme styling by overriding CSS variables: 5 | https://github.com/vuejs/vitepress/blob/main/src/client/theme-default/styles/vars.css 6 | */ 7 | 8 | /* Layouts */ 9 | 10 | /* 11 | :root { 12 | --vp-layout-max-width: 1440px; 13 | } */ 14 | 15 | .VPHero .clip { 16 | white-space: pre; 17 | max-width: 500px; 18 | } 19 | 20 | /* Fonts */ 21 | 22 | :root { 23 | /* Typography */ 24 | --vp-font-family-base: "Barlow", "Inter var experimental", "Inter var", 25 | -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Oxygen, Ubuntu, 26 | Cantarell, "Fira Sans", "Droid Sans", "Helvetica Neue", sans-serif; 27 | 28 | /* Code Snippet font */ 29 | --vp-font-family-mono: "Space Mono", Menlo, Monaco, Consolas, "Courier New", 30 | monospace; 31 | } 32 | 33 | .mono { 34 | /* 35 | Disable contextual alternates (kind of like ligatures but different) in monospace, 36 | which turns `/>` to an up arrow and `|>` (the Julia pipe symbol) to an up arrow as well. 37 | This is pretty bad for Julia folks reading even though copy+paste retains the same text. 38 | */ 39 | font-feature-settings: 'calt' 0; 40 | } 41 | 42 | /* Colors */ 43 | 44 | :root { 45 | --julia-blue: #4063D8; 46 | --julia-purple: #9558B2; 47 | --julia-red: #CB3C33; 48 | --julia-green: #389826; 49 | 50 | --vp-c-brand: #389826; 51 | --vp-c-brand-light: #3dd027; 52 | --vp-c-brand-lighter: #9499ff; 53 | --vp-c-brand-lightest: #bcc0ff; 54 | --vp-c-brand-dark: #535bf2; 55 | --vp-c-brand-darker: #454ce1; 56 | --vp-c-brand-dimm: #212425; 57 | } 58 | 59 | /* Component: Button */ 60 | 61 | :root { 62 | --vp-button-brand-border: var(--vp-c-brand-light); 63 | --vp-button-brand-text: var(--vp-c-white); 64 | --vp-button-brand-bg: var(--vp-c-brand); 65 | --vp-button-brand-hover-border: var(--vp-c-brand-light); 66 | --vp-button-brand-hover-text: var(--vp-c-white); 67 | --vp-button-brand-hover-bg: var(--vp-c-brand-light); 68 | --vp-button-brand-active-border: var(--vp-c-brand-light); 69 | --vp-button-brand-active-text: var(--vp-c-white); 70 | --vp-button-brand-active-bg: var(--vp-button-brand-bg); 71 | } 72 | 73 | /* Component: Home */ 74 | 75 | :root { 76 | --vp-home-hero-name-color: transparent; 77 | --vp-home-hero-name-background: -webkit-linear-gradient( 78 | 120deg, 79 | #9558B2 30%, 80 | #CB3C33 81 | ); 82 | 83 | --vp-home-hero-image-background-image: linear-gradient( 84 | -45deg, 85 | #9558B2 30%, 86 | #389826 30%, 87 | #CB3C33 88 | ); 89 | --vp-home-hero-image-filter: blur(40px); 90 | } 91 | 92 | @media (min-width: 640px) { 93 | :root { 94 | --vp-home-hero-image-filter: blur(56px); 95 | } 96 | } 97 | 98 | @media (min-width: 960px) { 99 | :root { 100 | --vp-home-hero-image-filter: blur(72px); 101 | } 102 | } 103 | 104 | /* Component: Custom Block */ 105 | 106 | :root.dark { 107 | --vp-custom-block-tip-border: var(--vp-c-brand); 108 | --vp-custom-block-tip-text: var(--vp-c-brand-lightest); 109 | --vp-custom-block-tip-bg: var(--vp-c-brand-dimm); 110 | 111 | /* // Tweak the color palette for blacks and dark grays */ 112 | --vp-c-black: hsl(220 20% 9%); 113 | --vp-c-black-pure: hsl(220, 24%, 4%); 114 | --vp-c-black-soft: hsl(220 16% 13%); 115 | --vp-c-black-mute: hsl(220 14% 17%); 116 | --vp-c-gray: hsl(220 8% 56%); 117 | --vp-c-gray-dark-1: hsl(220 10% 39%); 118 | --vp-c-gray-dark-2: hsl(220 12% 28%); 119 | --vp-c-gray-dark-3: hsl(220 12% 23%); 120 | --vp-c-gray-dark-4: hsl(220 14% 17%); 121 | --vp-c-gray-dark-5: hsl(220 16% 13%); 122 | 123 | /* // Backgrounds */ 124 | /* --vp-c-bg: hsl(240, 2%, 11%); */ 125 | --vp-custom-block-info-bg: hsl(220 14% 17%); 126 | /* --vp-c-gutter: hsl(220 20% 9%); 127 | 128 | --vp-c-bg-alt: hsl(220 20% 9%); 129 | --vp-c-bg-soft: hsl(220 14% 17%); 130 | --vp-c-bg-mute: hsl(220 12% 23%); 131 | */ 132 | } 133 | 134 | /* Component: Algolia */ 135 | 136 | .DocSearch { 137 | --docsearch-primary-color: var(--vp-c-brand) !important; 138 | } 139 | 140 | /* Component: MathJax */ 141 | 142 | mjx-container > svg { 143 | display: block; 144 | margin: auto; 145 | } 146 | 147 | mjx-container { 148 | padding: 0.5rem 0; 149 | } 150 | 151 | mjx-container { 152 | display: inline-block; 153 | margin: auto 2px -2px; 154 | } 155 | 156 | mjx-container > svg { 157 | margin: auto; 158 | display: inline-block; 159 | } 160 | 161 | /** 162 | * Colors links 163 | * -------------------------------------------------------------------------- */ 164 | 165 | :root { 166 | --vp-c-brand-1: #CB3C33; 167 | --vp-c-brand-2: #CB3C33; 168 | --vp-c-brand-3: #CB3C33; 169 | --vp-c-sponsor: #ca2971; 170 | --vitest-c-sponsor-hover: #c13071; 171 | } 172 | 173 | .dark { 174 | --vp-c-brand-1: #91dd33; 175 | --vp-c-brand-2: #91dd33; 176 | --vp-c-brand-3: #91dd33; 177 | --vp-c-sponsor: #91dd33; 178 | --vitest-c-sponsor-hover: #e51370; 179 | } -------------------------------------------------------------------------------- /src/gdcp/lorentz.jl: -------------------------------------------------------------------------------- 1 | # DGCP Atoms for Lorentz Model (Hyperbolic Space) 2 | # 3 | # This file implements geodesically convex atoms for the Lorentz model, 4 | # a Cartan-Hadamard manifold of constant negative curvature. 5 | # Based on the results from Ferreira, Németh, and Zhu (2022, 2023). 6 | 7 | using Manifolds 8 | using LinearAlgebra 9 | using Symbolics: Symbolic, @register_symbolic, unwrap, variables 10 | 11 | @register_symbolic Manifolds.distance( 12 | M::Manifolds.Lorentz, 13 | p::AbstractVector, 14 | q::Union{Symbolics.Arr, AbstractVector} 15 | ) false 16 | add_gdcprule(Manifolds.distance, Manifolds.Lorentz, Positive, GConvex, GAnyMono) 17 | 18 | """ 19 | lorentz_log_barrier(a, p) 20 | 21 | Computes the log-barrier function for the Lorentz model: `-log(-1 - _L)`. 22 | 23 | # Arguments 24 | 25 | - `a`: The vector (0, ..., 0, 1) in R^(d+1). 26 | - `p`: A point on the Lorentz manifold. 27 | """ 28 | function lorentz_log_barrier(p::AbstractVector) 29 | # Lorentzian inner product: a⋅p_L = a1*p1 + ... + a_d*p_d - a_{d+1}*p_{d+1} 30 | inner_prod = a[end] * p[end] 31 | return -log(-1 + inner_prod) 32 | end 33 | 34 | @register_symbolic lorentz_log_barrier(p::Union{Symbolics.Arr, AbstractVector}) 35 | add_gdcprule(lorentz_log_barrier, Manifolds.Lorentz, Positive, GConvex, GIncreasing) 36 | 37 | """ 38 | lorentz_homogeneous_quadratic(A::AbstractMatrix, p::AbstractVector) 39 | 40 | Computes the homogeneous quadratic function f(p) = p'Ap on the Lorentz model. 41 | For geodesic convexity, A must satisfy one of the conditions in Theorem 21. 42 | 43 | # Arguments 44 | 45 | - `A::AbstractMatrix`: A symmetric matrix in R^((d+1)×(d+1)). 46 | - `p::AbstractVector`: A point on the Lorentz manifold. 47 | """ 48 | function lorentz_homogeneous_quadratic(A::AbstractMatrix, p::AbstractVector) 49 | d = size(A, 1) - 1 50 | 51 | # Extract the components from matrix A 52 | A_bar = A[1:d, 1:d] 53 | a_vec = A[1:d, d + 1] 54 | sigma = A[d + 1, d + 1] 55 | 56 | # Compute the minimum eigenvalue of A_bar 57 | lambda_min = minimum(eigvals(A_bar)) 58 | 59 | # Check conditions from Theorem 21 60 | condition1 = isapprox(norm(a_vec), 0, atol = 1e-10) && sigma >= -lambda_min 61 | condition2 = sigma + lambda_min > 2 * sqrt(dot(a_vec, a_vec)) 62 | 63 | if !(condition1 || condition2) 64 | throw(ArgumentError("Matrix A does not satisfy geodesic convexity conditions")) 65 | end 66 | 67 | return p' * A * p 68 | end 69 | 70 | @register_symbolic lorentz_homogeneous_quadratic( 71 | A::AbstractMatrix, 72 | p::Union{Symbolics.Arr, AbstractVector} 73 | ) 74 | add_gdcprule(lorentz_homogeneous_quadratic, Manifolds.Lorentz, Positive, GConvex, GAnyMono) 75 | 76 | """ 77 | lorentz_homogeneous_diagonal(a::AbstractVector, p::AbstractVector) 78 | 79 | Computes the homogeneous diagonal quadratic function `∑(a_i * p_i^2)`. 80 | For geodesic convexity, min(a_1,...,a_d) + a_{d+1} ≥ 0. 81 | 82 | # Arguments 83 | 84 | - `a::AbstractVector`: A (d+1)-vector where min(a_1,...,a_d) + a_{d+1} ≥ 0. 85 | - `p::AbstractVector`: A point on the Lorentz manifold. 86 | """ 87 | function lorentz_homogeneous_diagonal(a::AbstractVector, p::AbstractVector) 88 | if length(a) != length(p) 89 | throw(DimensionMismatch("Vectors must have same length")) 90 | end 91 | 92 | if minimum(a[1:(end - 1)]) + a[end] < 0 93 | throw( 94 | ArgumentError( 95 | "For geodesic convexity, min(a[1:end-1]) + a[end] ≥ 0 is required", 96 | ), 97 | ) 98 | end 99 | 100 | return sum(a .* p .^ 2) 101 | end 102 | 103 | @register_symbolic lorentz_homogeneous_diagonal( 104 | a::AbstractVector, 105 | p::Union{Symbolics.Arr, AbstractVector} 106 | ) 107 | add_gdcprule(lorentz_homogeneous_diagonal, Manifolds.Lorentz, Positive, GConvex, GAnyMono) 108 | 109 | """ 110 | lorentz_nonhomogeneous_quadratic(A::AbstractMatrix, b::AbstractVector, c::Real, p::AbstractVector) 111 | 112 | Computes the non-homogeneous quadratic function f(p) = p'Ap + b'p + c on the Lorentz model. 113 | For geodesic convexity, p'Ap must be geodesically convex and b must be in the Lorentz cone L. 114 | 115 | # Arguments 116 | 117 | - `A::AbstractMatrix`: A symmetric matrix in R^((d+1)×(d+1)). 118 | - `b::AbstractVector`: A vector in R^(d+1) which must be in the Lorentz cone. 119 | - `c::Real`: A constant term. 120 | - `p::AbstractVector`: A point on the Lorentz manifold. 121 | """ 122 | function lorentz_nonhomogeneous_quadratic( 123 | A::AbstractMatrix, 124 | b::AbstractVector, 125 | c::Real, 126 | p::AbstractVector 127 | ) 128 | # Check if b is in the Lorentz cone 129 | b_head = b[1:(end - 1)] 130 | b_tail = b[end] 131 | 132 | if !(norm(b_head)^2 <= b_tail^2 && b_tail >= 0) 133 | throw(ArgumentError("Vector b must be in the Lorentz cone for geodesic convexity")) 134 | end 135 | 136 | # This call will check if A satisfies the geodesic convexity conditions 137 | homogeneous_part = lorentz_homogeneous_quadratic(A, p) 138 | println(size(homogeneous_part)) 139 | affine_part = (Matrix(b') * p) 140 | println(size(affine_part)) 141 | return homogeneous_part + affine_part[1] + c 142 | end 143 | 144 | @register_symbolic lorentz_nonhomogeneous_quadratic( 145 | A::AbstractMatrix, 146 | b::AbstractVector, 147 | c::Real, 148 | p::Vector{Num} 149 | ) 150 | add_gdcprule(lorentz_nonhomogeneous_quadratic, Manifolds.Lorentz, AnySign, GConvex, AnyMono) 151 | 152 | """ 153 | lorentz_least_squares(X::AbstractMatrix, y::AbstractVector, p::AbstractVector) 154 | 155 | Computes the least squares function `‖y - Xp‖²_2 = y'y - 2y'Xp + p'X'Xp` for the Lorentz model. 156 | 157 | # Arguments 158 | 159 | - `X::AbstractMatrix`: A matrix in R^(n×(d+1)). 160 | - `y::AbstractVector`: A vector in R^n. 161 | - `p::AbstractVector`: A point on the Lorentz manifold. 162 | """ 163 | function lorentz_least_squares(X::AbstractMatrix, y::AbstractVector, p::AbstractVector) 164 | A = X' * X # Homogeneous quadratic coefficient 165 | b = -2 * X' * y # Linear coefficient 166 | c = y' * y # Constant term 167 | 168 | # This call will check the geodesic convexity conditions for both 169 | # the homogeneous part (via lorentz_homogeneous_quadratic) and the linear term 170 | return lorentz_nonhomogeneous_quadratic(A, b, c, p) 171 | end 172 | 173 | @register_symbolic lorentz_least_squares(X::Matrix{Num}, y::Vector{Num}, p::Vector{Num}) 174 | add_gdcprule(lorentz_least_squares, Manifolds.Lorentz, Positive, GConvex, AnyMono) 175 | 176 | """ 177 | lorentz_transform(O::AbstractMatrix, p::AbstractVector) 178 | 179 | Applies a Lorentz transform to a point on the Lorentz manifold. 180 | The matrix O must be an element of the orthochronous Lorentz group O⁺(1,d). 181 | 182 | # Arguments 183 | 184 | - `O::AbstractMatrix`: An element of the orthochronous Lorentz group. 185 | - `p::AbstractVector`: A point on the Lorentz manifold. 186 | """ 187 | function lorentz_transform(O::AbstractMatrix, p::AbstractVector) 188 | d = length(p) - 1 189 | J = Diagonal([ones(d)..., -1]) 190 | 191 | # Check if O is in the Lorentz group 192 | if !isapprox(O' * J * O, J, rtol = 1e-10) 193 | throw(ArgumentError("Matrix is not in the Lorentz group")) 194 | end 195 | 196 | # Check if O preserves the positive time direction (orthochronous) 197 | if (O * [zeros(d)..., 1])[end] <= 0 198 | throw(ArgumentError("Matrix does not preserve the positive time direction")) 199 | end 200 | 201 | return O * p 202 | end 203 | 204 | @register_symbolic lorentz_transform( 205 | O::AbstractMatrix, 206 | p::Union{Symbolics.Arr, AbstractVector} 207 | ) 208 | # Not adding a rule since this preserves geodesic convexity but doesn't have a specific curvature 209 | 210 | # Export functions 211 | export lorentz_log_barrier, lorentz_homogeneous_quadratic 212 | export lorentz_homogeneous_diagonal, lorentz_least_squares, lorentz_transform 213 | -------------------------------------------------------------------------------- /src/gdcp/gdcp_rules.jl: -------------------------------------------------------------------------------- 1 | using Manifolds 2 | using Symbolics: @register_symbolic, unwrap 3 | using LinearAlgebra 4 | 5 | # @enum GSign GPositive GNegative GAnySign 6 | @enum GCurvature GConvex GConcave GLinear GUnknownCurvature 7 | @enum GMonotonicity GIncreasing GDecreasing GAnyMono 8 | 9 | const gdcprules_dict = Dict() 10 | 11 | function add_gdcprule(f, manifold, sign, curvature, monotonicity) 12 | if !(monotonicity isa Tuple) 13 | monotonicity = (monotonicity,) 14 | end 15 | gdcprules_dict[f] = makegrule(manifold, sign, curvature, monotonicity) 16 | end 17 | function makegrule(manifold, sign, curvature, monotonicity) 18 | (manifold = manifold, sign = sign, gcurvature = curvature, gmonotonicity = monotonicity) 19 | end 20 | 21 | hasgdcprule(f::Function) = haskey(gdcprules_dict, f) 22 | hasgdcprule(f) = false 23 | gdcprule(f, args...) = gdcprules_dict[f], args 24 | 25 | setgcurvature(ex::Union{Symbolic, Num}, curv) = setmetadata(ex, GCurvature, curv) 26 | setgcurvature(ex, curv) = ex 27 | getgcurvature(ex::Union{Symbolic, Num}) = getmetadata(ex, GCurvature) 28 | getgcurvature(ex) = GLinear 29 | hasgcurvature(ex::Union{Symbolic, Num}) = hasmetadata(ex, GCurvature) 30 | hasgcurvature(ex) = ex isa Real 31 | 32 | function mul_gcurvature(args) 33 | non_constants = findall(x -> issym(x) || iscall(x), args) 34 | constants = findall(x -> !issym(x) && !iscall(x), args) 35 | try 36 | @assert length(non_constants) <= 1 37 | catch 38 | @warn "DGCP does not support multiple non-constant arguments in multiplication" 39 | return GUnknownCurvature 40 | end 41 | if !isempty(non_constants) 42 | expr = args[first(non_constants)] 43 | curv = find_gcurvature(expr) 44 | return if prod(args[constants]) < 0 45 | # flip 46 | curv == GConvex ? GConcave : curv == GConcave ? GConvex : curv 47 | else 48 | curv 49 | end 50 | end 51 | return GLinear 52 | end 53 | 54 | function add_gcurvature(args) 55 | curvs = find_gcurvature.(args) 56 | all(==(GLinear), curvs) && return GLinear 57 | all(x -> x == GConvex || x == GLinear, curvs) && return GConvex 58 | all(x -> x == GConcave || x == GLinear, curvs) && return GConcave 59 | return GUnknownCurvature 60 | end 61 | 62 | function find_gcurvature(ex) 63 | if hasgcurvature(ex) 64 | return getgcurvature(ex) 65 | end 66 | if iscall(ex) 67 | f, args = operation(ex), arguments(ex) 68 | knowngcurv = false 69 | 70 | if hasgdcprule(f) && !any(iscall.(args)) 71 | rule, args = gdcprule(f, args...) 72 | f_curvature = rule.gcurvature 73 | f_monotonicity = rule.gmonotonicity 74 | knowngcurv = true 75 | elseif f == LinearAlgebra.logdet 76 | if operation(args[1]) == conjugation || 77 | operation(args[1]) == LinearAlgebra.diag || 78 | Symbol(operation(args[1])) == :+ || 79 | operation(args[1]) == affine_map || 80 | operation(args[1]) == hadamard_product 81 | return GConvex 82 | end 83 | elseif f == log && 84 | iscall(args[1]) && 85 | (operation(args[1]) == LinearAlgebra.tr || operation(args[1]) == quad_form) 86 | return GConvex 87 | elseif (f == schatten_norm || f == eigsummax) && operation(args[1]) == log 88 | return GConvex 89 | elseif f == sum_log_eigmax && hasdcprule(args[1]) 90 | if dcprule(operation(args[1])) == Convex 91 | return GConvex 92 | else 93 | return GUnknownCurvature 94 | end 95 | elseif f == affine_map 96 | if args[1] == tr || args[1] == conjugation || args[1] == diag 97 | return GConvex 98 | else 99 | return GUnknownCurvature 100 | end 101 | elseif hasgdcprule(f) && any(iscall.(args)) 102 | for i in eachindex(args) 103 | if iscall(args[i]) 104 | if operation(args[i]) == inv 105 | rule, args = gdcprule(f, args...) 106 | f_curvature = rule.gcurvature 107 | f_monotonicity = if rule.gmonotonicity == GIncreasing 108 | GDecreasing 109 | elseif rule.gmonotonicity == GDecreasing 110 | GIncreasing 111 | else 112 | GAnyMono 113 | end 114 | knowngcurv = true 115 | elseif operation(args[i]) == broadcast 116 | rule, args = gdcprule(f, args...) 117 | f_curvature = rule.gcurvature 118 | f_monotonicity = rule.gmonotonicity 119 | knowngcurv = true 120 | elseif operation(args[i]) == affine_map 121 | rule, args = gdcprule(f, args...) 122 | f_curvature = rule.gcurvature 123 | f_monotonicity = rule.gmonotonicity 124 | knowngcurv = true 125 | end 126 | end 127 | end 128 | elseif Symbol(f) == :* 129 | if args[1] isa Number && args[1] > 0 130 | return find_gcurvature(args[2]) 131 | elseif args[1] isa Number && args[1] < 0 132 | argscurv = find_gcurvature(args[2]) 133 | if argscurv == GConvex 134 | return GConcave 135 | elseif argscurv == GConcave 136 | return GConvex 137 | else 138 | argscurv 139 | end 140 | else 141 | @warn "Disciplined Programming does not support multiple non-constant arguments in multiplication" 142 | return GUnknownCurvature 143 | end 144 | end 145 | 146 | if !(knowngcurv) && hasdcprule(f) 147 | rule, args = dcprule(f, args...) 148 | f_curvature = rule.curvature 149 | f_monotonicity = rule.monotonicity 150 | end 151 | 152 | if f_curvature == Convex || f_curvature == Affine 153 | if all(enumerate(args)) do (i, arg) 154 | arg_curv = find_gcurvature(arg) 155 | m = get_arg_property(f_monotonicity, i, args) 156 | # @show arg 157 | if arg_curv == GConvex 158 | m == Increasing 159 | elseif arg_curv == GConcave 160 | m == Decreasing 161 | else 162 | arg_curv == GLinear 163 | end 164 | end 165 | return GConvex 166 | end 167 | elseif f_curvature == Concave || f_curvature == Affine 168 | if all(enumerate(args)) do (i, arg) 169 | arg_curv = find_gcurvature(arg) 170 | m = f_monotonicity[i] 171 | if arg_curv == GConcave 172 | m == Increasing 173 | elseif arg_curv == GConvex 174 | m == Decreasing 175 | else 176 | arg_curv == GLinear 177 | end 178 | end 179 | return GConcave 180 | end 181 | elseif f_curvature == Affine 182 | if all(enumerate(args)) do (i, arg) 183 | arg_curv = find_gcurvature(arg) 184 | arg_curv == GLinear 185 | end 186 | return GLinear 187 | end 188 | elseif f_curvature isa GCurvature 189 | return f_curvature 190 | else 191 | return GUnknownCurvature 192 | end 193 | elseif hasfield(typeof(ex), :val) && operation(ex.val) in keys(gdcprules_dict) 194 | f, args = operation(ex.val), arguments(ex.val) 195 | rule, args = gdcprule(f, args...) 196 | return rule.gcurvature 197 | else 198 | return GLinear 199 | end 200 | return GUnknownCurvature 201 | end 202 | 203 | function propagate_gcurvature(ex, M::AbstractManifold) 204 | r = [@rule *(~~x) => setgcurvature(~MATCH, mul_gcurvature(~~x)) 205 | @rule +(~~x) => setgcurvature(~MATCH, add_gcurvature(~~x)) 206 | @rule ~x => setgcurvature(~x, find_gcurvature(~x))] 207 | ex = Postwalk(Chain(r))(ex) 208 | ex = Prewalk(Chain(r))(ex) 209 | return ex 210 | end 211 | -------------------------------------------------------------------------------- /test/dgp.jl: -------------------------------------------------------------------------------- 1 | using Manifolds, Symbolics, SymbolicAnalysis, LinearAlgebra 2 | using LinearAlgebra, PDMats 3 | using Symbolics: unwrap 4 | using Test, Zygote, ForwardDiff 5 | using SymbolicAnalysis: propagate_sign, propagate_curvature, propagate_gcurvature 6 | 7 | @variables X[1:5, 1:5] 8 | 9 | M = Manifolds.SymmetricPositiveDefinite(5) 10 | 11 | A = rand(5, 5) 12 | A = A * A' 13 | 14 | ex = SymbolicAnalysis.logdet(SymbolicAnalysis.conjugation(inv(X), A)) |> unwrap 15 | ex = propagate_sign(ex) 16 | ex = propagate_curvature(ex) 17 | ex = propagate_gcurvature(ex, M) 18 | SymbolicAnalysis.getcurvature(ex) 19 | @test SymbolicAnalysis.getgcurvature(ex) == SymbolicAnalysis.GConvex 20 | 21 | ex = SymbolicAnalysis.logdet(tr(inv(X))) |> unwrap 22 | ex = propagate_sign(ex) 23 | ex = propagate_curvature(ex) 24 | ex = propagate_gcurvature(ex, M) 25 | @test SymbolicAnalysis.getgcurvature(ex) == SymbolicAnalysis.GConvex 26 | SymbolicAnalysis.getcurvature(ex) 27 | 28 | @variables Sigma[1:5, 1:5] 29 | xs = [rand(5) for i in 1:2] 30 | ex = sum(SymbolicAnalysis.log_quad_form(x, inv(Sigma)) for x in xs) + 31 | 1 / 5 * logdet(Sigma) |> Symbolics.unwrap 32 | analyze_res = SymbolicAnalysis.analyze(ex, M) 33 | @test analyze_res.gcurvature == SymbolicAnalysis.GConvex 34 | 35 | ##Brascamplieb Problem 36 | M = SymmetricPositiveDefinite(5) 37 | objective_expr = logdet(SymbolicAnalysis.conjugation(X, A)) - logdet(X) |> unwrap 38 | objective_expr = SymbolicAnalysis.propagate_sign(objective_expr) 39 | analyze_res = analyze(objective_expr, M) 40 | @test analyze_res.gcurvature == SymbolicAnalysis.GConvex 41 | 42 | objective_expr = SymbolicAnalysis.propagate_gcurvature(objective_expr, M) 43 | @test SymbolicAnalysis.getgcurvature(objective_expr) == SymbolicAnalysis.GConvex 44 | 45 | ex = SymbolicAnalysis.tr(SymbolicAnalysis.conjugation(X, A)) |> unwrap 46 | ex = propagate_sign(ex) 47 | ex = propagate_curvature(ex) 48 | ex = propagate_gcurvature(ex, M) 49 | 50 | @test analyze(ex, M).gcurvature == SymbolicAnalysis.GConvex 51 | 52 | # using Convex 53 | 54 | # X = Convex.Variable(5, 5) 55 | # Y = Convex.Variable(5, 5) 56 | # ex = sqrt(X*Y) 57 | # vexity(ex) 58 | 59 | ## Karcher Mean 60 | As = [rand(5, 5) for i in 1:5] 61 | As = [As[i] * As[i]' for i in 1:5] 62 | 63 | ex = SymbolicAnalysis.sdivergence(X, As[1]) |> unwrap 64 | ex = SymbolicAnalysis.propagate_sign(ex) 65 | ex = SymbolicAnalysis.propagate_gcurvature(ex, M) 66 | 67 | @test SymbolicAnalysis.getgcurvature(ex) == SymbolicAnalysis.GConvex 68 | 69 | ex = sum(SymbolicAnalysis.sdivergence(X, As[i]) for i in 1:5) |> Symbolics.unwrap 70 | ex = SymbolicAnalysis.propagate_sign(ex) 71 | ex = SymbolicAnalysis.propagate_gcurvature(ex, M) 72 | 73 | @test SymbolicAnalysis.getgcurvature(ex) == SymbolicAnalysis.GConvex 74 | 75 | ex = Manifolds.distance(M, As[1], X)^2 |> Symbolics.unwrap 76 | ex = SymbolicAnalysis.propagate_sign(ex) 77 | ex = SymbolicAnalysis.propagate_gcurvature(ex, M) 78 | 79 | @test SymbolicAnalysis.getgcurvature(ex) == SymbolicAnalysis.GConvex 80 | 81 | M = SymmetricPositiveDefinite(5) 82 | objective_expr = sum(Manifolds.distance(M, As[i], X)^2 for i in 1:5) |> Symbolics.unwrap 83 | analyze_res = analyze(objective_expr, M) 84 | @test analyze_res.gcurvature == SymbolicAnalysis.GConvex 85 | 86 | @variables Y[1:5, 1:5] 87 | ex = sqrt(X * Y) 88 | analyze_res = analyze(ex, M) 89 | @test analyze_res.gcurvature == SymbolicAnalysis.GUnknownCurvature 90 | 91 | # ex = exp(X*Y) |> unwrap 92 | # ex = SymbolicAnalysis.propagate_sign(ex) 93 | # @test_throws SymbolicUtils.RuleRewriteError SymbolicAnalysis.propagate_gcurvature(ex) 94 | 95 | # using Manopt, Manifolds, Random, LinearAlgebra, ManifoldDiff 96 | # using ManifoldDiff: grad_distance, prox_distance 97 | # Random.seed!(42); 98 | 99 | # m = 100 100 | # σ = 0.005 101 | # q = Matrix{Float64}(I, 5, 5) .+ 2.0 102 | # data2 = [exp(M, q, σ * rand(M; vector_at=q)) for i in 1:m]; 103 | 104 | # f(M, x) = sum(distance(M, x, data2[i])^2 for i in 1:m) 105 | # f(x) = sum(distance(M, x, data2[i])^2 for i in 1:m) 106 | 107 | # using FiniteDifferences 108 | 109 | # r_backend = ManifoldDiff.RiemannianProjectionBackend( 110 | # ManifoldDiff.FiniteDifferencesBackend() 111 | # ) 112 | # gradf1_FD(M, p) = ManifoldDiff.gradient(M, f, p, r_backend) 113 | 114 | # m1 = gradient_descent(M, f, gradf1_FD, data2[1]; maxiter=1000) 115 | 116 | # ################################ 117 | using Optimization, 118 | OptimizationManopt, Symbolics, Manifolds, Random, LinearAlgebra, SymbolicAnalysis 119 | 120 | M = SymmetricPositiveDefinite(5) 121 | m = 100 122 | σ = 0.005 123 | q = Matrix{Float64}(LinearAlgebra.I(5)) .+ 2.0 124 | 125 | data2 = [exp(M, q, σ * rand(M; vector_at = q)) for i in 1:m]; 126 | 127 | f(x, p = nothing) = sum(SymbolicAnalysis.distance(M, data2[i], x)^2 for i in 1:5) 128 | optf = OptimizationFunction(f, Optimization.AutoZygote()) 129 | prob = OptimizationProblem(optf, data2[1]; manifold = M, structural_analysis = true) 130 | 131 | opt = OptimizationManopt.GradientDescentOptimizer() 132 | @time sol = solve(prob, opt, maxiters = 100) 133 | @test sol.objective < 1e-2 134 | 135 | M = SymmetricPositiveDefinite(5) 136 | xs = [rand(5) for i in 1:5] 137 | 138 | function f(S, p = nothing) 139 | 1 / length(xs) * sum(SymbolicAnalysis.log_quad_form(x, S) for x in xs) + 140 | 1 / 5 * logdet(inv(S)) 141 | end 142 | 143 | optf = OptimizationFunction(f, Optimization.AutoZygote()) 144 | prob = OptimizationProblem( 145 | optf, 146 | Array{Float64}(LinearAlgebra.I(5)); 147 | manifold = M, 148 | structural_analysis = true 149 | ) 150 | 151 | opt = OptimizationManopt.GradientDescentOptimizer() 152 | sol = solve(prob, opt, maxiters = 10) 153 | 154 | A = randn(5, 5) #initialize random matrix 155 | A = A * A' #make it a SPD matrix 156 | 157 | function matsqrt(X, p = nothing) #setup objective function 158 | return SymbolicAnalysis.sdivergence(X, A) + 159 | SymbolicAnalysis.sdivergence(X, Matrix{Float64}(LinearAlgebra.I(5))) 160 | end 161 | 162 | optf = OptimizationFunction(matsqrt, Optimization.AutoZygote()) #setup oracles 163 | prob = OptimizationProblem(optf, A / 2, manifold = M, structural_analysis = true) #setup problem with manifold and initial point 164 | 165 | sol = solve(prob, GradientDescentOptimizer(), maxiters = 1000) #solve the problem 166 | @test sqrt(A) ≈ sol.minimizer rtol = 1e-3 167 | 168 | ex = matsqrt(X) |> unwrap 169 | ex = SymbolicAnalysis.propagate_sign(ex) 170 | ex = SymbolicAnalysis.propagate_gcurvature(ex, M) 171 | 172 | @test SymbolicAnalysis.getgcurvature(ex) == SymbolicAnalysis.GConvex 173 | 174 | ##Diagonal loading 175 | @variables X[1:5, 1:5] 176 | 177 | ex = tr(inv(X)) + logdet(X) |> unwrap 178 | @test analyze(ex, M).gcurvature == SymbolicAnalysis.GConvex 179 | 180 | γ = 1 / 2 181 | ex = (tr(X + γ * I(5)))^(2) |> unwrap 182 | 183 | @test analyze(ex, M).gcurvature == SymbolicAnalysis.GConvex 184 | 185 | d = 10 186 | n = 50 187 | @variables X[1:d, 1:d] 188 | 189 | @variables x[1:5] X[1:5, 1:5] 190 | ex = SymbolicAnalysis.log_quad_form(x, inv(X)) |> unwrap 191 | ex = SymbolicAnalysis.propagate_sign(ex) 192 | ex = SymbolicAnalysis.propagate_gcurvature(ex, M) 193 | @test SymbolicAnalysis.getgcurvature(ex) == SymbolicAnalysis.GConvex 194 | 195 | ys = [rand(5) for i in 1:5] 196 | ex = SymbolicAnalysis.log_quad_form(ys, X) |> unwrap 197 | ex = SymbolicAnalysis.propagate_sign(ex) 198 | ex = SymbolicAnalysis.propagate_gcurvature(ex, M) 199 | @test SymbolicAnalysis.getgcurvature(ex) == SymbolicAnalysis.GConvex 200 | 201 | ex = SymbolicAnalysis.log_quad_form(ys, inv(X)) |> unwrap 202 | ex = SymbolicAnalysis.propagate_sign(ex) 203 | ex = SymbolicAnalysis.propagate_gcurvature(ex, M) 204 | @test SymbolicAnalysis.getgcurvature(ex) == SymbolicAnalysis.GConvex 205 | 206 | ex = SymbolicAnalysis.log_quad_form(ys, X) |> unwrap 207 | ex = SymbolicAnalysis.propagate_sign(ex) 208 | ex = SymbolicAnalysis.propagate_gcurvature(ex, M) 209 | @test SymbolicAnalysis.getgcurvature(ex) == SymbolicAnalysis.GConvex 210 | 211 | ex = sum(SymbolicAnalysis.eigsummax(log(X), 2)) |> unwrap 212 | anres = analyze(ex, M) 213 | @test anres.gcurvature == SymbolicAnalysis.GConvex 214 | 215 | ex = sum(SymbolicAnalysis.schatten_norm(log(X), 3)) |> unwrap 216 | anres = analyze(ex, M) 217 | @test anres.gcurvature == SymbolicAnalysis.GConvex 218 | 219 | ex = exp(SymbolicAnalysis.eigsummax(log(X), 2)) |> unwrap 220 | anres = analyze(ex, M) 221 | @test anres.gcurvature == SymbolicAnalysis.GConvex 222 | 223 | ex = SymbolicAnalysis.sum_log_eigmax(X, 2) |> unwrap 224 | anres = analyze(ex, M) 225 | @test anres.gcurvature == SymbolicAnalysis.GConvex 226 | 227 | ex = SymbolicAnalysis.sum_log_eigmax(exp, X, 2) |> unwrap 228 | anres = analyze(ex, M) 229 | @test anres.gcurvature == SymbolicAnalysis.GConvex 230 | 231 | B = rand(5, 5) 232 | B = B * B' 233 | Ys = [rand(5, 5) for i in 1:5] 234 | Ys = [Y * Y' for Y in Ys] 235 | ex = tr(SymbolicAnalysis.affine_map(SymbolicAnalysis.conjugation, X, B, Ys[1])) |> unwrap 236 | anres = analyze(ex, M) 237 | @test anres.gcurvature == SymbolicAnalysis.GConvex 238 | 239 | ex = SymbolicAnalysis.hadamard_product(X, B) |> unwrap 240 | anres = analyze(ex, M) 241 | @test anres.gcurvature == SymbolicAnalysis.GConvex 242 | 243 | A = rand(5, 5) 244 | A = A * A' 245 | ex = logdet(SymbolicAnalysis.affine_map(SymbolicAnalysis.hadamard_product, X, A, B)) |> 246 | unwrap 247 | anres = analyze(ex, M) 248 | @test anres.gcurvature == SymbolicAnalysis.GConvex 249 | -------------------------------------------------------------------------------- /src/gdcp/spd.jl: -------------------------------------------------------------------------------- 1 | ### DGCP Atoms 2 | 3 | @register_symbolic LinearAlgebra.logdet(X::Matrix{Num}) 4 | add_gdcprule( 5 | LinearAlgebra.logdet, 6 | SymmetricPositiveDefinite, 7 | Positive, 8 | GLinear, 9 | GIncreasing 10 | ) 11 | 12 | """ 13 | conjugation(X, B) 14 | 15 | Conjugation of a matrix `X` by a matrix `B` is defined as `B'X*B`. 16 | 17 | # Arguments 18 | 19 | - `X::Matrix`: A symmetric positive definite matrix. 20 | - `B::Matrix`: A matrix. 21 | """ 22 | function conjugation(X, B) 23 | return B' * X * B 24 | end 25 | 26 | @register_array_symbolic conjugation(X::Union{Symbolics.Arr, Matrix{Num}}, B::Matrix) begin 27 | size = (size(B, 2), size(B, 2)) 28 | end 29 | 30 | add_gdcprule(conjugation, SymmetricPositiveDefinite, Positive, GConvex, GIncreasing) 31 | 32 | @register_symbolic LinearAlgebra.tr(X::Union{Symbolics.Arr, Matrix{Num}}) 33 | add_gdcprule(LinearAlgebra.tr, SymmetricPositiveDefinite, Positive, GConvex, GIncreasing) 34 | 35 | add_gdcprule(sum, SymmetricPositiveDefinite, Positive, GConvex, GIncreasing) 36 | 37 | add_gdcprule(adjoint, SymmetricPositiveDefinite, Positive, GLinear, GIncreasing) 38 | 39 | """ 40 | scalar_mat(X, k=size(X, 1)) 41 | 42 | Scalar matrix of a symmetric positive definite matrix `X` is defined as `tr(X)*I(k)`. 43 | 44 | # Arguments 45 | 46 | - `X::Matrix`: A symmetric positive definite matrix. 47 | - `k::Int`: The size of the identity matrix. 48 | """ 49 | function scalar_mat(X, k = size(X, 1)) 50 | return tr(X) * I(k) 51 | end 52 | 53 | @register_symbolic scalar_mat(X::Union{Symbolics.Arr, Matrix{Num}}, k::Int) 54 | 55 | add_gdcprule(scalar_mat, SymmetricPositiveDefinite, Positive, GConvex, GIncreasing) 56 | 57 | add_gdcprule(LinearAlgebra.diag, SymmetricPositiveDefinite, Positive, GConvex, GIncreasing) 58 | 59 | # """ 60 | # pinching(X, Ps) 61 | 62 | # Pinching of a symmetric positive definite matrix `X` by a set of symmetric positive definite matrices `Ps` is defined as `sum(Ps[i]*X*Ps[i])`. 63 | 64 | # # Arguments 65 | # - `X::Matrix`: A symmetric positive definite matrix. 66 | # - `Ps::Vector`: A vector of symmetric positive definite matrices. 67 | # """ 68 | # function pinching(X, Ps) 69 | # return sum(Ps[i]*X*Ps[i] for i in eachindex(Ps); dims = 1) 70 | # end 71 | 72 | # @register_symbolic pinching(X::Matrix{Num}, Ps::Vector{Union{Symbolics.Arr, Matrix{Num}}}) 73 | 74 | # add_gdcprule(pinching, SymmetricPositiveDefinite, Positive, GConvex, GIncreasing) 75 | 76 | """ 77 | sdivergence(X, Y) 78 | 79 | Symmetric divergence of two symmetric positive definite matrices `X` and `Y` is defined as `logdet((X+Y)/2) - 1/2*logdet(X*Y)`. 80 | 81 | # Arguments 82 | 83 | - `X::Matrix`: A symmetric positive definite matrix. 84 | - `Y::Matrix`: A symmetric positive definite matrix. 85 | """ 86 | function sdivergence(X, Y) 87 | return logdet((X + Y) / 2) - 1 / 2 * logdet(X * Y) 88 | end 89 | 90 | @register_symbolic sdivergence(X::Matrix{Num}, Y::Matrix) 91 | add_gdcprule(sdivergence, SymmetricPositiveDefinite, Positive, GConvex, GIncreasing) 92 | 93 | @register_symbolic Manifolds.distance( 94 | M::Manifolds.SymmetricPositiveDefinite, 95 | X::AbstractMatrix, 96 | Y::Union{Symbolics.Arr, Matrix{Num}} 97 | ) 98 | add_gdcprule(Manifolds.distance, SymmetricPositiveDefinite, Positive, GConvex, GAnyMono) 99 | 100 | # @register_symbolic LinearAlgebra.exp(X::Union{Symbolics.Arr, Matrix{Num}}) 101 | # add_gdcprule(exp, SymmetricPositiveDefinite, Positive, GConvex, GIncreasing) 102 | 103 | # add_gdcprule(sqrt, SymmetricPositiveDefinite, Positive, GConvex, GIncreasing) 104 | 105 | add_gdcprule( 106 | SymbolicAnalysis.quad_form, 107 | SymmetricPositiveDefinite, 108 | Positive, 109 | GConvex, 110 | GIncreasing 111 | ) 112 | 113 | add_gdcprule( 114 | LinearAlgebra.eigmax, 115 | SymmetricPositiveDefinite, 116 | Positive, 117 | GConvex, 118 | GIncreasing 119 | ) 120 | 121 | """ 122 | log_quad_form(y, X) 123 | log_quad_form(ys, X) 124 | 125 | Log of the quadratic form of a symmetric positive definite matrix `X` and a vector `y` is defined as `log(y'*X*y)` or for a vector of vectors `ys` as `log(sum(y'*X*y for y in ys))`. 126 | 127 | # Arguments 128 | 129 | - `y::Vector`: A vector of `Number`s or a `Vector` of `Vector`s. 130 | - `X::Matrix`: A symmetric positive definite matrix. 131 | """ 132 | function log_quad_form(y::Vector{<:Number}, X::Matrix) 133 | return log(y' * X * y) 134 | end 135 | 136 | function log_quad_form(ys::Vector{<:Vector}, X::Matrix) 137 | return log(sum(y' * X * y for y in ys)) 138 | end 139 | 140 | @register_symbolic log_quad_form(y::Vector, X::Matrix{Num}) 141 | add_gdcprule(log_quad_form, SymmetricPositiveDefinite, Positive, GConvex, GIncreasing) 142 | 143 | add_gdcprule(inv, SymmetricPositiveDefinite, Positive, GConvex, GDecreasing) 144 | 145 | add_gdcprule(diag, SymmetricPositiveDefinite, Positive, GConvex, GIncreasing) 146 | 147 | @register_array_symbolic Base.log(X::Matrix{Num}) begin 148 | size = (size(X, 1), size(X, 2)) 149 | end 150 | 151 | add_gdcprule(eigsummax, SymmetricPositiveDefinite, Positive, GConvex, GIncreasing) 152 | 153 | """ 154 | schatten_norm(X, p=2) 155 | 156 | Schatten norm of a symmetric positive definite matrix `X`. 157 | 158 | # Arguments 159 | 160 | - `X::Matrix`: A symmetric positive definite matrix. 161 | - `p::Int`: The p-norm. 162 | """ 163 | function schatten_norm(X::AbstractMatrix, p::Int = 2) 164 | return norm(eigvals(X), p) 165 | end 166 | 167 | @register_symbolic schatten_norm(X::Matrix{Num}, p::Int) 168 | add_gdcprule(schatten_norm, SymmetricPositiveDefinite, Positive, GConvex, GIncreasing) 169 | 170 | """ 171 | sum_log_eigmax(X, k) 172 | sum_log_eigmax(f, X, k) 173 | 174 | Sum of the log of the maximum eigenvalues of a symmetric positive definite matrix `X`. If a function `f` is provided, 175 | the sum is over `f` applied to the log of the eigenvalues. 176 | 177 | # Arguments 178 | 179 | - `f::Function`: A function. 180 | - `X::Matrix`: A symmetric positive definite matrix. 181 | - `k::Int`: The number of eigenvalues to consider. 182 | """ 183 | function sum_log_eigmax(f::Function, X::AbstractMatrix, k::Int) 184 | nrows = size(X, 1) 185 | eigs = eigvals(X, (nrows - k + 1):nrows) 186 | return sum(f.(log.(eigs))) 187 | end 188 | 189 | @register_symbolic sum_log_eigmax(f::Function, X::Matrix{Num}, k::Int) 190 | 191 | function sum_log_eigmax(X::AbstractMatrix, k::Int) 192 | nrows = size(X, 1) 193 | eigs = eigvals(X, (nrows - k + 1):nrows) 194 | return sum((log.(eigs))) 195 | end 196 | 197 | @register_symbolic sum_log_eigmax(X::Matrix{Num}, k::Int) false 198 | add_gdcprule(sum_log_eigmax, SymmetricPositiveDefinite, Positive, GConvex, GIncreasing) 199 | 200 | """ 201 | affine_map(f, X, B, Y) 202 | affine_map(f, X, B, Ys) 203 | 204 | Affine map, i.e., `B + f(X, Y)` or `B + sum(f(X, Y) for Y in Ys)` for a function `f` where `f` is a positive linear operator. 205 | 206 | # Arguments 207 | 208 | - `f::Function`: One of the following functions: `conjugation`, `diag`, `tr` and `hadamard_product`. 209 | - `X::Matrix`: A symmetric positive definite matrix. 210 | - `B::Matrix`: A matrix. 211 | - `Y::Matrix`: A matrix. 212 | - `Ys::Vector{<:Matrix}`: A vector of matrices. 213 | """ 214 | function affine_map(f::typeof(conjugation), X::Matrix, B::Matrix, Y::Matrix) 215 | if !(LinearAlgebra.isposdef(B)) || !(eigvals(Symmetric(B), 1:1)[1] >= 0.0) 216 | throw(DomainError(B, "B must be positive semi-definite.")) 217 | end 218 | return B + conjugation(X, Y) 219 | end 220 | 221 | function affine_map(f::typeof(conjugation), X::Matrix, B::Matrix, Ys::Vector{<:Matrix}) 222 | if !(LinearAlgebra.isposdef(B)) || !(eigvals(Symmetric(B), 1:1)[1] >= 0.0) 223 | throw(DomainError(B, "B must be positive semi-definite.")) 224 | end 225 | return B + sum(conjugation(X, Y) for Y in Ys) 226 | end 227 | 228 | @register_array_symbolic affine_map( 229 | conjf::typeof(conjugation), 230 | X::Matrix{Num}, 231 | B::Matrix, 232 | Y::Union{Matrix, Vector{<:Matrix}} 233 | ) begin 234 | size = (size(B, 1), size(B, 2)) 235 | end 236 | 237 | function affine_map(f::Union{typeof(diag), typeof(tr)}, X::AbstractMatrix, B::AbstractMatrix) 238 | if !(LinearAlgebra.isposdef(B)) || !(eigvals(Symmetric(B), 1:1)[1] >= 0.0) 239 | throw(DomainError(B, "B must be positive semi-definite.")) 240 | end 241 | return B + f(X) 242 | end 243 | 244 | @register_array_symbolic affine_map( 245 | diagtrf::Union{typeof(diag), typeof(tr)}, 246 | X::Matrix{Num}, 247 | B::Matrix 248 | ) begin 249 | size = (size(B, 1), size(B, 2)) 250 | end false 251 | 252 | add_gdcprule(affine_map, SymmetricPositiveDefinite, Positive, GConvex, GIncreasing) 253 | 254 | """ 255 | hadamard_product(X, B) 256 | 257 | Hadamard product or element-wise multiplication of a symmetric positive definite matrix `X` by a positive semi-definite matrix `B`. 258 | 259 | # Arguments 260 | 261 | - `X::Matrix`: A symmetric positive definite matrix. 262 | - `B::Matrix`: A positive semi-definite matrix. 263 | """ 264 | function hadamard_product(X::AbstractMatrix, B::AbstractMatrix) 265 | if (!(LinearAlgebra.isposdef(B)) || !(eigvals(Symmetric(B), 1:1)[1] >= 0.0)) && 266 | !(any(prod(r) == 0.0 for r in eachrow(B))) 267 | throw(DomainError(B, "B must be positive semi-definite and have no zero rows.")) 268 | end 269 | return B .* X 270 | end 271 | 272 | @register_array_symbolic hadamard_product(X::Matrix{Num}, B::Matrix) begin 273 | size = (size(B, 1), size(B, 2)) 274 | end 275 | 276 | add_gdcprule(hadamard_product, SymmetricPositiveDefinite, Positive, GConvex, GIncreasing) 277 | 278 | function affine_map(f::typeof(hadamard_product), X::Matrix, Y::Matrix, B::Matrix) 279 | if !(LinearAlgebra.isposdef(B)) || !(eigvals(Symmetric(B), 1:1)[1] >= 0.0) 280 | throw(DomainError(B, "B must be positive semi-definite.")) 281 | end 282 | return B + hadamard_product(X, Y) 283 | end 284 | 285 | @register_array_symbolic affine_map( 286 | hadamard_product::typeof(hadamard_product), 287 | X::Matrix{Num}, 288 | Y::Matrix, 289 | B::Matrix 290 | ) begin 291 | size = (size(B, 1), size(B, 2)) 292 | end false 293 | -------------------------------------------------------------------------------- /src/rules.jl: -------------------------------------------------------------------------------- 1 | 2 | @enum Sign Positive Negative AnySign 3 | @enum Curvature Convex Concave Affine UnknownCurvature 4 | @enum Monotonicity Increasing Decreasing AnyMono 5 | 6 | struct CustomDomain{T} <: Domain{T} 7 | in::Function 8 | end 9 | 10 | Base.in(x, c::CustomDomain) = c.in(x) 11 | 12 | function array_domain(element_domain) 13 | CustomDomain{AbstractArray}() do xs 14 | all(in(element_domain), xs) 15 | end 16 | end 17 | 18 | function array_domain(element_domain, N) 19 | CustomDomain{AbstractArray{<:Any, N}}() do xs 20 | ndims(xs) == N && all(in(element_domain), xs) 21 | end 22 | end 23 | 24 | function symmetric_domain() 25 | CustomDomain{AbstractArray{<:Any, 2}}(issymmetric) 26 | end 27 | 28 | function semidefinite_domain() 29 | CustomDomain{AbstractArray{<:Any, 2}}(isposdef) #not semi so needs to change 30 | end 31 | 32 | function negsemidefinite_domain() 33 | CustomDomain{AbstractArray{<:Any, 2}}(isposdef ∘ -) #not semi so needs to change 34 | end 35 | 36 | function definite_domain() 37 | CustomDomain{AbstractArray{<:Any, 2}}(isposdef) 38 | end 39 | 40 | function negdefinite_domain() 41 | CustomDomain{AbstractArray{<:Any, 2}}(isposdef ∘ -) 42 | end 43 | 44 | function function_domain() 45 | CustomDomain{Function}(x -> typeassert(x, Function)) 46 | end 47 | 48 | function increasing_if_positive(x) 49 | sign = getsign(x) 50 | sign == AnySign ? AnyMono : sign == Positive ? Increasing : Decreasing 51 | end 52 | 53 | const dcprules_dict = Dict() 54 | 55 | function add_dcprule(f, domain, sign, curvature, monotonicity) 56 | if !(monotonicity isa Tuple) 57 | monotonicity = (monotonicity,) 58 | end 59 | if f in keys(dcprules_dict) 60 | dcprules_dict[f] = vcat(dcprules_dict[f], makerule(domain, sign, curvature, monotonicity)) 61 | else 62 | dcprules_dict[f] = makerule(domain, sign, curvature, monotonicity) 63 | end 64 | end 65 | 66 | function makerule(domain, sign, curvature, monotonicity) 67 | (; domain = domain, sign = sign, curvature = curvature, monotonicity = monotonicity) 68 | end 69 | 70 | hasdcprule(f::Function) = haskey(dcprules_dict, f) 71 | hasdcprule(f) = false 72 | 73 | Symbolics.hasmetadata(::Union{Real, AbstractArray{<:Real}}, args...) = false 74 | 75 | function dcprule(f, args...) 76 | if all(hasmetadata.(args, Ref(VarDomain))) 77 | argsdomain = getmetadata.(args, Ref(VarDomain)) 78 | else 79 | if dcprules_dict[f] isa Vector 80 | return dcprules_dict[f][1], args 81 | else 82 | return dcprules_dict[f], args 83 | end 84 | end 85 | 86 | if dcprules_dict[f] isa Vector 87 | for i in 1:length(dcprules_dict[f]) 88 | if (dcprules_dict[f][i].domain isa Domain) && 89 | all(issubset.(argsdomain, Ref(dcprules_dict[f][i].domain))) 90 | return dcprules_dict[f][i], args 91 | elseif !(dcprules_dict[f][i].domain isa Domain) && 92 | all(issubset.(argsdomain, dcprules_dict[f][i].domain)) 93 | return dcprules_dict[f][i], args 94 | else 95 | throw( 96 | ArgumentError( 97 | "No DCP rule found for $f with arguments $args with domain $argsdomain", 98 | ), 99 | ) 100 | end 101 | end 102 | elseif (dcprules_dict[f].domain isa Domain) && 103 | all(issubset.(argsdomain, Ref(dcprules_dict[f].domain))) 104 | return dcprules_dict[f], args 105 | elseif dcprules_dict[f].domain isa Tuple && 106 | all(issubset.(argsdomain, dcprules_dict[f].domain)) 107 | return dcprules_dict[f], args 108 | else 109 | throw(ArgumentError("No DCP rule found for $f with arguments $args")) 110 | end 111 | end 112 | 113 | ### Sign ### 114 | setsign(ex::Union{Num, Symbolic}, sign) = setmetadata(ex, Sign, sign) 115 | setsign(ex, sign) = ex 116 | 117 | function getsign(ex::Union{Num, Symbolic}) 118 | if hasmetadata(ex, Sign) 119 | return getmetadata(ex, Sign) 120 | end 121 | return AnySign 122 | end 123 | 124 | getsign(ex::Union{AbstractFloat, Integer}) = ex < 0 ? Negative : Positive 125 | 126 | function getsign(ex::AbstractArray) 127 | if all(x -> getsign(x) == Negative, ex) 128 | return Negative 129 | elseif all(x -> getsign(x) == Positive, ex) 130 | return Positive 131 | else 132 | AnySign 133 | end 134 | end 135 | 136 | hassign(ex::Union{Num, Symbolic}) = hasmetadata(ex, Sign) 137 | hassign(ex) = ex isa Real 138 | 139 | hassign(ex::typeof(Base.broadcast)) = true 140 | getsign(ex::typeof(Base.broadcast)) = Positive 141 | 142 | Symbolics.arguments(x::Number) = x 143 | 144 | function add_sign(args) 145 | if hassign(args) 146 | return getsign(args) 147 | end 148 | for i in eachindex(args) 149 | if iscall(args[i]) 150 | args[i] = propagate_sign(args[i]) 151 | end 152 | end 153 | signs = reduce(vcat, getsign.(args)) 154 | if any(==(AnySign), signs) 155 | AnySign 156 | elseif all(==(Negative), signs) 157 | Negative 158 | elseif all(==(Positive), signs) 159 | Positive 160 | else 161 | AnySign 162 | end 163 | end 164 | 165 | function mul_sign(args) 166 | signs = getsign.(args) 167 | if any(==(AnySign), signs) 168 | AnySign 169 | elseif isodd(count(==(Negative), signs)) 170 | Negative 171 | else 172 | Positive 173 | end 174 | end 175 | 176 | function propagate_sign(ex) 177 | # Step 1: set the sign of all variables to be AnySign 178 | rs = [@rule ~x::issym => setsign(~x, AnySign) where {hassign(~x)} 179 | @rule ~x::iscall => setsign(~x, AnySign) where {hassign(~x)} 180 | @rule ~x::issym => setsign(~x, (dcprule(~x))[1].sign) where {hasdcprule(~x)} 181 | @rule ~x::issym => setsign(~x, (gdcprule(~x))[1].sign) where {hasgdcprule(~x)} 182 | @rule ~x::iscall => setsign( 183 | ~x, 184 | (dcprule(operation(~x), arguments(~x)...)[1].sign) 185 | ) where {hasdcprule(operation(~x))} 186 | @rule ~x::iscall => setsign( 187 | ~x, 188 | (gdcprule(operation(~x), arguments(~x)...)[1].sign) 189 | ) where {hasgdcprule(operation(~x))} 190 | @rule *(~~x) => setsign(~MATCH, mul_sign(~~x)) 191 | @rule +(~~x) => setsign(~MATCH, add_sign(~~x))] 192 | rc = Chain(rs) 193 | ex = Postwalk(rc)(ex) 194 | ex = Prewalk(rc)(ex) 195 | return ex 196 | end 197 | 198 | ### Curvature ### 199 | 200 | setcurvature(ex::Union{Num, Symbolic}, curv) = setmetadata(ex, Curvature, curv) 201 | setcurvature(ex, curv) = ex 202 | getcurvature(ex::Union{Num, Symbolic}) = getmetadata(ex, Curvature) 203 | getcurvature(ex) = Affine 204 | hascurvature(ex::Union{Num, Symbolic}) = hasmetadata(ex, Curvature) 205 | hascurvature(ex) = ex isa Real 206 | 207 | function mul_curvature(args) 208 | # all but one arg is constant 209 | non_constants = findall(x -> issym(x) || iscall(x), args) 210 | constants = findall(x -> !issym(x) && !iscall(x), args) 211 | try 212 | @assert length(non_constants) <= 1 213 | catch 214 | @warn "DCP does not support multiple non-constant arguments in multiplication" 215 | return UnknownCurvature 216 | end 217 | 218 | if !isempty(non_constants) 219 | expr = args[first(non_constants)] 220 | curv = find_curvature(expr) 221 | return if getsign(prod(args[constants])) == Negative 222 | # flip 223 | curv == Convex ? Concave : curv == Concave ? Convex : curv 224 | else 225 | curv 226 | end 227 | end 228 | return Affine 229 | end 230 | 231 | function add_curvature(args) 232 | curvs = find_curvature.(args) 233 | all(==(Affine), curvs) && return Affine 234 | all(x -> x == Convex || x == Affine, curvs) && return Convex 235 | all(x -> x == Concave || x == Affine, curvs) && return Concave 236 | return UnknownCurvature 237 | end 238 | 239 | function propagate_curvature(ex) 240 | rs = [@rule *(~~x) => setcurvature(~MATCH, mul_curvature(~~x)) 241 | @rule +(~~x) => setcurvature(~MATCH, add_curvature(~~x)) 242 | @rule ~x => setcurvature(~x, find_curvature(~x))] 243 | rc = Chain(rs) 244 | ex = Postwalk(rc)(ex) 245 | ex = Prewalk(rc)(ex) 246 | # SymbolicUtils.inspect(ex, metadata = true) 247 | return ex 248 | end 249 | 250 | function get_arg_property(monotonicity, i, args) 251 | @label start 252 | if monotonicity isa Function 253 | monotonicity(args[i]) 254 | elseif monotonicity isa Tuple && i <= length(monotonicity) 255 | monotonicity = monotonicity[i] 256 | @goto start 257 | else 258 | monotonicity 259 | end 260 | end 261 | 262 | function find_curvature(ex) 263 | if hascurvature(ex) 264 | return getcurvature(ex) 265 | end 266 | 267 | if iscall(ex) 268 | f, args = operation(ex), arguments(ex) 269 | # @show f 270 | if hasdcprule(f) 271 | rule, args = dcprule(f, args...) 272 | elseif Symbol(f) == :* 273 | if args[1] isa Number && args[1] > 0 274 | return find_curvature(args[2]) 275 | elseif args[1] isa Number && args[1] < 0 276 | argscurv = find_curvature(args[2]) 277 | if argscurv == Convex 278 | return Concave 279 | elseif argscurv == Concave 280 | return Convex 281 | else 282 | argscurv 283 | end 284 | else 285 | @warn "DCP does not support multiple non-constant arguments in multiplication" 286 | return UnknownCurvature 287 | end 288 | else 289 | return UnknownCurvature 290 | end 291 | f_curvature = rule.curvature 292 | f_monotonicity = rule.monotonicity 293 | 294 | if f_curvature == Affine 295 | if all(enumerate(args)) do (i, arg) 296 | arg_curv = find_curvature(arg) 297 | arg_curv == Affine 298 | end 299 | return Affine 300 | end 301 | elseif f_curvature == Convex || f_curvature == Affine 302 | if all(enumerate(args)) do (i, arg) 303 | arg_curv = find_curvature(arg) 304 | m = get_arg_property(f_monotonicity, i, args) 305 | # @show f_monotonicity 306 | # @show arg 307 | # @show m 308 | if arg_curv == Convex 309 | m == Increasing 310 | elseif arg_curv == Concave 311 | m == Decreasing 312 | else 313 | arg_curv == Affine 314 | end 315 | end 316 | return Convex 317 | end 318 | elseif f_curvature == Concave || f_curvature == Affine 319 | if all(enumerate(args)) do (i, arg) 320 | arg_curv = find_curvature(arg) 321 | m = f_monotonicity[i] 322 | if arg_curv == Concave 323 | m == Increasing 324 | elseif arg_curv == Convex 325 | m == Decreasing 326 | else 327 | arg_curv == Affine 328 | end 329 | end 330 | return Concave 331 | end 332 | end 333 | return UnknownCurvature 334 | else 335 | return Affine 336 | end 337 | end 338 | -------------------------------------------------------------------------------- /docs/src/atoms.md: -------------------------------------------------------------------------------- 1 | # Atoms for DCP and DGCP 2 | 3 | This page is intended to be a reference for the atoms that are currently implemented in this package with their respective properties. As much as possible atoms are created with functions from base, standard libraries and popular packages, but we also inherit a few functions from the CVX family of packages such as `quad_form`, `quad_over_lin` etc. and also introduce some new functions in this package. Description of all such special functions implemented in this package is available in the [Special functions](@ref) section of the documentation. 4 | 5 | ## DCP Atoms 6 | 7 | | Atom | Domain | Sign | Curvature | Monotonicity | 8 | |:------------------------- |:------------------------------------------------------------------------------ |:--------- |:--------- |:------------------------------------------- | 9 | | dot | (array_domain(ℝ), array_domain(ℝ)) | AnySign | Affine | Increasing | 10 | | dotsort | (array_domain(ℝ, 1), array_domain(ℝ, 1)) | AnySign | Convex | (AnyMono, increasing_if_positive ∘ minimum) | 11 | | StatsBase.geomean | array_domain(HalfLine{Real,:open}(), 1) | Positive | Concave | Increasing | 12 | | StatsBase.harmmean | array_domain(HalfLine{Real,:open}(), 1) | Positive | Concave | Increasing | 13 | | invprod | array_domain(HalfLine{Real,:open}()) | Positive | Convex | Decreasing | 14 | | eigmax | symmetric_domain() | AnySign | Convex | AnyMono | 15 | | eigmin | symmetric_domain() | AnySign | Concave | AnyMono | 16 | | eigsummax | (array_domain(ℝ, 2), ℝ) | AnySign | Convex | AnyMono | 17 | | eigsummin | (array_domain(ℝ, 2), ℝ) | AnySign | Concave | AnyMono | 18 | | logdet | semidefinite_domain() | AnySign | Concave | AnyMono | 19 | | LogExpFunctions.logsumexp | array_domain(ℝ, 2) | AnySign | Convex | Increasing | 20 | | matrix_frac | (array_domain(ℝ, 1), definite_domain()) | AnySign | Convex | AnyMono | 21 | | maximum | array_domain(ℝ) | AnySign | Convex | Increasing | 22 | | minimum | array_domain(ℝ) | AnySign | Concave | Increasing | 23 | | norm | (array_domain(ℝ), Interval{:closed, :open}(1, Inf)) | Positive | Convex | increasing_if_positive | 24 | | norm | (array_domain(ℝ), Interval{:closed, :open}(0, 1)) | Positive | Convex | increasing_if_positive | 25 | | perspective(f, x, s) | (function_domain(), ℝ, Positive) | Same as f | Same as f | AnyMono | 26 | | quad_form | (array_domain(ℝ, 1), semidefinite_domain()) | Positive | Convex | (increasing_if_positive, Increasing) | 27 | | quad_over_lin | (array_domain(ℝ), HalfLine{Real,:open}()) | Positive | Convex | (increasing_if_positive, Decreasing) | 28 | | quad_over_lin | (ℝ, HalfLine{Real,:open}()) | Positive | Convex | (increasing_if_positive, Decreasing) | 29 | | sum | array_domain(ℝ, 2) | AnySign | Affine | Increasing | 30 | | sum_largest | (array_domain(ℝ, 2), ℤ) | AnySign | Convex | Increasing | 31 | | sum_smallest | (array_domain(ℝ, 2), ℤ) | AnySign | Concave | Increasing | 32 | | tr | array_domain(ℝ, 2) | AnySign | Affine | Increasing | 33 | | trinv | definite_domain() | Positive | Convex | AnyMono | 34 | | tv | array_domain(ℝ, 1) | Positive | Convex | AnyMono | 35 | | tv | array_domain(array_domain(ℝ, 2), 1) | Positive | Convex | AnyMono | 36 | | abs | ℂ | Positive | Convex | increasing_if_positive | 37 | | conj | ℂ | AnySign | Affine | AnyMono | 38 | | exp | ℝ | Positive | Convex | Increasing | 39 | | xlogx | ℝ | AnySign | Convex | AnyMono | 40 | | huber | (ℝ, HalfLine()) | Positive | Convex | increasing_if_positive | 41 | | imag | ℂ | AnySign | Affine | AnyMono | 42 | | inv | HalfLine{Real,:open}() | Positive | Convex | Decreasing | 43 | | log | HalfLine{Real,:open}() | AnySign | Concave | Increasing | 44 | | log | array_domain(ℝ, 2) | Positive | Concave | Increasing | 45 | | inv | semidefinite_domain() | AnySign | Convex | Decreasing | 46 | | sqrt | semidefinite_domain() | Positive | Concave | Increasing | 47 | | kldivergence | (array_domain(HalfLine{Real,:open}, 1), array_domain(HalfLine{Real,:open}, 1)) | Positive | Convex | AnyMono | 48 | | lognormcdf | ℝ | Negative | Concave | Increasing | 49 | | log1p | Interval{:open,:open}(-1, Inf) | Negative | Concave | Increasing | 50 | | logistic | ℝ | Positive | Convex | Increasing | 51 | | max | (ℝ, ℝ) | AnySign | Convex | Increasing | 52 | | min | (ℝ, ℝ) | AnySign | Concave | Increasing | 53 | | ^(x, i) | See below | See below | See below | See below | 54 | | real | ℂ | AnySign | Affine | Increasing | 55 | | rel_entr | (HalfLine{Real,:open}(), HalfLine{Real,:open}()) | AnySign | Convex | (AnyMono, Decreasing) | 56 | | sqrt | HalfLine() | Positive | Concave | Increasing | 57 | | xexpx | HalfLine | Positive | Convex | Increasing | 58 | | conv | (array_domain(ℝ, 1), array_domain(ℝ, 1)) | AnySign | Affine | AnyMono | 59 | | cumsum | array_domain(ℝ) | AnySign | Affine | Increasing | 60 | | diagm | array_domain(ℝ, 1) | AnySign | Affine | Increasing | 61 | | diag | array_domain(ℝ, 2) | AnySign | Affine | Increasing | 62 | | diff | array_domain(ℝ) | AnySign | Affine | Increasing | 63 | | kron | (array_domain(ℝ, 2), array_domain(ℝ, 2)) | AnySign | Affine | Increasing | 64 | 65 | ### Special Cases for ^(x, i) 66 | 67 | | Condition on i | Domain | Sign | Curvature | Monotonicity | 68 | |:----------------- |:--------------------------- |:-------- |:--------- |:---------------------- | 69 | | i = 1 | ℝ | AnySign | Affine | Increasing | 70 | | i is even integer | ℝ | Positive | Convex | increasing_if_positive | 71 | | i is odd integer | HalfLine() | Positive | Convex | Increasing | 72 | | i ≥ 1 | HalfLine() | Positive | Convex | Increasing | 73 | | 0 < i < 1 | HalfLine() | Positive | Concave | Increasing | 74 | | i < 0 | HalfLine{Float64,:closed}() | Positive | Convex | Increasing | 75 | 76 | ## DGCP Atoms (Symmetric Positive Definite) 77 | 78 | | Atom | Sign | Geodesic Curvature | Monotonicity | 79 | |:-------------------------- |:-------- |:------------------ |:------------ | 80 | | LinearAlgebra.logdet | Positive | GLinear | GIncreasing | 81 | | conjugation | Positive | GConvex | GIncreasing | 82 | | LinearAlgebra.tr | Positive | GConvex | GIncreasing | 83 | | sum | Positive | GConvex | GIncreasing | 84 | | adjoint | Positive | GLinear | GIncreasing | 85 | | scalar_mat | Positive | GConvex | GIncreasing | 86 | | LinearAlgebra.diag | Positive | GConvex | GIncreasing | 87 | | sdivergence | Positive | GConvex | GIncreasing | 88 | | Manifolds.distance | Positive | GConvex | GAnyMono | 89 | | SymbolicAnalysis.quad_form | Positive | GConvex | GIncreasing | 90 | | LinearAlgebra.eigmax | Positive | GConvex | GIncreasing | 91 | | log_quad_form | Positive | GConvex | GIncreasing | 92 | | inv | Positive | GConvex | GDecreasing | 93 | | diag | Positive | GConvex | GIncreasing | 94 | | eigsummax | Positive | GConvex | GIncreasing | 95 | | schatten_norm | Positive | GConvex | GIncreasing | 96 | | sum_log_eigmax | Positive | GConvex | GIncreasing | 97 | | affine_map | Positive | GConvex | GIncreasing | 98 | | hadamard_product | Positive | GConvex | GIncreasing | 99 | 100 | ## DGCP Atoms (Lorentz Model) 101 | 102 | | Atom | Sign | Geodesic Curvature | Monotonicity | 103 | |:----------------------------- |:-------- |:------------------ |:------------ | 104 | | lorentz_distance | Positive | GConvex | GAnyMono | 105 | | lorentz_log_barrier | Positive | GConvex | GIncreasing | 106 | | lorentz_homogeneous_quadratic | Positive | GConvex | GAnyMono | 107 | | lorentz_homogeneous_diagonal | Positive | GConvex | GAnyMono | 108 | | lorentz_least_squares | Positive | GConvex | GAnyMono | 109 | | lorentz_transform | - | - | - | 110 | 111 | Note: `lorentz_transform` does not have specific geodesic curvature properties by itself, but it preserves geodesic convexity when applied to geodesically convex functions. 112 | -------------------------------------------------------------------------------- /src/atoms.jl: -------------------------------------------------------------------------------- 1 | ### DCP atom rules 2 | 3 | add_dcprule(+, RealLine(), AnySign, Affine, Increasing) 4 | add_dcprule(-, RealLine(), AnySign, Affine, Decreasing) 5 | 6 | add_dcprule(Base.Ref, RealLine(), AnySign, Affine, AnyMono) 7 | 8 | add_dcprule( 9 | dot, 10 | (array_domain(RealLine()), array_domain(RealLine())), 11 | AnySign, 12 | Affine, 13 | Increasing 14 | ) 15 | 16 | """ 17 | dotsort(x, y) 18 | 19 | Sorts `x` and `y` and returns the dot product of the sorted vectors. 20 | 21 | # Arguments 22 | 23 | - `x::AbstractVector`: A vector. 24 | - `y::AbstractVector`: A vector. 25 | """ 26 | function dotsort(x::AbstractVector, y::AbstractVector) 27 | if length(x) != length(y) 28 | throw(DimensionMismatch("AbstractVectors must have same length")) 29 | end 30 | return dot(sort.(x), sort.(y)) 31 | end 32 | Symbolics.@register_symbolic dotsort(x::AbstractVector, y::AbstractVector) 33 | add_dcprule( 34 | dotsort, 35 | (array_domain(RealLine(), 1), array_domain(RealLine(), 1)), 36 | AnySign, 37 | Convex, 38 | (AnyMono, increasing_if_positive ∘ minimum) 39 | ) 40 | 41 | add_dcprule( 42 | StatsBase.geomean, 43 | array_domain(HalfLine{Real, :open}(), 1), 44 | Positive, 45 | Concave, 46 | Increasing 47 | ) 48 | add_dcprule( 49 | StatsBase.harmmean, 50 | array_domain(HalfLine{Real, :open}(), 1), 51 | Positive, 52 | Concave, 53 | Increasing 54 | ) 55 | 56 | """ 57 | invprod(x::AbstractVector) 58 | 59 | Returns the inverse of the product of the elements of `x`. 60 | 61 | # Arguments 62 | 63 | - `x::AbstractVector`: A vector. 64 | """ 65 | function invprod(x::AbstractVector) 66 | if any(iszero(x)) 67 | throw(DivideError()) 68 | end 69 | inv(prod(x)) 70 | end 71 | Symbolics.@register_symbolic invprod(x::AbstractVector) 72 | 73 | add_dcprule(invprod, array_domain(HalfLine{Real, :open}()), Positive, Convex, Decreasing) 74 | 75 | add_dcprule(eigmax, symmetric_domain(), AnySign, Convex, AnyMono) 76 | 77 | add_dcprule(eigmin, symmetric_domain(), AnySign, Concave, AnyMono) 78 | 79 | """ 80 | eigsummax(m::Symmetric, k) 81 | 82 | Returns the sum of the `k` largest eigenvalues of `m`. 83 | 84 | # Arguments 85 | 86 | - `m::Symmetric`: A symmetric matrix. 87 | - `k::Int`: The Real of largest eigenvalues to sum. 88 | """ 89 | function eigsummax(m::Symmetric, k::Int) 90 | if k < 1 || k > size(m, 1) 91 | throw(DomainError(k, "k must be between 1 and size(m, 1)")) 92 | end 93 | nrows = size(m, 1) 94 | return sum(eigvals(m, (nrows - k + 1):nrows)) 95 | end 96 | Symbolics.@register_symbolic eigsummax(m::Matrix, k::Int) 97 | add_dcprule(eigsummax, (array_domain(RealLine(), 2), RealLine()), AnySign, Convex, AnyMono) 98 | 99 | """ 100 | eigsummin(m::Symmetric, k) 101 | 102 | Returns the sum of the `k` smallest eigenvalues of `m`. 103 | 104 | # Arguments 105 | 106 | - `m::Symmetric`: A symmetric matrix. 107 | - `k::Int`: The Real of smallest eigenvalues to sum. 108 | """ 109 | function eigsummin(m::Symmetric, k::Int) 110 | if k < 1 || k > size(m, 1) 111 | throw(DomainError(k, "k must be between 1 and size(m, 1)")) 112 | end 113 | return sum(eigvals(m, 1:k)) 114 | end 115 | Symbolics.@register_symbolic eigsummin(m::Matrix, k::Int) 116 | add_dcprule(eigsummin, (array_domain(RealLine(), 2), RealLine()), AnySign, Concave, AnyMono) 117 | 118 | add_dcprule(logdet, semidefinite_domain(), AnySign, Concave, AnyMono) 119 | 120 | add_dcprule( 121 | LogExpFunctions.logsumexp, 122 | array_domain(RealLine(), 2), 123 | AnySign, 124 | Convex, 125 | Increasing 126 | ) 127 | 128 | """ 129 | matrix_frac(x::AbstractVector, P::AbstractMatrix) 130 | 131 | Returns the quadratic form `x' * P^{-1} * x`. 132 | 133 | # Arguments 134 | 135 | - `x::AbstractVector`: A vector. 136 | - `P::AbstractMatrix`: A matrix. 137 | """ 138 | function matrix_frac(x::AbstractVector, P::AbstractMatrix) 139 | if length(x) != size(P, 1) 140 | throw(DimensionMismatch("x and P must have same length")) 141 | end 142 | return x' * inv(P) * x 143 | end 144 | Symbolics.@register_symbolic AbstractMatrix_frac(x::AbstractVector, P::AbstractMatrix) 145 | add_dcprule( 146 | matrix_frac, 147 | (array_domain(RealLine(), 1), definite_domain()), 148 | AnySign, 149 | Convex, 150 | AnyMono 151 | ) 152 | 153 | add_dcprule(maximum, array_domain(RealLine()), AnySign, Convex, Increasing) 154 | 155 | add_dcprule(minimum, array_domain(RealLine()), AnySign, Concave, Increasing) 156 | 157 | #incorrect for p<1 158 | add_dcprule( 159 | norm, 160 | (array_domain(RealLine()), Interval{:closed, :open}(1, Inf)), 161 | Positive, 162 | Convex, 163 | increasing_if_positive 164 | ) 165 | add_dcprule( 166 | norm, 167 | (array_domain(RealLine()), Interval{:closed, :open}(0, 1)), 168 | Positive, 169 | Convex, 170 | increasing_if_positive 171 | ) 172 | 173 | """ 174 | perspective(f::Function, x, s::Real) 175 | 176 | Returns the perspective function `s * f(x / s)`. 177 | 178 | # Arguments 179 | 180 | - `f::Function`: A function. 181 | - `x`: A Real. 182 | - `s::Real`: A positive Real. 183 | """ 184 | function perspective(f::Function, x, s::Real) 185 | if s < 0 186 | throw(DomainError(s, "s must be positive")) 187 | end 188 | if s == 0 189 | return zero(typeof(f(x))) 190 | end 191 | s * f(x / s) 192 | end 193 | Symbolics.@register_symbolic perspective(f::Function, x, s::Real) 194 | add_dcprule( 195 | perspective, 196 | (function_domain(), RealLine(), Positive), 197 | getsign, 198 | getcurvature, 199 | AnyMono 200 | ) 201 | 202 | """ 203 | quad_form(x::AbstractVector, P::AbstractMatrix) 204 | 205 | Returns the quadratic form `x' * P * x`. 206 | 207 | # Arguments 208 | 209 | - `x::AbstractVector`: A vector. 210 | - `P::AbstractMatrix`: A matrix. 211 | """ 212 | function quad_form(x::AbstractVector, P::AbstractMatrix) 213 | if length(x) != size(P, 1) 214 | throw(DimensionMismatch("x and P must have same length")) 215 | end 216 | return x' * P * x 217 | end 218 | Symbolics.@register_symbolic quad_form(x::AbstractVector, P::AbstractMatrix) 219 | add_dcprule( 220 | quad_form, 221 | (array_domain(RealLine(), 1), semidefinite_domain()), 222 | Positive, 223 | Convex, 224 | (increasing_if_positive, Increasing) 225 | ) 226 | 227 | function quad_over_lin(x::Vector{<:Real}, y::Real) 228 | if getsign(y) == Negative 229 | throw(DomainError(y, "y must be positive")) 230 | end 231 | return sum(x .^ 2) / y 232 | end 233 | 234 | Symbolics.@register_symbolic quad_over_lin(x::AbstractVector, y::Real) false 235 | 236 | """ 237 | quad_over_lin(x::Real, y::Real) 238 | 239 | Returns the quadratic over linear form `x^2 / y`. 240 | 241 | # Arguments 242 | 243 | - `x`: A Real or a vector. 244 | - `y::Real`: A positive Real. 245 | """ 246 | function quad_over_lin(x::Real, y::Real) 247 | if getsign(y) == Negative 248 | throw(DomainError(y, "y must be positive")) 249 | end 250 | return x^2 / y 251 | end 252 | 253 | Symbolics.@register_symbolic quad_over_lin(x::Real, y::Real) 254 | 255 | add_dcprule( 256 | quad_over_lin, 257 | (array_domain(RealLine()), HalfLine{Real, :open}()), 258 | Positive, 259 | Convex, 260 | (increasing_if_positive, Decreasing) 261 | ) 262 | 263 | add_dcprule( 264 | quad_over_lin, 265 | (RealLine(), HalfLine{Real, :open}()), 266 | Positive, 267 | Convex, 268 | (increasing_if_positive, Decreasing) 269 | ) 270 | 271 | add_dcprule(sum, array_domain(RealLine(), 2), AnySign, Affine, Increasing) 272 | 273 | """ 274 | sum_largest(x::AbstractMatrix, k) 275 | 276 | Returns the sum of the `k` largest elements of `x`. 277 | 278 | # Arguments 279 | 280 | - `x::AbstractMatrix`: A matrix. 281 | - `k::Int`: The Real of largest elements to sum. 282 | """ 283 | function sum_largest(x::AbstractMatrix, k::Integer) 284 | return sum(sort(vec(x))[(end - k):end]) 285 | end 286 | Symbolics.@register_symbolic sum_largest(x::AbstractMatrix, k::Integer) 287 | add_dcprule(sum_largest, (array_domain(RealLine(), 2), ℤ), AnySign, Convex, Increasing) 288 | 289 | """ 290 | sum_smallest(x::AbstractMatrix, k) 291 | 292 | Returns the sum of the `k` smallest elements of `x`. 293 | 294 | # Arguments 295 | 296 | - `x::AbstractMatrix`: A matrix. 297 | - `k::Int`: The Real of smallest elements to sum. 298 | """ 299 | function sum_smallest(x::AbstractMatrix, k::Integer) 300 | return sum(sort(vec(x))[1:k]) 301 | end 302 | 303 | Symbolics.@register_symbolic sum_smallest(x::AbstractArray, k::Integer) 304 | add_dcprule(sum_smallest, (array_domain(RealLine(), 2), ℤ), AnySign, Concave, Increasing) 305 | 306 | add_dcprule(tr, array_domain(RealLine(), 2), AnySign, Affine, Increasing) 307 | 308 | """ 309 | trinv(x::AbstractMatrix) 310 | 311 | Returns the trace of the inverse of `x`. 312 | 313 | # Arguments 314 | 315 | - `x::AbstractMatrix`: A matrix. 316 | """ 317 | function trinv(x::AbstractMatrix) 318 | return tr(inv(x)) 319 | end 320 | Symbolics.@register_symbolic trinv(x::AbstractMatrix) 321 | add_dcprule(trinv, definite_domain(), Positive, Convex, AnyMono) 322 | 323 | """ 324 | tv(x::AbstractVector{<:Real}) 325 | 326 | Returns the total variation of `x`, defined as `sum_i |x_{i+1} - x_i|`. 327 | 328 | # Arguments 329 | 330 | - `x::AbstractVector`: A vector. 331 | """ 332 | function tv(x::Vector{<:Real}) 333 | return sum(abs.(x[2:end] - x[1:(end - 1)])) 334 | end 335 | Symbolics.@register_symbolic tv(x::AbstractVector) false 336 | add_dcprule(tv, array_domain(RealLine(), 1), Positive, Convex, AnyMono) 337 | 338 | """ 339 | tv(x::AbstractVector{<:AbstractMatrix}) 340 | 341 | Returns the total variation of `x`, defined as `sum_{i,j} |x_{k+1}[i,j] - x_k[i,j]|`. 342 | 343 | # Arguments 344 | 345 | - `x::AbstractVector`: A vector of matrices. 346 | """ 347 | function tv(x::AbstractVector{<:AbstractMatrix}) 348 | return sum(map(1:(size(x, 1) - 1)) do i 349 | map(1:(size(x, 2) - 1)) do j 350 | norm([x[k][i + 1, j] - x[k][i, j] for k in eachindex(x)]) 351 | end 352 | end) 353 | end 354 | add_dcprule(tv, array_domain(array_domain(RealLine(), 2), 1), Positive, Convex, AnyMono) 355 | 356 | add_dcprule(abs, ℂ, Positive, Convex, increasing_if_positive) 357 | 358 | add_dcprule(conj, ℂ, AnySign, Affine, AnyMono) 359 | 360 | add_dcprule(exp, RealLine(), Positive, Convex, Increasing) 361 | 362 | Symbolics.@register_symbolic LogExpFunctions.xlogx(x::Real) 363 | add_dcprule(xlogx, RealLine(), AnySign, Convex, AnyMono) 364 | 365 | """ 366 | huber(x, M=1) 367 | 368 | Returns the Huber loss function of `x` with threshold `M`. 369 | 370 | # Arguments 371 | 372 | - `x::Real`: A Real. 373 | - `M::Real`: The threshold. 374 | """ 375 | function huber(x::Real, M::Real = 1) 376 | if M < 0 377 | throw(DomainError(M, "M must be positive")) 378 | end 379 | 380 | if abs(x) <= M 381 | return x^2 382 | else 383 | return 2 * M * abs(x) - M^2 384 | end 385 | end 386 | Symbolics.@register_symbolic huber(x::Real, M::Real) 387 | add_dcprule(huber, (RealLine(), HalfLine()), Positive, Convex, increasing_if_positive) 388 | 389 | add_dcprule(imag, ℂ, AnySign, Affine, AnyMono) 390 | 391 | add_dcprule(inv, HalfLine{Real, :open}(), Positive, Convex, Decreasing) 392 | add_dcprule(log, HalfLine{Real, :open}(), AnySign, Concave, Increasing) 393 | 394 | @register_symbolic Base.log(A::Symbolics.Arr) 395 | add_dcprule(log, array_domain(RealLine(), 2), Positive, Concave, Increasing) 396 | 397 | @register_symbolic LinearAlgebra.inv(A::Symbolics.Arr) 398 | add_dcprule(inv, semidefinite_domain(), AnySign, Convex, Decreasing) 399 | 400 | @register_symbolic LinearAlgebra.sqrt(A::Symbolics.Arr) 401 | add_dcprule(sqrt, semidefinite_domain(), Positive, Concave, Increasing) 402 | 403 | add_dcprule( 404 | kldivergence, 405 | (array_domain(HalfLine{Real, :open}, 1), array_domain(HalfLine{Real, :open}, 1)), 406 | Positive, 407 | Convex, 408 | AnyMono 409 | ) 410 | 411 | """ 412 | lognormcdf(x::Real) 413 | 414 | Returns the log of the normal cumulative distribution function of `x`. 415 | 416 | # Arguments 417 | 418 | - `x::Real`: A Real. 419 | """ 420 | function lognormcdf(x::Real) 421 | return logcdf(Normal, x) 422 | end 423 | Symbolics.@register_symbolic lognormcdf(x::Real) 424 | add_dcprule(lognormcdf, RealLine(), Negative, Concave, Increasing) 425 | 426 | add_dcprule(log1p, Interval{:open, :open}(-1, Inf), Negative, Concave, Increasing) 427 | 428 | add_dcprule(logistic, RealLine(), Positive, Convex, Increasing) 429 | 430 | add_dcprule(max, (RealLine(), RealLine()), AnySign, Convex, Increasing) 431 | add_dcprule(min, (RealLine(), RealLine()), AnySign, Concave, Increasing) 432 | 433 | # special cases which depend on arguments: 434 | function dcprule(::typeof(^), x::Symbolic, i) 435 | args = (x, i) 436 | if isone(i) 437 | return makerule(RealLine(), AnySign, Affine, Increasing), args 438 | elseif isinteger(i) && iseven(i) 439 | return makerule(RealLine(), Positive, Convex, increasing_if_positive), args 440 | elseif isinteger(i) && isodd(i) 441 | return makerule(HalfLine(), Positive, Convex, Increasing), args 442 | elseif i >= 1 443 | return makerule(HalfLine(), Positive, Convex, Increasing), args 444 | elseif i > 0 && i < 1 445 | return makerule(HalfLine(), Positive, Concave, Increasing), args 446 | elseif i < 0 447 | return makerule(HalfLine{Float64, :closed}(), Positive, Convex, Increasing), args 448 | end 449 | end 450 | dcprule(::typeof(Base.literal_pow), f, x...) = dcprule(^, x...) 451 | 452 | hasdcprule(::typeof(^)) = true 453 | 454 | add_dcprule(real, ℂ, AnySign, Affine, Increasing) 455 | 456 | function rel_entr(x::Real, y::Real) 457 | if x < 0 || y < 0 458 | throw(DomainError((x, y), "x and y must be positive")) 459 | end 460 | if x == 0 461 | return 0 462 | end 463 | x * log(x / y) 464 | end 465 | Symbolics.@register_symbolic rel_entr(x::Real, y::Real) 466 | add_dcprule( 467 | rel_entr, 468 | (HalfLine{Real, :open}(), HalfLine{Real, :open}()), 469 | AnySign, 470 | Convex, 471 | (AnyMono, Decreasing) 472 | ) 473 | 474 | add_dcprule(sqrt, HalfLine(), Positive, Concave, Increasing) 475 | 476 | add_dcprule(xexpx, HalfLine, Positive, Convex, Increasing) 477 | 478 | add_dcprule( 479 | conv, 480 | (array_domain(RealLine(), 1), array_domain(RealLine(), 1)), 481 | AnySign, 482 | Affine, 483 | AnyMono 484 | ) 485 | 486 | add_dcprule(cumsum, array_domain(RealLine()), AnySign, Affine, Increasing) 487 | 488 | add_dcprule(diagm, array_domain(RealLine(), 1), AnySign, Affine, Increasing) 489 | 490 | add_dcprule(diag, array_domain(RealLine(), 2), AnySign, Affine, Increasing) 491 | 492 | add_dcprule(diff, array_domain(RealLine()), AnySign, Affine, Increasing) 493 | 494 | add_dcprule(hcat, array_domain(array_domain(RealLine(), 1), 1), AnySign, Affine, Increasing) 495 | 496 | add_dcprule( 497 | kron, 498 | (array_domain(RealLine(), 2), array_domain(RealLine(), 2)), 499 | AnySign, 500 | Affine, 501 | Increasing 502 | ) 503 | 504 | add_dcprule(reshape, array_domain(RealLine(), 2), AnySign, Affine, Increasing) 505 | 506 | add_dcprule(triu, array_domain(RealLine(), 2), AnySign, Affine, Increasing) 507 | 508 | add_dcprule(vec, array_domain(RealLine(), 2), AnySign, Affine, Increasing) 509 | 510 | add_dcprule(vcat, array_domain(array_domain(RealLine(), 1), 1), AnySign, Affine, Increasing) 511 | 512 | function dcprule(::typeof(broadcast), f, x...) 513 | return dcprule(f, x...) 514 | end 515 | hasdcprule(::typeof(broadcast)) = true 516 | 517 | # add_dcprule(broadcast, (function_domain, array_domain(RealLine())), AnySign, Affine, (AnyMono, AnyMono)) 518 | 519 | add_dcprule(LinearAlgebra.adjoint, array_domain(RealLine(), 1), AnySign, Affine, Increasing) 520 | add_dcprule(Base.getindex, array_domain(RealLine(), 1), AnySign, Affine, AnyMono) 521 | --------------------------------------------------------------------------------