├── .github └── workflows │ └── ci-build.yml ├── .gitignore ├── .gitmodules ├── README.md ├── build-jaxlib.ps1 ├── build-requirements.txt ├── functions.ps1 ├── update_index.py └── windows_configure.bazelrc /.github/workflows/ci-build.yml: -------------------------------------------------------------------------------- 1 | name: build whl and uploads 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | 7 | # Allows you to run this workflow manually from the Actions tab 8 | workflow_dispatch: 9 | 10 | jobs: 11 | build-cuda: 12 | strategy: 13 | matrix: 14 | cuda-version: ["11.8", "12.1"] 15 | runs-on: windows-2022 16 | env: 17 | AZURE_STORAGE_CONNECTION_STRING: ${{ secrets.AZURE_STORAGE_CONNECTION_STRING }} 18 | BAZEL_PATH: "D:\\bazel.exe" 19 | TEMP: C:\\Users\\runneradmin\\Temp 20 | TMP: C:\\Users\\runneradmin\\Temp 21 | PYTHONUNBUFFERED: '1' 22 | 23 | steps: 24 | - name: Show user home 25 | run: ls ~ 26 | - name: Show cpu info 27 | run: Get-CimInstance Win32_Processor 28 | - name: Limit cpu 29 | run: | 30 | $p = Get-CimInstance Win32_Processor 31 | if ($p.Name -match "E5-") { throw "CPU is too old!" } 32 | - name: Show memory info 33 | run: Get-CimInstance Win32_PhysicalMemory | Format-Table Tag, DeviceLocator, Capacity, Speed 34 | - name: Configure pagefile 35 | uses: al-cheb/configure-pagefile-action@v1.2 36 | with: 37 | minimum-size: 8GB 38 | maximum-size: 32GB 39 | disk-root: "C:" 40 | # - name: Show disk info 41 | # run: Get-Volume -DriveLetter CD | Sort-Object DriveLetter 42 | - name: Workaround https://github.com/bazelbuild/bazel/issues/18592 43 | run: rm -Recurse -Force "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\vcpkg" -ErrorAction Continue 44 | 45 | - uses: actions/checkout@v2 46 | with: 47 | submodules: true 48 | 49 | - name: Download Bazelisk 50 | run: curl -k -L https://github.com/bazelbuild/bazelisk/releases/download/v1.17.0/bazelisk-windows-amd64.exe -o $env:BAZEL_PATH 51 | 52 | - name: Install CUDA ${{ matrix.cuda-version }} 53 | run: | 54 | curl -k -L https://whls.blob.core.windows.net/ci-files/v${{ matrix.cuda-version }}.7z -o cuda.7z 55 | 7z x cuda.7z -o'D:/CUDA' 56 | rm cuda.7z 57 | ls D:/CUDA/v${{ matrix.cuda-version }} 58 | 59 | - uses: actions/cache@v2 60 | with: 61 | path: ~\AppData\Local\pip\Cache 62 | key: ${{ runner.os }}-pip-${{ hashFiles('build-requirements.txt') }} 63 | restore-keys: | 64 | ${{ runner.os }}-pip- 65 | 66 | #=============# 67 | # Python 3.11 # 68 | #=============# 69 | - name: py311 70 | uses: actions/setup-python@v2 71 | with: 72 | python-version: "3.11" 73 | - name: py311 pip install 74 | run: pip install -r "$env:GITHUB_WORKSPACE/build-requirements.txt" 75 | - name: py311 build whl and upload 76 | run: | 77 | rm -Recurse -Force "$env:GITHUB_WORKSPACE/jax/bazel-dist" -ErrorAction Continue 78 | cd "$env:GITHUB_WORKSPACE/jax" 79 | ../build-jaxlib.ps1 cuda -bazel_path $env:BAZEL_PATH -vs_version 2022 -cuda_version '${{ matrix.cuda-version }}' -cuda_prefix 'D:/CUDA' -symlink_python 80 | az storage blob upload-batch --overwrite -d unstable -s "$env:GITHUB_WORKSPACE/jax/bazel-dist" --pattern '*.whl' 81 | - uses: actions/upload-artifact@v3 82 | with: 83 | name: whls 84 | path: jax/bazel-dist/**/*.whl 85 | 86 | #=============# 87 | # Python 3.10 # 88 | #=============# 89 | - name: py310 90 | uses: actions/setup-python@v2 91 | with: 92 | python-version: "3.10" 93 | - name: py310 pip install 94 | run: pip install -r "$env:GITHUB_WORKSPACE/build-requirements.txt" 95 | - name: py310 build whl and upload 96 | run: | 97 | rm -Recurse -Force "$env:GITHUB_WORKSPACE/jax/bazel-dist" -ErrorAction Continue 98 | cd "$env:GITHUB_WORKSPACE/jax" 99 | ../build-jaxlib.ps1 cuda -bazel_path $env:BAZEL_PATH -vs_version 2022 -cuda_version '${{ matrix.cuda-version }}' -cuda_prefix 'D:/CUDA' -symlink_python 100 | az storage blob upload-batch --overwrite -d unstable -s "$env:GITHUB_WORKSPACE/jax/bazel-dist" --pattern '*.whl' 101 | - uses: actions/upload-artifact@v3 102 | with: 103 | name: whls 104 | path: jax/bazel-dist/**/*.whl 105 | 106 | #============# 107 | # Python 3.9 # 108 | #============# 109 | - name: py39 110 | uses: actions/setup-python@v2 111 | with: 112 | python-version: "3.9" 113 | - name: py39 pip install 114 | run: pip install -r "$env:GITHUB_WORKSPACE/build-requirements.txt" 115 | - name: py39 build whl and upload 116 | run: | 117 | rm -Recurse -Force "$env:GITHUB_WORKSPACE/jax/bazel-dist" -ErrorAction Continue 118 | cd "$env:GITHUB_WORKSPACE/jax" 119 | ../build-jaxlib.ps1 cuda -bazel_path $env:BAZEL_PATH -vs_version 2022 -cuda_version '${{ matrix.cuda-version }}' -cuda_prefix 'D:/CUDA' -symlink_python 120 | az storage blob upload-batch --overwrite -d unstable -s "$env:GITHUB_WORKSPACE/jax/bazel-dist" --pattern '*.whl' 121 | - uses: actions/upload-artifact@v3 122 | with: 123 | name: whls 124 | path: jax/bazel-dist/**/*.whl 125 | 126 | update-index: 127 | if: ${{ ! cancelled() }} 128 | needs: 129 | - build-cuda 130 | runs-on: windows-2022 131 | env: 132 | AZURE_STORAGE_CONNECTION_STRING: ${{ secrets.AZURE_STORAGE_CONNECTION_STRING }} 133 | steps: 134 | - uses: actions/checkout@v2 135 | - name: py311 136 | uses: actions/setup-python@v2 137 | with: 138 | python-version: "3.11" 139 | - name: update index.html 140 | run: | 141 | cd "$env:GITHUB_WORKSPACE/" 142 | python ./update_index.py --url_mode absolute unstable > index.html 143 | az storage blob upload --overwrite -c unstable -f index.html -n index.html --content-type='text/html' 144 | az storage blob upload --overwrite -c '$web' -f index.html -n 'unstable/index.html' --content-type='text/html' 145 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | env.ps1 2 | *.html 3 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "jax"] 2 | path = jax 3 | url = https://github.com/cloudhan/jax.git 4 | [submodule "xla"] 5 | path = xla 6 | url = https://github.com/cloudhan/xla.git 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # JAX ❤️ 🪟 2 | 3 | alpha state... 4 | 5 | A community supported Windows build for jax. 6 | 7 | Currently, only CPU and CUDA 11.1 are supported. For CUDA 11.x, please install the `cuda`/`cuda11_cudnn82` package. 8 | 9 | # Unstable builds 10 | 11 | Each`jax` build pinnes a concrete `jaxlib` package version in its `setup.py`. To install an unstable 12 | build, you must first ensure the required `jaxlib` package exists in the pacakge 13 | index. Check it out at https://whls.blob.core.windows.net/unstable/index.html 14 | 15 | You can either install `jax` via pip (CPU only or CUDA), install `jax` from source or download the desired wheel manually. 16 | 17 | ## Install CPU only version via `pip` 18 | 19 | **Starting from 0.4.13, CPU build was removed. Please use official CPU build from PyPI directly. 20 | 21 | ``` 22 | pip install "jax[cpu]===0.3.14" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver 23 | ``` 24 | 25 | ## Install `cuda111` version via `pip` 26 | 27 | ``` 28 | pip install jax[cuda111] -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver 29 | ``` 30 | 31 | ## Install from `jax` source 32 | 33 | ``` 34 | pip install -e .[cuda111] -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver 35 | ``` 36 | 37 | ## The manual solution 38 | 39 | Select a version of `jaxlib` that you want to install. Then install `jax` manually. 40 | 41 | ```powershell 42 | # download jaxlib from https://whls.blob.core.windows.net/unstable/index.html 43 | pip install 44 | pip install jax 45 | ``` 46 | 47 | 48 | # Stable builds 49 | 50 |
To be added 51 |

52 | 53 | Check it out at https://whls.blob.core.windows.net/releases/index.html 54 | 55 |

56 | 57 | 58 | # Additional notes 59 | 60 | For `--use-deprecated legacy-resolver`, refers to 61 | [pip #9186](https://github.com/pypa/pip/issues/9186) and 62 | [pip #9203](https://github.com/pypa/pip/issues/9203). 63 | -------------------------------------------------------------------------------- /build-jaxlib.ps1: -------------------------------------------------------------------------------- 1 | param( 2 | [Parameter(Position=0, Mandatory = $true)] 3 | [ValidateSet('cpu', 'cuda')] 4 | [String]$build_type, 5 | 6 | [Parameter(Mandatory = $false)] 7 | [String]$bazel_path = "bazel", 8 | 9 | [Parameter(Mandatory = $false)] 10 | [int]$bazel_jobs = -1, 11 | 12 | [Parameter(Mandatory = $false)] 13 | [String]$conda_env = "", 14 | 15 | [Parameter(Mandatory = $false)] 16 | [ValidateSet('12.1', '11.8')] 17 | [String]$cuda_version = "12.1", 18 | 19 | [Parameter(Mandatory = $false)] 20 | [String]$cuda_prefix = "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA", 21 | 22 | [Parameter(Mandatory = $false)] 23 | [String]$bazel_output_root = "C:/bazel_output_root", 24 | 25 | [Parameter(Mandatory = $false)] 26 | [ValidateSet("2022", "2019")] 27 | [String]$vs_version = "", 28 | 29 | [Parameter(Mandatory = $false)] 30 | [String]$bazel_vc_full_version = "", 31 | 32 | [Parameter(Mandatory = $false)] 33 | [String]$xla_submodule = (Join-Path $PSScriptRoot xla), 34 | 35 | [Parameter(Mandatory = $false)] 36 | [String]$triton_submodule = (Join-Path $PSScriptRoot triton), 37 | 38 | # For CI to avoid full rebuild when changing python version 39 | [switch]$symlink_python 40 | ) 41 | 42 | . (Join-Path (Split-Path $MyInvocation.MyCommand.Path) functions.ps1) 43 | $ErrorActionPreference = "Stop" 44 | 45 | # path for patch.exe and realpath.exe 46 | $msys2_path = "C:\msys64\usr\bin" 47 | 48 | switch ($cuda_version) { 49 | '12.1' { 50 | $cudnn_version = '8.9.1' 51 | } 52 | '11.8' { 53 | $cudnn_version = '8.6.0' 54 | } 55 | } 56 | 57 | $cuda_path = "$cuda_prefix/v$cuda_version" 58 | $cudnn_path = $cuda_path 59 | 60 | if ($xla_submodule -ne (Join-Path $PSScriptRoot xla)) { 61 | $xla_submodule = Resolve-Path $xla_submodule 62 | } 63 | 64 | if ($triton_submodule -ne (Join-Path $PSScriptRoot triton)) { 65 | $triton_submodule = Resolve-Path $triton_submodule 66 | } 67 | 68 | [System.Collections.ArrayList]$new_path = ` 69 | 'C:\tools', ` 70 | 'C:\Program Files\Git\cmd', ` 71 | 'C:\Windows\System32', ` 72 | 'C:\Windows', ` 73 | 'C:\Windows\System32\Wbem', ` 74 | 'C:\Windows\System32\WindowsPowerShell\v1.0' 75 | 76 | Push-Environment 77 | Push-Location 78 | 79 | try { 80 | if ($cuda_path -eq $cudnn_path) { 81 | $env:TF_CUDA_PATHS="$cuda_path" 82 | } 83 | else { 84 | $env:TF_CUDA_PATHS="$cuda_path,$cudnn_path" 85 | } 86 | 87 | # insert your path here 88 | $new_path.Insert(0, "$msys2_path") 89 | 90 | # bring github actions python into path 91 | if ($env:pythonLocation) { 92 | $new_path.Insert(0, "$env:pythonLocation") 93 | $new_path.Insert(0, "$env:pythonLocation/Scripts") 94 | } 95 | 96 | $env:PATH = $new_path -join ";" 97 | 98 | if ($vs_version -ne "") { 99 | Set-VSEnv $vs_version 100 | } 101 | if ($bazel_vc_full_version -ne "") { 102 | $env:BAZEL_VC_FULL_VERSION = $bazel_vc_full_version 103 | } 104 | 105 | # bring conda python into environment, this supersede MSYS2's python and 106 | # maybe VS's python 107 | if ($conda_env -ne "") { 108 | conda activate $conda_env 109 | } 110 | 111 | echo 'try-import %workspace%/../windows_configure.bazelrc' > .bazelrc.user 112 | 113 | if ($bazel_jobs -gt 0) { 114 | echo "build --jobs=${bazel_jobs}" >> .bazelrc.user 115 | } 116 | 117 | if (Test-Path $xla_submodule) { 118 | Write-Host -ForegroundColor Yellow "Use xla submodule " $xla_submodule 119 | echo ('build:windows --override_repository=xla=' + $xla_submodule.Replace("\", "/")) >> .bazelrc.user 120 | } 121 | 122 | if (Test-Path $triton_submodule) { 123 | Write-Host -ForegroundColor Yellow "Use triton submodule " $triton_submodule 124 | echo ('build:windows --override_repository=triton=' + $triton_submodule.Replace("\", "/")) >> .bazelrc.user 125 | } 126 | 127 | $python_bin_path = "" 128 | if ($symlink_python) { 129 | $python_symlinked_home = Join-Path $PSScriptRoot python_symlinked 130 | Remove-Item $python_symlinked_home -Force -ErrorAction 0 131 | New-Item -Type SymbolicLink $python_symlinked_home -Target (Split-Path (Get-Command python).Source) -Force 132 | $new_path.Insert(0, $python_symlinked_home) 133 | 134 | $python_bin_path = Join-Path $python_symlinked_home python.exe 135 | 136 | # We use it to trigger the repository rule when python is changed 137 | $python_lib_path = (Get-Item $python_symlinked_home).Target.Replace("\", "/") 138 | Write-Host -ForegroundColor Yellow "Use PYTHON_LIB_PATH " $python_lib_path 139 | echo ('build:windows --repo_env PYTHON_LIB_PATH="' + $python_lib_path + '"') >> .bazelrc.user 140 | } 141 | 142 | # NOTE: In case it is needed to debug a build failure, run `bazel --output_user_root=$bazel_output_root ` 143 | if ($build_type -eq 'cpu') { 144 | python .\build\build.py ` 145 | --python_bin_path="$python_bin_path" ` 146 | --noenable_cuda ` 147 | --bazel_path="$bazel_path" ` 148 | --bazel_startup_options="--output_user_root=$bazel_output_root" 149 | } elseif ($build_type -eq 'cuda') { 150 | python .\build\build.py ` 151 | --python_bin_path="$python_bin_path" ` 152 | --enable_cuda ` 153 | --cuda_version="$cuda_version" ` 154 | --cuda_path="$cuda_path" ` 155 | --cudnn_version="$cudnn_version" ` 156 | --cudnn_path="$cudnn_path" ` 157 | --bazel_path="$bazel_path" ` 158 | --bazel_startup_options="--output_user_root=$bazel_output_root" 159 | } 160 | 161 | if ($LASTEXITCODE -ne 0) { 162 | throw "last command exit with $LASTEXITCODE" 163 | } 164 | 165 | if ((ls dist).Count -ne 1) { 166 | throw "number of whl files != 1" 167 | } 168 | 169 | $name = (ls dist)[0].Name 170 | 171 | if ($build_type -eq 'cpu') { 172 | mkdir "bazel-dist/cpu" -ErrorAction 0 173 | mv -Force "dist/$name" "bazel-dist/cpu/$name" 174 | Write-Host -ForegroundColor Yellow "Moved dist/$name to bazel-dist/cpu/$name" 175 | } elseif ($build_type -eq 'cuda') { 176 | $cuda_ver = [System.Version]$cuda_version 177 | $cudnn_ver = [System.Version]$cudnn_version 178 | $cuda_dir = "cuda$($cuda_ver.Major)$($cuda_ver.Minor)" 179 | $cuda_cudnn_tag = "cuda$($cuda_ver.Major).cudnn$($cudnn_ver.Major)$($cudnn_ver.Minor)" 180 | $new_name = $name.Insert($name.IndexOf("-", $name.IndexOf("-") + 1), "+$cuda_cudnn_tag") 181 | 182 | mkdir "bazel-dist/$cuda_dir" -ErrorAction 0 183 | mv -Force "dist/$name" "bazel-dist/$cuda_dir/$new_name" 184 | Write-Host -ForegroundColor Yellow "Move dist/$name to bazel-dist/$cuda_dir/$new_name" 185 | } 186 | } 187 | finally { 188 | Pop-Location 189 | Pop-Environment 190 | } 191 | -------------------------------------------------------------------------------- /build-requirements.txt: -------------------------------------------------------------------------------- 1 | # https://github.com/pypa/setuptools/pull/3690 2 | setuptools>=65.6.1 3 | 4 | numpy==1.23.5 5 | scipy==1.9.3 6 | wheel 7 | six 8 | auditwheel 9 | build 10 | -------------------------------------------------------------------------------- /functions.ps1: -------------------------------------------------------------------------------- 1 | # copy pasted from https://gist.github.com/cloudhan/97db3c1e57895a09a80ec1f30c471cb3 2 | function Set-EnvFromCmdSet { 3 | [CmdletBinding()] 4 | param( 5 | [Parameter(ValueFromPipeline)] 6 | [string]$CmdSetResult 7 | ) 8 | process { 9 | if ($CmdSetResult -Match "=") { 10 | $i = $CmdSetResult.IndexOf("=") 11 | $k = $CmdSetResult.Substring(0, $i) 12 | $v = $CmdSetResult.Substring($i + 1) 13 | Set-Item -Force -Path "Env:\$k" -Value "$v" 14 | } 15 | } 16 | } 17 | 18 | function Set-VSEnv { 19 | param ( 20 | [parameter(Mandatory = $false)] 21 | [ValidateSet(2022, 2019, 2017)] 22 | [int]$Version = 2022, 23 | 24 | [parameter(Mandatory = $false)] 25 | [ValidateSet("all", "x86", "x64")] 26 | [String]$Arch = "x64" 27 | ) 28 | 29 | $vs_where = 'C:\Program Files (x86)\Microsoft Visual Studio\Installer\vswhere.exe' 30 | 31 | $version_range = switch ($Version) { 32 | 2022 { '[17,18)' } 33 | 2019 { '[16,17)' } 34 | 2017 { '[15,16)' } 35 | } 36 | $info = &$vs_where -version $version_range -format json | ConvertFrom-Json 37 | $vs_env = @{ 38 | install_path = if ($null -ne $info) { $info[0].installationPath } else { $null } 39 | all = 'Common7\Tools\VsDevCmd.bat' 40 | x64 = 'VC\Auxiliary\Build\vcvars64.bat' 41 | x86 = 'VC\Auxiliary\Build\vcvars32.bat' 42 | } 43 | if ( $null -eq $vs_env.install_path) { 44 | Write-Host -ForegroundColor Red "Visual Studio $Version is not installed." 45 | return 46 | } 47 | 48 | $path = Join-Path $vs_env.install_path $vs_env.$Arch 49 | 50 | C:/Windows/System32/cmd.exe /c "`"$path`" & set" | Set-EnvFromCmdSet 51 | Set-Item -Force -Path "Env:\BAZEL_VC" -Value "$env:VCINSTALLDIR" 52 | Write-Host -ForegroundColor Green "Visual Studio $Version $Arch Command Prompt variables set." 53 | } 54 | 55 | 56 | class EnvironmentStack { 57 | static [System.Collections.Stack] $stack = [System.Collections.Stack]::new() 58 | } 59 | 60 | function Push-Environment { 61 | [EnvironmentStack]::stack.push([Environment]::GetEnvironmentVariables()) 62 | } 63 | 64 | function Pop-Environment { 65 | [EnvironmentStack]::stack.pop().GetEnumerator() | 66 | ForEach-Object { 67 | Set-Item -Force -Path ("Env:\" + $_.Key) -Value $_.Value 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /update_index.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import argparse 5 | import subprocess 6 | 7 | HEADER = """ 8 | 9 | 10 | 11 | 12 | """.strip("\n") 13 | 14 | FOOTER = "\n\n" 15 | 16 | ABSOLUTE_URL_PATTERN = "https://whls.blob.core.windows.net/{container}/{key}" 17 | 18 | RELATIVE_URL_PATTERN = "./{key}" 19 | 20 | 21 | if not os.environ.get("AZURE_STORAGE_CONNECTION_STRING"): 22 | sys.stderr.write("Environment variable AZURE_STORAGE_CONNECTION_STRING is not properly set.\n") 23 | sys.stderr.write("Abort!\n") 24 | exit(-1) 25 | 26 | 27 | parser = argparse.ArgumentParser(description="Generates and/or update package indexes html file for pip.") 28 | parser.add_argument("container", choices=["unstable", "releases"]) 29 | parser.add_argument("--url_mode", default="relative", choices=["absolute", "relative"]) 30 | args = parser.parse_args() 31 | 32 | 33 | def entry(key): 34 | if args.url_mode == "relative": 35 | url_pattern = RELATIVE_URL_PATTERN 36 | else: 37 | url_pattern = ABSOLUTE_URL_PATTERN 38 | 39 | link_href = url_pattern.format(key=key, container=args.container) 40 | link_title = link_href 41 | return f'{key}
' 42 | 43 | 44 | def get_entries(): 45 | entries = [] 46 | proc_args = ["az.cmd", "storage", "blob", "list", "--output=json", f"--container-name={args.container}"] 47 | sys.stderr.write(f"Running command: {' '.join(proc_args)}\n") 48 | json_str = subprocess.check_output(proc_args) 49 | j = json.loads(json_str) 50 | for item in j: 51 | key = item["name"] 52 | if key.endswith(".whl"): 53 | entries.append(entry(key)) 54 | return entries 55 | 56 | 57 | def gen_index_html(entries): 58 | html = [HEADER] 59 | html.extend(entries) 60 | html.append(FOOTER) 61 | return "\n".join(html) 62 | 63 | 64 | html = gen_index_html(get_entries()) 65 | print(html) 66 | -------------------------------------------------------------------------------- /windows_configure.bazelrc: -------------------------------------------------------------------------------- 1 | build:windows --features=compiler_param_file 2 | 3 | build:windows --copt=/d2ReducedOptimizeHugeFunctions 4 | build:windows --host_copt=/d2ReducedOptimizeHugeFunctions 5 | 6 | build:windows --copt=/arch:AVX 7 | build:windows --copt=/arch:AVX2 8 | build:windows --copt=/DTF_COMPILE_LIBRARY 9 | --------------------------------------------------------------------------------