├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── .python-version ├── README.md ├── pyproject.toml ├── src └── sparse_vggt │ ├── __init__.py │ ├── models │ ├── __init__.py │ ├── attention.py │ ├── pi3.py │ ├── utils.py │ └── vggt.py │ └── utils │ ├── hilbert.py │ ├── sparse_wrapper.py │ └── tokens.py └── uv.lock /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[oc] 3 | build/ 4 | dist/ 5 | wheels/ 6 | *.egg-info 7 | .venv/ 8 | mise.toml 9 | dev.py 10 | dev.ipynb 11 | pyrightconfig.json 12 | py.typed 13 | results/ 14 | .vscode/ 15 | .cursor/ 16 | .ruff_cache/ 17 | slurm/ 18 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "external/vggt"] 2 | path = external/vggt 3 | url = https://github.com/facebookresearch/vggt.git 4 | [submodule "external/SpargeAttn"] 5 | path = external/SpargeAttn 6 | url = https://github.com/brianwang00001/SpargeAttn.git 7 | [submodule "external/Pi3"] 8 | path = external/Pi3 9 | url = https://github.com/yyfz/Pi3.git 10 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.6.0 4 | hooks: 5 | - id: end-of-file-fixer 6 | - id: trailing-whitespace 7 | - id: check-added-large-files 8 | 9 | - repo: https://github.com/astral-sh/ruff-pre-commit 10 | rev: v0.5.7 11 | hooks: 12 | - id: ruff 13 | args: ["--fix", "--ignore", "E402"] 14 | - id: ruff-format 15 | 16 | - repo: local 17 | hooks: 18 | - id: jupyter-nb-clear-output 19 | name: jupyter-nb-clear-output 20 | files: \.ipynb$ 21 | stages: [pre-commit] 22 | language: system 23 | entry: jupyter nbconvert --ClearOutputPreprocessor.enabled=True --inplace 24 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.12 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Faster VGGT with Block-Sparse Global Attention 2 | 3 | [📄 Arxiv](https://arxiv.org/abs/2509.07120) | [🌐 Project Page](https://brianwang00001.github.io/sparse-vggt/) 4 | 5 | ## Quick Start 6 | Setup the environment: 7 | ```bash 8 | # Clone the repository 9 | git clone --recursive https://github.com/brianwang00001/sparse-vggt 10 | cd sparse-vggt 11 | 12 | # Install dependencies 13 | uv sync 14 | 15 | # Compile SpargeAttn 16 | # Needs cuda installed (we used cuda 12.8) 17 | uv pip install -e external/SpargeAttn/ --no-build-isolation 18 | ``` 19 | 20 | Try the sparse VGGT model: 21 | ```python 22 | import torch 23 | from vggt.models.vggt import VGGT 24 | from sparse_vggt.models.vggt import sparse_aggregator_from_vggt 25 | 26 | # Load the original VGGT model 27 | model = VGGT.from_pretrained("facebook/VGGT-1B") 28 | 29 | # Replace the aggregator with the sparse aggregator 30 | # Note: `aux_output_store` is a dictionary of auxiliary outputs from the global attention 31 | # You can use it to get the sparsity of the global attention 32 | sparse_aggregator, aux_output_store = sparse_aggregator_from_vggt( 33 | model.aggregator, 34 | sparse_ratio=0.1, # example config 35 | cdf_threshold=0.97, # example config 36 | ) 37 | model.aggregator = sparse_aggregator 38 | 39 | # Use the sparse model as usual 40 | model.cuda() 41 | model.eval() 42 | images = torch.randn(10, 3, 518, 378).cuda() 43 | with torch.no_grad(): 44 | with torch.autocast("cuda", dtype=torch.bfloat16): 45 | out = model(images) 46 | ``` 47 | 48 | Similar for Pi3: 49 | ```python 50 | import torch 51 | from pi3.models.pi3 import Pi3 52 | from sparse_vggt.models.pi3 import sparse_model_from_pi3 53 | 54 | model = Pi3.from_pretrained("yyfz233/Pi3") 55 | model, aux_output_store = sparse_model_from_pi3(model, sparse_ratio=0.1, cdf_threshold=0.97) 56 | ``` 57 | 58 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "sparse-vggt" 3 | version = "0.1.0" 4 | description = "Official implementation of Faster VGGt with Block-Sparse Global Attention" 5 | readme = "README.md" 6 | authors = [ 7 | { name = "Brian Wang", email = "brianwang00001@gmail.com" } 8 | ] 9 | requires-python = ">=3.12" 10 | dependencies = [ 11 | "vggt", 12 | "einops>=0.8.1", 13 | "huggingface-hub>=0.34.4", 14 | "ninja>=1.13.0", 15 | "pillow>=10.3.0", 16 | "safetensors>=0.6.2", 17 | "torch>=2.8.0", 18 | "torchvision>=0.23.0", 19 | "numpy==1.26.4", 20 | "plyfile>=1.1.2", 21 | "ipykernel>=6.30.1", 22 | ] 23 | 24 | [tool.uv.sources] 25 | vggt = { path = "external/vggt", editable = true } 26 | 27 | [tool.uv] 28 | dev-dependencies = ["ruff", "pre-commit"] 29 | 30 | [tool.ruff] 31 | line-length = 100 32 | target-version = "py312" 33 | exclude = ["external"] 34 | 35 | [build-system] 36 | requires = ["uv_build>=0.8.14,<0.9.0"] 37 | build-backend = "uv_build" 38 | -------------------------------------------------------------------------------- /src/sparse_vggt/__init__.py: -------------------------------------------------------------------------------- 1 | from .models.vggt import sparse_aggregator_from_vggt 2 | from .models.pi3 import sparse_model_from_pi3 3 | 4 | __all__ = ["sparse_aggregator_from_vggt", "sparse_model_from_pi3"] -------------------------------------------------------------------------------- /src/sparse_vggt/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vggt import sparse_aggregator_from_vggt 2 | 3 | __all__ = ["sparse_aggregator_from_vggt"] 4 | -------------------------------------------------------------------------------- /src/sparse_vggt/models/attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | 7 | from sparse_vggt.utils.sparse_wrapper import block_sparse_attn_cuda 8 | from sparse_vggt.utils.tokens import get_patch_tokens, get_special_tokens 9 | 10 | 11 | def predict_attention(query, key, ks_q=128, ks_k=64, pool_mode="avg"): 12 | """ 13 | Args: 14 | query: (B, nh, Tq, C) 15 | key: (B, nh, Tk, C) 16 | 17 | Return: 18 | pooled_prob: (B, nh, Tq, Tk) 19 | """ 20 | assert pool_mode in ["max", "avg"], f"{pool_mode=}" 21 | 22 | pooling_fn = { 23 | "max": F.max_pool1d, 24 | "avg": F.avg_pool1d, 25 | }[pool_mode] 26 | 27 | assert query.ndim == 4, f"{query.shape=}" 28 | assert key.ndim == 4, f"{key.shape=}" 29 | 30 | B, nh, Tq, C = query.shape 31 | _, _, Tk, _ = key.shape 32 | 33 | # Query Pooling 34 | query = rearrange(query, "B nh Tq C -> (B nh) C Tq") 35 | pooled_query = pooling_fn(query, kernel_size=ks_q, ceil_mode=True) 36 | pooled_query = rearrange(pooled_query, "(B nh) C Tq -> B nh Tq C", B=B, nh=nh) 37 | 38 | # Key Pooling 39 | key = rearrange(key, "B nh Tk C -> (B nh) C Tk") 40 | pooled_key = pooling_fn(key, kernel_size=ks_k, ceil_mode=True) 41 | pooled_key = rearrange(pooled_key, "(B nh) C Tk -> B nh Tk C", B=B, nh=nh) 42 | 43 | # Dot Product 44 | scale = 1 / math.sqrt(C) 45 | pooled_score = pooled_query @ pooled_key.transpose(-1, -2) * scale # (B, nh, Tq, Tk) 46 | pooled_prob = F.softmax(pooled_score, dim=-1) 47 | 48 | return pooled_prob 49 | 50 | 51 | def adaptive_sparse_attention_forward( 52 | self, 53 | x, 54 | pos, 55 | sparse_ratio: float | None = None, 56 | cdf_threshold: float | None = None, 57 | pool_mode: str = "avg", 58 | aux_sparsity_only: bool | None = None, 59 | aux_output_store: dict | None = None, 60 | num_special_tokens: int = 5, 61 | num_heads: int = 16, 62 | ): 63 | """Adaptive Block Sparse Attention forward pass to replace the original attention forward function. 64 | 65 | Args: 66 | x: (B, N * P, hidden_dim) == (B, N * (H * W + S), hidden_dim) 67 | 68 | Return: 69 | x: (B, N * P, hidden_dim) 70 | """ 71 | 72 | B, NP, hidden_dim = x.shape 73 | S = num_special_tokens 74 | nh = num_heads 75 | hd = hidden_dim // num_heads 76 | 77 | # Infer H, W and N from pos 78 | H = int(pos[0].max(0).values[0]) 79 | W = int(pos[0].max(0).values[1]) 80 | N = NP // (H * W + S) 81 | P = H * W + S 82 | 83 | # Sanity check 84 | assert N * (H * W + S) == NP, f"{N=}, {H=}, {W=}, {S=}, {NP=}" 85 | 86 | qkv = self.qkv(x) 87 | three = 3 88 | qkv = rearrange(qkv, "B N (three nh hd)-> B N three nh hd", three=three, nh=nh, hd=hd) 89 | qkv = rearrange(qkv, "B N three nh hd -> three B nh N hd") 90 | 91 | q, k, v = qkv.unbind(0) # (B, num_heads, N * P, head_dim) 92 | q, k = self.q_norm(q), self.k_norm(k) 93 | 94 | if self.rope is not None: 95 | q = self.rope(q, pos) 96 | k = self.rope(k, pos) 97 | 98 | # separate patch and special tokens 99 | q_special = get_special_tokens(q, N, P, S) 100 | k_special = get_special_tokens(k, N, P, S) 101 | v_special = get_special_tokens(v, N, P, S) 102 | q_patch = get_patch_tokens(q, N, P, S) 103 | k_patch = get_patch_tokens(k, N, P, S) 104 | v_patch = get_patch_tokens(v, N, P, S) 105 | 106 | # special tokens attend to all tokens 107 | if q_special is not None: 108 | x_special = F.scaled_dot_product_attention(q_special, k, v) 109 | else: 110 | x_special = None 111 | # release memory 112 | del q, k, v, qkv 113 | 114 | # Append special key and values in the end 115 | if k_special is not None: 116 | key = torch.cat([k_patch, k_special], dim=-2) 117 | value = torch.cat([v_patch, v_special], dim=-2) 118 | else: 119 | key = k_patch 120 | value = v_patch 121 | 122 | if self.training: 123 | raise NotImplementedError("This is currently only training-free. Use .eval()") 124 | 125 | else: 126 | attn_pooled = predict_attention(query=q_patch, key=k_patch, pool_mode=pool_mode) 127 | # release memory 128 | del k_patch, v_patch 129 | 130 | # patch attention 131 | x_patch, sparsity = block_sparse_attn_cuda( 132 | query=q_patch, 133 | key=key, 134 | value=value, 135 | pooled_score=attn_pooled, 136 | sparse_ratio=sparse_ratio, 137 | cdf_threshold=cdf_threshold, 138 | return_sparsity=True, 139 | ) 140 | 141 | if aux_output_store is not None: 142 | aux_output_store["sparsity"] = sparsity 143 | 144 | if not aux_sparsity_only: 145 | aux_output_store.update( 146 | { 147 | "attn_pooled": attn_pooled, 148 | "query": q, 149 | "key": k, 150 | "shape": { 151 | "B": B, 152 | "N": N, 153 | "P": P, 154 | "head_dim": hd, 155 | "num_heads": nh, 156 | "H": H, 157 | "W": W, 158 | }, 159 | } 160 | ) 161 | 162 | x_patch = rearrange( 163 | x_patch, 164 | "B nh (N H W) hd -> B nh N (H W) hd", 165 | N=N, 166 | H=H, 167 | W=W, 168 | ) 169 | 170 | # combine patch and special tokens 171 | if x_special is not None: 172 | x = x_special.view(B, nh, N, S, hd) 173 | x = torch.cat([x, x_patch], dim=-2) 174 | else: 175 | x = x_patch 176 | 177 | x = x.view(B, nh, N * P, hd) 178 | x = x.transpose(1, 2).reshape(B, N * P, nh * hd) 179 | x = self.proj(x) 180 | x = self.proj_drop(x) 181 | return x 182 | -------------------------------------------------------------------------------- /src/sparse_vggt/models/pi3.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from functools import partial 3 | from pathlib import Path 4 | from types import MethodType 5 | from typing import Literal 6 | 7 | import torch 8 | 9 | from sparse_vggt.models.attention import adaptive_sparse_attention_forward 10 | from sparse_vggt.models.utils import print_sparse_info 11 | from sparse_vggt.utils.hilbert import hilbert_permute 12 | from sparse_vggt.utils.sparse_wrapper import check_sparse_mode 13 | 14 | _PROJECT_ROOT = Path(__file__).parent.parent.parent.parent 15 | sys.path.append(str(_PROJECT_ROOT / "external/Pi3")) 16 | from pi3.models.pi3 import Pi3 17 | 18 | 19 | def pi3_sparse_attention_forward( 20 | self, 21 | x, 22 | xpos, # replace pos with xpos 23 | sparse_ratio, 24 | cdf_threshold, 25 | pool_mode, 26 | aux_sparsity_only, 27 | aux_output_store, 28 | num_special_tokens, 29 | num_heads, 30 | ): 31 | """Wrapper for Pi3 attention because pi3 has `xpos` instead of `pos`""" 32 | return adaptive_sparse_attention_forward( 33 | self, 34 | x, 35 | xpos, 36 | sparse_ratio, 37 | cdf_threshold, 38 | pool_mode, 39 | aux_sparsity_only, 40 | aux_output_store, 41 | num_special_tokens, 42 | num_heads, 43 | ) 44 | 45 | 46 | def sparse_pi3_decode_forward(self, hidden, N, H, W, use_hilbert=True): 47 | BN, hw, _ = hidden.shape 48 | B = BN // N 49 | 50 | final_output = [] 51 | 52 | hidden = hidden.reshape(B * N, hw, -1) 53 | 54 | register_token = self.register_token.repeat(B, N, 1, 1).reshape( 55 | B * N, *self.register_token.shape[-2:] 56 | ) 57 | 58 | # Concatenate special tokens with patch tokens 59 | hidden = torch.cat([register_token, hidden], dim=1) 60 | hw = hidden.shape[1] 61 | 62 | # Convert HW to patch space 63 | H = H // self.patch_size 64 | W = W // self.patch_size 65 | pos = self.position_getter(B * N, H, W, hidden.device) 66 | 67 | if self.patch_start_idx > 0: 68 | # do not use position embedding for special tokens (camera and register tokens) 69 | # so set pos to 0 for the special tokens 70 | pos = pos + 1 71 | pos_special = torch.zeros(B * N, self.patch_start_idx, 2).to(hidden.device).to(pos.dtype) 72 | pos = torch.cat([pos_special, pos], dim=1) 73 | 74 | # add: hilbert permutation 75 | orig_pos = pos.clone() 76 | if use_hilbert: 77 | S = self.patch_start_idx 78 | hidden = hilbert_permute(hidden, H, W, S) # (B * N, H * W + S, C) 79 | orig_pos = pos.clone() 80 | pos = hilbert_permute(pos, H, W, S) # (B * N, H * W, 2) 81 | 82 | for i in range(len(self.decoder)): 83 | blk = self.decoder[i] 84 | 85 | if i % 2 == 0: 86 | pos = pos.reshape(B * N, hw, -1) 87 | hidden = hidden.reshape(B * N, hw, -1) 88 | else: 89 | pos = pos.reshape(B, N * hw, -1) 90 | hidden = hidden.reshape(B, N * hw, -1) 91 | 92 | hidden = blk(hidden, xpos=pos) 93 | 94 | if i + 1 in [len(self.decoder) - 1, len(self.decoder)]: 95 | # add: hilbert permutation 96 | if use_hilbert: 97 | S = self.patch_start_idx 98 | hidden_orig = hidden.reshape(B * N, hw, -1) 99 | hidden_orig = hilbert_permute(hidden_orig, H, W, S, reverse=True) 100 | final_output.append(hidden_orig) 101 | else: 102 | final_output.append(hidden.reshape(B * N, hw, -1)) 103 | 104 | return torch.cat([final_output[0], final_output[1]], dim=-1), orig_pos.reshape(B * N, hw, -1) 105 | 106 | 107 | def sparse_model_from_pi3( 108 | model: Pi3, 109 | use_hilbert: bool = False, 110 | sparse_ratio: float | None = None, 111 | cdf_threshold: float | None = None, 112 | pool_mode: Literal["max", "avg"] = "avg", 113 | aux_output: bool = False, 114 | aux_sparsity_only: bool = True, 115 | num_special_tokens: int = 5, 116 | verbose: bool = True, 117 | ): 118 | """Convert the original Pi3 model to a sparse model. 119 | 120 | Args: 121 | model: Original Pi3 model. 122 | use_hilbert: If True, use Hilbert permutation on patch tokens. 123 | sparse_ratio: Sparse ratio for the global attention. 124 | cdf_threshold: CDF threshold for the global attention. 125 | pool_mode: Avg or Max pooling for the global attention. 126 | aux_output: If True, store auxiliary output from the global attention. 127 | aux_sparsity_only: If True, only store sparsity from the global attention. 128 | 129 | Returns: 130 | Aggregator: Modified Aggregator 131 | aux_output_store (dict | None): Auxiliary output storage. 132 | """ 133 | # Check sparse mode 134 | check_sparse_mode(sparse_ratio, cdf_threshold) 135 | 136 | # Replace model decode function 137 | decode_fwd = partial(sparse_pi3_decode_forward, use_hilbert=use_hilbert) 138 | model.decode = MethodType(decode_fwd, model) 139 | 140 | # Auxiliary output 141 | modified_layers = [i for i in range(len(model.decoder)) if i % 2 != 0] 142 | if aux_output: 143 | # Pointers to store auxiliary output 144 | aux_output_store = {i: {} for i in modified_layers} 145 | else: 146 | # No auxiliary output`` 147 | aux_output_store = {i: None for i in modified_layers} 148 | 149 | for i in range(len(model.decoder)): 150 | if i not in modified_layers: 151 | continue 152 | 153 | # Replace attention forward function 154 | attn_fwd = partial( 155 | pi3_sparse_attention_forward, 156 | sparse_ratio=sparse_ratio, 157 | cdf_threshold=cdf_threshold, 158 | pool_mode=pool_mode, 159 | aux_output_store=aux_output_store[i], 160 | aux_sparsity_only=aux_sparsity_only, 161 | num_special_tokens=num_special_tokens, 162 | num_heads=model.decoder[i].attn.num_heads, 163 | ) 164 | 165 | # Replace attention 166 | model.decoder[i].attn.forward = MethodType(attn_fwd, model.decoder[i].attn) 167 | 168 | # Print some info 169 | if verbose: 170 | print_sparse_info( 171 | sparse_ratio, 172 | cdf_threshold, 173 | pool_mode, 174 | use_hilbert, 175 | aux_output, 176 | aux_sparsity_only, 177 | ) 178 | return model, aux_output_store 179 | -------------------------------------------------------------------------------- /src/sparse_vggt/models/utils.py: -------------------------------------------------------------------------------- 1 | def print_sparse_info( 2 | sparse_ratio, 3 | cdf_threshold, 4 | pool_mode, 5 | use_hilbert, 6 | aux_output, 7 | aux_sparsity_only, 8 | ): 9 | # Print some info 10 | print("-" * 100) 11 | 12 | if sparse_ratio is not None: 13 | print(f"sparse_ratio: {sparse_ratio}") 14 | if cdf_threshold is not None: 15 | print(f"cdf_threshold: {cdf_threshold}") 16 | 17 | if use_hilbert: 18 | print("Using Hilbert permutation") 19 | else: 20 | print("Not using Hilbert permutation") 21 | 22 | print(f"Pooling mode: {pool_mode}") 23 | 24 | if aux_output and aux_sparsity_only: 25 | print("Auxiliary output is enabled (sparsity only)") 26 | elif aux_output and not aux_sparsity_only: 27 | print( 28 | "\033[31mWARNING: Including all auxiliary outputs can be slow, only for analysis and debugging!\033[0m" 29 | ) 30 | else: 31 | print("Auxiliary output is disabled") 32 | 33 | print("-" * 100) 34 | -------------------------------------------------------------------------------- /src/sparse_vggt/models/vggt.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from types import MethodType 3 | from typing import Literal 4 | 5 | import torch 6 | from vggt.models.aggregator import Aggregator, slice_expand_and_flatten 7 | 8 | from sparse_vggt.models.attention import adaptive_sparse_attention_forward 9 | from sparse_vggt.models.utils import print_sparse_info 10 | from sparse_vggt.utils.hilbert import hilbert_permute 11 | from sparse_vggt.utils.sparse_wrapper import check_sparse_mode 12 | 13 | 14 | def sparse_vggt_aggregator_forward( 15 | self, 16 | images: torch.Tensor, 17 | use_hilbert: bool = True, 18 | intermediate_layer_idx: list[int] = [4, 11, 17, 23], # hardcoded for vggt 19 | ): 20 | """Adaptive Block Sparse Attention Aggregator forward pass to replace the original aggregator forward function. 21 | 22 | Args: 23 | images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. 24 | B: batch size, S: sequence length, 3: RGB channels, H: height, W: width 25 | use_hilbert: if True, use Hilbert permutation for patch tokens 26 | 27 | Returns: 28 | (list[torch.Tensor], int): 29 | The list of outputs from the attention blocks, 30 | and the patch_start_idx indicating where patch tokens begin. 31 | """ 32 | B, N, C_in, H, W = images.shape 33 | 34 | if C_in != 3: 35 | raise ValueError(f"Expected 3 input channels, got {C_in}") 36 | 37 | # Normalize images and reshape for patch embed 38 | images = (images - self._resnet_mean) / self._resnet_std 39 | 40 | # Reshape to [B*N, C, H, W] for patch embedding 41 | images = images.view(B * N, C_in, H, W) 42 | patch_tokens = self.patch_embed(images) 43 | 44 | if isinstance(patch_tokens, dict): 45 | patch_tokens = patch_tokens["x_norm_patchtokens"] 46 | 47 | _, P, C = patch_tokens.shape 48 | 49 | # Expand camera and register tokens to match batch size and sequence length 50 | camera_token = slice_expand_and_flatten(self.camera_token, B, N) 51 | register_token = slice_expand_and_flatten(self.register_token, B, N) 52 | 53 | # Concatenate special tokens with patch tokens 54 | tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1) 55 | 56 | # convert to patch dimension 57 | H = H // self.patch_size 58 | W = W // self.patch_size 59 | pos = None 60 | if self.rope is not None: 61 | pos = self.position_getter(B * N, H, W, device=images.device) 62 | 63 | if self.patch_start_idx > 0 and pos is not None: 64 | # do not use position embedding for special tokens (camera and register tokens) 65 | # so set pos to 0 for the special tokens 66 | pos = pos + 1 67 | pos_special = torch.zeros(B * N, self.patch_start_idx, 2).to(images.device).to(pos.dtype) 68 | pos = torch.cat([pos_special, pos], dim=1) 69 | 70 | # update P because we added special tokens 71 | _, P, C = tokens.shape 72 | 73 | frame_idx = 0 74 | global_idx = 0 75 | 76 | output_list = [] 77 | 78 | # add: hilbert permutation 79 | S = self.patch_start_idx 80 | if use_hilbert: 81 | tokens = hilbert_permute(tokens, H, W, S) # (B * N, H * W + S, C) 82 | pos = hilbert_permute(pos, H, W, S) # (B * N, H * W, 2) 83 | 84 | concat_inter = None 85 | frame_intermediates = [] 86 | global_intermediates = [] 87 | 88 | for layer_idx in range(self.aa_block_num): 89 | for attn_type in self.aa_order: 90 | if attn_type == "frame": 91 | tokens, frame_idx, frame_intermediates = self._process_frame_attention( 92 | tokens, B, N, P, C, frame_idx, pos=pos 93 | ) 94 | elif attn_type == "global": 95 | tokens, global_idx, global_intermediates = self._process_global_attention( 96 | tokens, B, N, P, C, global_idx, pos=pos 97 | ) 98 | else: 99 | raise ValueError(f"Unknown attention type: {attn_type}") 100 | if layer_idx in intermediate_layer_idx: 101 | for i in range(len(frame_intermediates)): 102 | # concat frame and global intermediates, [B x N x P x 2C] 103 | concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1) 104 | output_list.append(concat_inter) 105 | else: 106 | output_list.append(None) # workaround for vggt's heads 107 | 108 | del concat_inter 109 | del frame_intermediates 110 | del global_intermediates 111 | 112 | if use_hilbert: 113 | for i, inter_x in enumerate(output_list): 114 | # (B, N, H * W + S, C) 115 | if inter_x is not None: 116 | output_list[i] = hilbert_permute(inter_x, H, W, S, reverse=True) 117 | 118 | return output_list, self.patch_start_idx 119 | 120 | 121 | def sparse_aggregator_from_vggt( 122 | aggregator: Aggregator, 123 | use_hilbert: bool = False, 124 | sparse_ratio: float | None = None, 125 | cdf_threshold: float | None = None, 126 | pool_mode: Literal["max", "avg"] = "avg", 127 | aux_output: bool = False, 128 | aux_sparsity_only: bool = True, 129 | num_special_tokens: int = 5, 130 | verbose: bool = True, 131 | ): 132 | """Convert the original VGGT aggregator to a sparse aggregator. 133 | 134 | Args: 135 | aggregator: Original VGGT aggregator. 136 | use_hilbert: If True, use Hilbert permutation on patch tokens. 137 | sparse_ratio: Sparse ratio for the global attention. 138 | cdf_threshold: CDF threshold for the global attention. 139 | pool_mode: Avg or Max pooling for the global attention. 140 | aux_output: If True, store auxiliary output from the global attention. 141 | aux_sparsity_only: If True, only store sparsity from the global attention. 142 | 143 | Returns: 144 | Aggregator: Modified Aggregator 145 | aux_output_store (dict | None): Auxiliary output storage. 146 | """ 147 | # Check sparse mode 148 | check_sparse_mode(sparse_ratio, cdf_threshold) 149 | 150 | # Replace aggregator forward function 151 | aggregator_fwd = partial(sparse_vggt_aggregator_forward, use_hilbert=use_hilbert) 152 | aggregator.forward = MethodType(aggregator_fwd, aggregator) 153 | 154 | modified_layers = [i for i in range(24)] 155 | if aux_output: 156 | # Pointers to store auxiliary output 157 | aux_output_store = {i: {} for i in modified_layers} 158 | else: 159 | # No auxiliary output 160 | aux_output_store = {i: None for i in modified_layers} 161 | 162 | for i in range(len(aggregator.global_blocks)): 163 | if i not in modified_layers: 164 | continue 165 | 166 | # Replace attention forward function 167 | attn_fwd = partial( 168 | adaptive_sparse_attention_forward, 169 | sparse_ratio=sparse_ratio, 170 | cdf_threshold=cdf_threshold, 171 | pool_mode=pool_mode, 172 | aux_output_store=aux_output_store[i], 173 | aux_sparsity_only=aux_sparsity_only, 174 | num_special_tokens=num_special_tokens, 175 | ) 176 | aggregator.global_blocks[i].attn.forward = MethodType( 177 | attn_fwd, aggregator.global_blocks[i].attn 178 | ) 179 | 180 | if verbose: 181 | print_sparse_info( 182 | sparse_ratio, 183 | cdf_threshold, 184 | pool_mode, 185 | use_hilbert, 186 | aux_output, 187 | aux_sparsity_only, 188 | ) 189 | return aggregator, aux_output_store 190 | -------------------------------------------------------------------------------- /src/sparse_vggt/utils/hilbert.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import lru_cache 3 | from typing import List 4 | 5 | import torch 6 | from einops import rearrange, repeat 7 | 8 | 9 | def gilbert2d(width, height): 10 | """ 11 | Generalized Hilbert ('gilbert') space-filling curve for arbitrary-sized 12 | 2D rectangular grids. Generates discrete 2D coordinates to fill a rectangle 13 | of size (width x height). 14 | """ 15 | 16 | if width >= height: 17 | yield from generate2d(0, 0, width, 0, 0, height) 18 | else: 19 | yield from generate2d(0, 0, 0, height, width, 0) 20 | 21 | 22 | def sgn(x): 23 | return -1 if x < 0 else (1 if x > 0 else 0) 24 | 25 | 26 | def generate2d(x, y, ax, ay, bx, by): 27 | w = abs(ax + ay) 28 | h = abs(bx + by) 29 | 30 | (dax, day) = (sgn(ax), sgn(ay)) # unit major direction 31 | (dbx, dby) = (sgn(bx), sgn(by)) # unit orthogonal direction 32 | 33 | if h == 1: 34 | # trivial row fill 35 | for i in range(0, w): 36 | yield (x, y) 37 | (x, y) = (x + dax, y + day) 38 | return 39 | 40 | if w == 1: 41 | # trivial column fill 42 | for i in range(0, h): 43 | yield (x, y) 44 | (x, y) = (x + dbx, y + dby) 45 | return 46 | 47 | (ax2, ay2) = (ax // 2, ay // 2) 48 | (bx2, by2) = (bx // 2, by // 2) 49 | 50 | w2 = abs(ax2 + ay2) 51 | h2 = abs(bx2 + by2) 52 | 53 | if 2 * w > 3 * h: 54 | if (w2 % 2) and (w > 2): 55 | # prefer even steps 56 | (ax2, ay2) = (ax2 + dax, ay2 + day) 57 | 58 | # long case: split in two parts only 59 | yield from generate2d(x, y, ax2, ay2, bx, by) 60 | yield from generate2d(x + ax2, y + ay2, ax - ax2, ay - ay2, bx, by) 61 | 62 | else: 63 | if (h2 % 2) and (h > 2): 64 | # prefer even steps 65 | (bx2, by2) = (bx2 + dbx, by2 + dby) 66 | 67 | # standard case: one step up, one long horizontal, one step down 68 | yield from generate2d(x, y, bx2, by2, ax2, ay2) 69 | yield from generate2d(x + bx2, y + by2, ax, ay, bx - bx2, by - by2) 70 | yield from generate2d( 71 | x + (ax - dax) + (bx2 - dbx), 72 | y + (ay - day) + (by2 - dby), 73 | -bx2, 74 | -by2, 75 | -(ax - ax2), 76 | -(ay - ay2), 77 | ) 78 | 79 | 80 | @lru_cache(maxsize=16) 81 | def make_hilbert_gather_idx(width, height, inverse=False) -> List[int]: 82 | points = list(gilbert2d(width, height)) 83 | mapping = {y * width + x: i for i, (x, y) in enumerate(points)} 84 | 85 | if inverse: 86 | mapping = {v: k for k, v in mapping.items()} 87 | 88 | gather_idx = [0] * (width * height) 89 | for old_idx, new_idx in mapping.items(): 90 | gather_idx[new_idx] = old_idx 91 | return gather_idx 92 | 93 | 94 | def get_hilbert_permutation(B, N, H, W, C, device, inverse=False): 95 | """Get the gather indices for the Hilbert permutation. 96 | 97 | Args: 98 | inverse (bool): If True, return the gather indices from original to Hilbert permutated. 99 | Otherwise, return the gather indices from Hilbert permutated to original. 100 | 101 | Returns: 102 | gather_idx: gather indices from original to Hilbert permutated 103 | shape: (B, N * H * W, C) 104 | """ 105 | gather_idx = make_hilbert_gather_idx(W, H, inverse=inverse) 106 | gather_idx = torch.tensor(gather_idx, dtype=torch.int64, device=device) 107 | gather_idx = repeat(gather_idx, "T -> B T C", B=B, C=C) # (B, H * W, C) 108 | 109 | # Applye to every frame 110 | frame_offset = torch.arange(N, device=device) * H * W 111 | frame_offset = rearrange(frame_offset, "N -> 1 N 1 1") 112 | gather_idx = repeat(gather_idx, "B T C -> B N T C", N=N) 113 | gather_idx = gather_idx + frame_offset 114 | 115 | gather_idx = rearrange(gather_idx, "B N T C -> B (N T) C") 116 | return gather_idx 117 | 118 | 119 | def hilbert_permute(tokens, H, W, S, reverse=False): 120 | """Apply Hilbert permutation on patch tokens. 121 | 122 | Args: 123 | tokens: (..., H * W + S, C) 124 | N: number of frames 125 | H, W: height and width 126 | S: number of special tokens per frame 127 | reverse: if True, apply inverse Hilbert permutation 128 | 129 | Return: 130 | tokens: (..., H * W + S, C) 131 | """ 132 | assert tokens.shape[-2] == H * W + S, f"{tokens.shape=}" 133 | C = tokens.shape[-1] 134 | 135 | # Squash batch dims into one dim 136 | batch_dims = tokens.shape[:-2] 137 | B_flat = math.prod(batch_dims) if len(batch_dims) > 0 else 1 138 | tokens = tokens.reshape(B_flat, H * W + S, C) 139 | 140 | # Separate patch and special tokens 141 | x_special = tokens[:, :S, :].contiguous() 142 | x_patch = tokens[:, S:, :].contiguous() # (B * N, H * W, C) 143 | 144 | # Hilbert permutation 145 | gather_idx_o2h = get_hilbert_permutation( 146 | B_flat, 1, H, W, C, device=x_patch.device, inverse=reverse 147 | ) 148 | x_patch = torch.gather(x_patch, dim=-2, index=gather_idx_o2h) 149 | 150 | # Combine patch and special tokens 151 | tokens = torch.cat([x_special, x_patch], dim=-2) 152 | tokens = tokens.reshape(batch_dims + (H * W + S, C)) 153 | return tokens 154 | -------------------------------------------------------------------------------- /src/sparse_vggt/utils/sparse_wrapper.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import spas_sage_attn._qattn as qattn 4 | import torch 5 | from spas_sage_attn.quant_per_block import per_block_int8 6 | from spas_sage_attn.utils import ( 7 | block_map_lut_triton, 8 | fill_block_map_triton, 9 | hyperparameter_check, 10 | ) 11 | 12 | 13 | def check_sparse_mode(sparse_ratio, cdf_threshold): 14 | """Check the valid combinations of sparse_ratio, and cdf_threshold for sparse inference. 15 | 16 | Args: 17 | sparse_ratio (float | None): choose a ratio of top blocks 18 | cdf_threshold (float | None): choose blocks that accumulate to a certain threshold 19 | 20 | Four modes (combinations) are allowed: 21 | 1. only specify sparse_ratio 22 | 2. only specify cdf_threshold 23 | 3. specify both sparse_ratio and cdf_threshold 24 | This means that the cdf threshold and sparse ratio are BOTH reached. 25 | """ 26 | use_ratio = sparse_ratio is not None 27 | use_cdf = cdf_threshold is not None 28 | 29 | # Modes 30 | only_use_ratio = use_ratio and (not use_cdf) 31 | only_use_cdf = (not use_ratio) and use_cdf 32 | use_ratio_and_cdf = use_ratio and use_cdf 33 | 34 | assert ( 35 | only_use_ratio + only_use_cdf + use_ratio_and_cdf == 1 36 | ), f"Current: {sparse_ratio=}, {cdf_threshold=}" 37 | 38 | 39 | def check_sparse_mode_three_type(topk, sparse_ratio, cdf_threshold): 40 | """Check the valid combinations of topk, sparse_ratio, and cdf_threshold for sparse inference. 41 | 42 | Args: 43 | topk (int | None): choose the top-k key blocks for each query block 44 | sparse_ratio (float | None): choose a ratio of top blocks 45 | cdf_threshold (float | None): choose blocks that accumulate to a certain threshold 46 | 47 | Four modes (combinations) are allowed: 48 | 1. only specify topk 49 | 2. only specify sparse_ratio 50 | 3. only specify cdf_threshold 51 | 4. specify both sparse_ratio and cdf_threshold 52 | This means that the cdf threshold and sparse ratio are BOTH reached. 53 | """ 54 | use_topk = topk is not None 55 | use_ratio = sparse_ratio is not None 56 | use_cdf = cdf_threshold is not None 57 | 58 | # Modes 59 | only_use_topk = use_topk and (not use_ratio) and (not use_cdf) 60 | only_use_ratio = (not use_topk) and use_ratio and (not use_cdf) 61 | only_use_cdf = (not use_topk) and (not use_ratio) and use_cdf 62 | use_ratio_and_cdf = use_ratio and use_cdf and (not use_topk) 63 | 64 | assert ( 65 | only_use_topk + only_use_ratio + only_use_cdf + use_ratio_and_cdf == 1 66 | ), f"Current: {topk=}, {sparse_ratio=}, {cdf_threshold=}" 67 | 68 | 69 | def get_block_mask( 70 | pooled_score: torch.Tensor, 71 | sink_blocks: int, 72 | topk: int | None = None, 73 | sparse_ratio: float | None = None, 74 | cdf_threshold: float | None = None, 75 | eps: float = 1e-5, 76 | ) -> torch.Tensor: 77 | """ 78 | Args: 79 | pooled_score (Tensor): Pooled attention scores after softmax 80 | Shape: (B, nh, q_blk, k_blk) 81 | where q_blk and k_blk are the number of query and key blocks. 82 | nh: number of heads 83 | sink_blocks (int): number of key blocks (from the beginning) to always be selected 84 | 85 | Returns: 86 | final_map (Bool Tensor): (B, nh, q_blk, k_blk + sink_blocks) 87 | True means the block is selected. 88 | 89 | """ 90 | check_sparse_mode_three_type(topk, sparse_ratio, cdf_threshold) 91 | 92 | B, nh, q_blk, k_blk = pooled_score.shape 93 | assert sink_blocks >= 0 and sink_blocks <= k_blk 94 | 95 | if sparse_ratio is not None: 96 | # Convert sparse ratio to topk 97 | assert sparse_ratio >= 0 and sparse_ratio <= 1 98 | topk = int(k_blk * (1 - sparse_ratio)) 99 | 100 | if topk is not None: 101 | assert topk >= 0 and topk <= k_blk 102 | 103 | if cdf_threshold is not None: 104 | assert cdf_threshold >= 0 and cdf_threshold <= 1 105 | 106 | sorted_score = torch.sort(pooled_score, dim=-1, descending=True) 107 | 108 | num_to_select = None 109 | if cdf_threshold is not None: 110 | cdf = torch.cumsum(sorted_score.values, dim=-1) 111 | cdfthreshd = hyperparameter_check(cdf_threshold, nh, pooled_score.device) 112 | cdfthreshd_ts = cdfthreshd.view(1, nh, 1, 1) 113 | cdfthreshd_ts = cdfthreshd_ts + eps # to avoid numerical error in searchsorted 114 | cdfthreshd_ts = cdfthreshd_ts.expand(B, -1, q_blk, 1).contiguous() 115 | num_to_select = torch.searchsorted(cdf, cdfthreshd_ts, right=True).squeeze(-1) 116 | 117 | if topk is not None: 118 | if num_to_select is None: 119 | num_to_select = torch.full((B, nh, q_blk), topk, device=pooled_score.device) 120 | else: 121 | num_to_select = torch.clamp(num_to_select, min=topk) 122 | 123 | final_map = torch.zeros_like(pooled_score, dtype=torch.bool) 124 | final_map = fill_block_map_triton(final_map, num_to_select, sorted_score.indices) 125 | 126 | if sink_blocks > 0: 127 | # Always select special tokens/blocks 128 | ones_shape = list(final_map.shape) 129 | ones_shape[-1] = sink_blocks 130 | trailing_ones = torch.ones(ones_shape, device=final_map.device).bool() 131 | final_map = torch.cat([final_map, trailing_ones], dim=-1) 132 | 133 | return final_map 134 | 135 | 136 | def block_sparse_attn_cuda( 137 | query: torch.Tensor, 138 | key: torch.Tensor, 139 | value: torch.Tensor, 140 | pooled_score: torch.Tensor, 141 | topk: int | None = None, 142 | sparse_ratio: float | None = None, 143 | cdf_threshold: float | None = None, 144 | return_sparsity: bool = False, 145 | dtype: torch.dtype = torch.float16, 146 | out_dtype: torch.dtype = torch.float32, 147 | ): 148 | """Block sparse attention using SpargeAttn kernels 149 | 150 | Args: 151 | query (torch.Tensor): (B, nheads, Tq, head_dim) 152 | key (torch.Tensor): (B, nheads, Tk, head_dim) 153 | value (torch.Tensor): (B, nheads, Tk, head_dim) 154 | sink tokens are appended to the end of key and value. 155 | pooled_score (torch.Tensor): (B, nheads, q_blk, k_blk) 156 | where q_blk and k_blk are the number of query and key blocks. 157 | The score here *doesn't* contain the sink tokens. 158 | topk, sparse_ratio, cdf_threshold: the mode of sparse attention 159 | - topk: choose the top-k key blocks for each query block 160 | - sparse_ratio: choose a ratio of top blocks 161 | - cdf_threshold: choose blocks that accumulate to a certain threshold 162 | 163 | Returns: 164 | out: Attention output of shape (B, nheads, T, head_dim) 165 | """ 166 | # Hardcode some arguments for using SpargeAttn kernels 167 | _is_causal = 0 168 | KBLK = 64 169 | pvthreshd = 1e10 170 | pvthreshd = hyperparameter_check(pvthreshd, query.size(-3), query.device) 171 | 172 | # Get block mask 173 | Tk = key.shape[-2] 174 | orig_Kblk = pooled_score.shape[-1] 175 | total_Kblk = math.ceil(Tk / KBLK) 176 | sink_blocks = total_Kblk - orig_Kblk 177 | final_map = get_block_mask( 178 | pooled_score, 179 | sink_blocks=sink_blocks, 180 | topk=topk, 181 | sparse_ratio=sparse_ratio, 182 | cdf_threshold=cdf_threshold, 183 | ) 184 | lut, valid_block_num = block_map_lut_triton(final_map) 185 | 186 | # Type conversion 187 | query, key, value = ( 188 | query.contiguous().to(dtype), 189 | key.contiguous().to(dtype), 190 | value.contiguous().to(dtype), 191 | ) 192 | 193 | # Quantization 194 | km = key.mean(dim=-2, keepdim=True) 195 | q_int8, q_scale, k_int8, k_scale = per_block_int8(query, key - km) 196 | q_scale = q_scale.squeeze(-1) 197 | k_scale = k_scale.squeeze(-1) 198 | 199 | # Get softmax scale 200 | hd = query.shape[-1] 201 | scale = 1.0 / (hd**0.5) 202 | 203 | # SpargeAttn attention kernel 204 | o = torch.empty_like(query) 205 | qattn.qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold( 206 | q_int8, 207 | k_int8, 208 | value, 209 | o, 210 | lut, 211 | valid_block_num, 212 | pvthreshd, 213 | q_scale, 214 | k_scale, 215 | 1, 216 | _is_causal, 217 | 1, 218 | scale, 219 | 0, 220 | ) 221 | o = o.to(out_dtype) 222 | if return_sparsity: 223 | sparsity = 1 - final_map.float().mean().item() 224 | return o, sparsity 225 | else: 226 | return o 227 | -------------------------------------------------------------------------------- /src/sparse_vggt/utils/tokens.py: -------------------------------------------------------------------------------- 1 | def get_patch_tokens(x, N, P, S): 2 | """Get patch tokens from the input tensor. 3 | 4 | Args: 5 | x (torch.Tensor): input tensor of shape (..., N * P, C) 6 | 7 | Returns: 8 | patch (torch.Tensor): patch tokens of shape (..., N * (P - S), C) 9 | 10 | where 11 | N: number of frames 12 | P: number of patch tokens + special tokens 13 | S: number of special tokens 14 | """ 15 | assert x.shape[-2] == N * P 16 | 17 | batch_dims = x.shape[:-2] 18 | channels = x.shape[-1] 19 | 20 | x = x.view(batch_dims + (N, P, channels)) 21 | patch = x[..., S:, :] # (..., N, P - S, C) 22 | patch = patch.reshape(batch_dims + (N * (P - S), channels)) 23 | return patch.contiguous() 24 | 25 | 26 | def get_special_tokens(x, N, P, S): 27 | """Get special tokens from the input tensor. 28 | 29 | Args: 30 | x (torch.Tensor): input tensor of shape (..., N * P, C) 31 | 32 | Returns: 33 | special (torch.Tensor): special tokens of shape (..., N * S, C) 34 | """ 35 | assert x.shape[-2] == N * P 36 | 37 | batch_dims = x.shape[:-2] 38 | channels = x.shape[-1] 39 | 40 | x = x.view(batch_dims + (N, P, channels)) 41 | special = x[..., :S, :] # (..., N, S, C) 42 | special = special.reshape(batch_dims + (N * S, channels)) 43 | if special.numel() == 0: 44 | return None 45 | return special 46 | -------------------------------------------------------------------------------- /uv.lock: -------------------------------------------------------------------------------- 1 | version = 1 2 | revision = 1 3 | requires-python = ">=3.12" 4 | resolution-markers = [ 5 | "sys_platform == 'darwin'", 6 | "platform_machine == 'aarch64' and sys_platform == 'linux'", 7 | "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')", 8 | ] 9 | 10 | [[package]] 11 | name = "appnope" 12 | version = "0.1.4" 13 | source = { registry = "https://pypi.org/simple" } 14 | sdist = { url = "https://files.pythonhosted.org/packages/35/5d/752690df9ef5b76e169e68d6a129fa6d08a7100ca7f754c89495db3c6019/appnope-0.1.4.tar.gz", hash = "sha256:1de3860566df9caf38f01f86f65e0e13e379af54f9e4bee1e66b48f2efffd1ee", size = 4170 } 15 | wheels = [ 16 | { url = "https://files.pythonhosted.org/packages/81/29/5ecc3a15d5a33e31b26c11426c45c501e439cb865d0bff96315d86443b78/appnope-0.1.4-py2.py3-none-any.whl", hash = "sha256:502575ee11cd7a28c0205f379b525beefebab9d161b7c964670864014ed7213c", size = 4321 }, 17 | ] 18 | 19 | [[package]] 20 | name = "asttokens" 21 | version = "3.0.0" 22 | source = { registry = "https://pypi.org/simple" } 23 | sdist = { url = "https://files.pythonhosted.org/packages/4a/e7/82da0a03e7ba5141f05cce0d302e6eed121ae055e0456ca228bf693984bc/asttokens-3.0.0.tar.gz", hash = "sha256:0dcd8baa8d62b0c1d118b399b2ddba3c4aff271d0d7a9e0d4c1681c79035bbc7", size = 61978 } 24 | wheels = [ 25 | { url = "https://files.pythonhosted.org/packages/25/8a/c46dcc25341b5bce5472c718902eb3d38600a903b14fa6aeecef3f21a46f/asttokens-3.0.0-py3-none-any.whl", hash = "sha256:e3078351a059199dd5138cb1c706e6430c05eff2ff136af5eb4790f9d28932e2", size = 26918 }, 26 | ] 27 | 28 | [[package]] 29 | name = "certifi" 30 | version = "2025.8.3" 31 | source = { registry = "https://pypi.org/simple" } 32 | sdist = { url = "https://files.pythonhosted.org/packages/dc/67/960ebe6bf230a96cda2e0abcf73af550ec4f090005363542f0765df162e0/certifi-2025.8.3.tar.gz", hash = "sha256:e564105f78ded564e3ae7c923924435e1daa7463faeab5bb932bc53ffae63407", size = 162386 } 33 | wheels = [ 34 | { url = "https://files.pythonhosted.org/packages/e5/48/1549795ba7742c948d2ad169c1c8cdbae65bc450d6cd753d124b17c8cd32/certifi-2025.8.3-py3-none-any.whl", hash = "sha256:f6c12493cfb1b06ba2ff328595af9350c65d6644968e5d3a2ffd78699af217a5", size = 161216 }, 35 | ] 36 | 37 | [[package]] 38 | name = "cffi" 39 | version = "1.17.1" 40 | source = { registry = "https://pypi.org/simple" } 41 | dependencies = [ 42 | { name = "pycparser" }, 43 | ] 44 | sdist = { url = "https://files.pythonhosted.org/packages/fc/97/c783634659c2920c3fc70419e3af40972dbaf758daa229a7d6ea6135c90d/cffi-1.17.1.tar.gz", hash = "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824", size = 516621 } 45 | wheels = [ 46 | { url = "https://files.pythonhosted.org/packages/5a/84/e94227139ee5fb4d600a7a4927f322e1d4aea6fdc50bd3fca8493caba23f/cffi-1.17.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:805b4371bf7197c329fcb3ead37e710d1bca9da5d583f5073b799d5c5bd1eee4", size = 183178 }, 47 | { url = "https://files.pythonhosted.org/packages/da/ee/fb72c2b48656111c4ef27f0f91da355e130a923473bf5ee75c5643d00cca/cffi-1.17.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:733e99bc2df47476e3848417c5a4540522f234dfd4ef3ab7fafdf555b082ec0c", size = 178840 }, 48 | { url = "https://files.pythonhosted.org/packages/cc/b6/db007700f67d151abadf508cbfd6a1884f57eab90b1bb985c4c8c02b0f28/cffi-1.17.1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1257bdabf294dceb59f5e70c64a3e2f462c30c7ad68092d01bbbfb1c16b1ba36", size = 454803 }, 49 | { url = "https://files.pythonhosted.org/packages/1a/df/f8d151540d8c200eb1c6fba8cd0dfd40904f1b0682ea705c36e6c2e97ab3/cffi-1.17.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da95af8214998d77a98cc14e3a3bd00aa191526343078b530ceb0bd710fb48a5", size = 478850 }, 50 | { url = "https://files.pythonhosted.org/packages/28/c0/b31116332a547fd2677ae5b78a2ef662dfc8023d67f41b2a83f7c2aa78b1/cffi-1.17.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d63afe322132c194cf832bfec0dc69a99fb9bb6bbd550f161a49e9e855cc78ff", size = 485729 }, 51 | { url = "https://files.pythonhosted.org/packages/91/2b/9a1ddfa5c7f13cab007a2c9cc295b70fbbda7cb10a286aa6810338e60ea1/cffi-1.17.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f79fc4fc25f1c8698ff97788206bb3c2598949bfe0fef03d299eb1b5356ada99", size = 471256 }, 52 | { url = "https://files.pythonhosted.org/packages/b2/d5/da47df7004cb17e4955df6a43d14b3b4ae77737dff8bf7f8f333196717bf/cffi-1.17.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b62ce867176a75d03a665bad002af8e6d54644fad99a3c70905c543130e39d93", size = 479424 }, 53 | { url = "https://files.pythonhosted.org/packages/0b/ac/2a28bcf513e93a219c8a4e8e125534f4f6db03e3179ba1c45e949b76212c/cffi-1.17.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:386c8bf53c502fff58903061338ce4f4950cbdcb23e2902d86c0f722b786bbe3", size = 484568 }, 54 | { url = "https://files.pythonhosted.org/packages/d4/38/ca8a4f639065f14ae0f1d9751e70447a261f1a30fa7547a828ae08142465/cffi-1.17.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4ceb10419a9adf4460ea14cfd6bc43d08701f0835e979bf821052f1805850fe8", size = 488736 }, 55 | { url = "https://files.pythonhosted.org/packages/86/c5/28b2d6f799ec0bdecf44dced2ec5ed43e0eb63097b0f58c293583b406582/cffi-1.17.1-cp312-cp312-win32.whl", hash = "sha256:a08d7e755f8ed21095a310a693525137cfe756ce62d066e53f502a83dc550f65", size = 172448 }, 56 | { url = "https://files.pythonhosted.org/packages/50/b9/db34c4755a7bd1cb2d1603ac3863f22bcecbd1ba29e5ee841a4bc510b294/cffi-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:51392eae71afec0d0c8fb1a53b204dbb3bcabcb3c9b807eedf3e1e6ccf2de903", size = 181976 }, 57 | { url = "https://files.pythonhosted.org/packages/8d/f8/dd6c246b148639254dad4d6803eb6a54e8c85c6e11ec9df2cffa87571dbe/cffi-1.17.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f3a2b4222ce6b60e2e8b337bb9596923045681d71e5a082783484d845390938e", size = 182989 }, 58 | { url = "https://files.pythonhosted.org/packages/8b/f1/672d303ddf17c24fc83afd712316fda78dc6fce1cd53011b839483e1ecc8/cffi-1.17.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0984a4925a435b1da406122d4d7968dd861c1385afe3b45ba82b750f229811e2", size = 178802 }, 59 | { url = "https://files.pythonhosted.org/packages/0e/2d/eab2e858a91fdff70533cab61dcff4a1f55ec60425832ddfdc9cd36bc8af/cffi-1.17.1-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d01b12eeeb4427d3110de311e1774046ad344f5b1a7403101878976ecd7a10f3", size = 454792 }, 60 | { url = "https://files.pythonhosted.org/packages/75/b2/fbaec7c4455c604e29388d55599b99ebcc250a60050610fadde58932b7ee/cffi-1.17.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:706510fe141c86a69c8ddc029c7910003a17353970cff3b904ff0686a5927683", size = 478893 }, 61 | { url = "https://files.pythonhosted.org/packages/4f/b7/6e4a2162178bf1935c336d4da8a9352cccab4d3a5d7914065490f08c0690/cffi-1.17.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de55b766c7aa2e2a3092c51e0483d700341182f08e67c63630d5b6f200bb28e5", size = 485810 }, 62 | { url = "https://files.pythonhosted.org/packages/c7/8a/1d0e4a9c26e54746dc08c2c6c037889124d4f59dffd853a659fa545f1b40/cffi-1.17.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c59d6e989d07460165cc5ad3c61f9fd8f1b4796eacbd81cee78957842b834af4", size = 471200 }, 63 | { url = "https://files.pythonhosted.org/packages/26/9f/1aab65a6c0db35f43c4d1b4f580e8df53914310afc10ae0397d29d697af4/cffi-1.17.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd398dbc6773384a17fe0d3e7eeb8d1a21c2200473ee6806bb5e6a8e62bb73dd", size = 479447 }, 64 | { url = "https://files.pythonhosted.org/packages/5f/e4/fb8b3dd8dc0e98edf1135ff067ae070bb32ef9d509d6cb0f538cd6f7483f/cffi-1.17.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:3edc8d958eb099c634dace3c7e16560ae474aa3803a5df240542b305d14e14ed", size = 484358 }, 65 | { url = "https://files.pythonhosted.org/packages/f1/47/d7145bf2dc04684935d57d67dff9d6d795b2ba2796806bb109864be3a151/cffi-1.17.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:72e72408cad3d5419375fc87d289076ee319835bdfa2caad331e377589aebba9", size = 488469 }, 66 | { url = "https://files.pythonhosted.org/packages/bf/ee/f94057fa6426481d663b88637a9a10e859e492c73d0384514a17d78ee205/cffi-1.17.1-cp313-cp313-win32.whl", hash = "sha256:e03eab0a8677fa80d646b5ddece1cbeaf556c313dcfac435ba11f107ba117b5d", size = 172475 }, 67 | { url = "https://files.pythonhosted.org/packages/7c/fc/6a8cb64e5f0324877d503c854da15d76c1e50eb722e320b15345c4d0c6de/cffi-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a", size = 182009 }, 68 | ] 69 | 70 | [[package]] 71 | name = "cfgv" 72 | version = "3.4.0" 73 | source = { registry = "https://pypi.org/simple" } 74 | sdist = { url = "https://files.pythonhosted.org/packages/11/74/539e56497d9bd1d484fd863dd69cbbfa653cd2aa27abfe35653494d85e94/cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560", size = 7114 } 75 | wheels = [ 76 | { url = "https://files.pythonhosted.org/packages/c5/55/51844dd50c4fc7a33b653bfaba4c2456f06955289ca770a5dbd5fd267374/cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9", size = 7249 }, 77 | ] 78 | 79 | [[package]] 80 | name = "charset-normalizer" 81 | version = "3.4.3" 82 | source = { registry = "https://pypi.org/simple" } 83 | sdist = { url = "https://files.pythonhosted.org/packages/83/2d/5fd176ceb9b2fc619e63405525573493ca23441330fcdaee6bef9460e924/charset_normalizer-3.4.3.tar.gz", hash = "sha256:6fce4b8500244f6fcb71465d4a4930d132ba9ab8e71a7859e6a5d59851068d14", size = 122371 } 84 | wheels = [ 85 | { url = "https://files.pythonhosted.org/packages/e9/5e/14c94999e418d9b87682734589404a25854d5f5d0408df68bc15b6ff54bb/charset_normalizer-3.4.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:e28e334d3ff134e88989d90ba04b47d84382a828c061d0d1027b1b12a62b39b1", size = 205655 }, 86 | { url = "https://files.pythonhosted.org/packages/7d/a8/c6ec5d389672521f644505a257f50544c074cf5fc292d5390331cd6fc9c3/charset_normalizer-3.4.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0cacf8f7297b0c4fcb74227692ca46b4a5852f8f4f24b3c766dd94a1075c4884", size = 146223 }, 87 | { url = "https://files.pythonhosted.org/packages/fc/eb/a2ffb08547f4e1e5415fb69eb7db25932c52a52bed371429648db4d84fb1/charset_normalizer-3.4.3-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c6fd51128a41297f5409deab284fecbe5305ebd7e5a1f959bee1c054622b7018", size = 159366 }, 88 | { url = "https://files.pythonhosted.org/packages/82/10/0fd19f20c624b278dddaf83b8464dcddc2456cb4b02bb902a6da126b87a1/charset_normalizer-3.4.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:3cfb2aad70f2c6debfbcb717f23b7eb55febc0bb23dcffc0f076009da10c6392", size = 157104 }, 89 | { url = "https://files.pythonhosted.org/packages/16/ab/0233c3231af734f5dfcf0844aa9582d5a1466c985bbed6cedab85af9bfe3/charset_normalizer-3.4.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1606f4a55c0fd363d754049cdf400175ee96c992b1f8018b993941f221221c5f", size = 151830 }, 90 | { url = "https://files.pythonhosted.org/packages/ae/02/e29e22b4e02839a0e4a06557b1999d0a47db3567e82989b5bb21f3fbbd9f/charset_normalizer-3.4.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:027b776c26d38b7f15b26a5da1044f376455fb3766df8fc38563b4efbc515154", size = 148854 }, 91 | { url = "https://files.pythonhosted.org/packages/05/6b/e2539a0a4be302b481e8cafb5af8792da8093b486885a1ae4d15d452bcec/charset_normalizer-3.4.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:42e5088973e56e31e4fa58eb6bd709e42fc03799c11c42929592889a2e54c491", size = 160670 }, 92 | { url = "https://files.pythonhosted.org/packages/31/e7/883ee5676a2ef217a40ce0bffcc3d0dfbf9e64cbcfbdf822c52981c3304b/charset_normalizer-3.4.3-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:cc34f233c9e71701040d772aa7490318673aa7164a0efe3172b2981218c26d93", size = 158501 }, 93 | { url = "https://files.pythonhosted.org/packages/c1/35/6525b21aa0db614cf8b5792d232021dca3df7f90a1944db934efa5d20bb1/charset_normalizer-3.4.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:320e8e66157cc4e247d9ddca8e21f427efc7a04bbd0ac8a9faf56583fa543f9f", size = 153173 }, 94 | { url = "https://files.pythonhosted.org/packages/50/ee/f4704bad8201de513fdc8aac1cabc87e38c5818c93857140e06e772b5892/charset_normalizer-3.4.3-cp312-cp312-win32.whl", hash = "sha256:fb6fecfd65564f208cbf0fba07f107fb661bcd1a7c389edbced3f7a493f70e37", size = 99822 }, 95 | { url = "https://files.pythonhosted.org/packages/39/f5/3b3836ca6064d0992c58c7561c6b6eee1b3892e9665d650c803bd5614522/charset_normalizer-3.4.3-cp312-cp312-win_amd64.whl", hash = "sha256:86df271bf921c2ee3818f0522e9a5b8092ca2ad8b065ece5d7d9d0e9f4849bcc", size = 107543 }, 96 | { url = "https://files.pythonhosted.org/packages/65/ca/2135ac97709b400c7654b4b764daf5c5567c2da45a30cdd20f9eefe2d658/charset_normalizer-3.4.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:14c2a87c65b351109f6abfc424cab3927b3bdece6f706e4d12faaf3d52ee5efe", size = 205326 }, 97 | { url = "https://files.pythonhosted.org/packages/71/11/98a04c3c97dd34e49c7d247083af03645ca3730809a5509443f3c37f7c99/charset_normalizer-3.4.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:41d1fc408ff5fdfb910200ec0e74abc40387bccb3252f3f27c0676731df2b2c8", size = 146008 }, 98 | { url = "https://files.pythonhosted.org/packages/60/f5/4659a4cb3c4ec146bec80c32d8bb16033752574c20b1252ee842a95d1a1e/charset_normalizer-3.4.3-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:1bb60174149316da1c35fa5233681f7c0f9f514509b8e399ab70fea5f17e45c9", size = 159196 }, 99 | { url = "https://files.pythonhosted.org/packages/86/9e/f552f7a00611f168b9a5865a1414179b2c6de8235a4fa40189f6f79a1753/charset_normalizer-3.4.3-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:30d006f98569de3459c2fc1f2acde170b7b2bd265dc1943e87e1a4efe1b67c31", size = 156819 }, 100 | { url = "https://files.pythonhosted.org/packages/7e/95/42aa2156235cbc8fa61208aded06ef46111c4d3f0de233107b3f38631803/charset_normalizer-3.4.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:416175faf02e4b0810f1f38bcb54682878a4af94059a1cd63b8747244420801f", size = 151350 }, 101 | { url = "https://files.pythonhosted.org/packages/c2/a9/3865b02c56f300a6f94fc631ef54f0a8a29da74fb45a773dfd3dcd380af7/charset_normalizer-3.4.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6aab0f181c486f973bc7262a97f5aca3ee7e1437011ef0c2ec04b5a11d16c927", size = 148644 }, 102 | { url = "https://files.pythonhosted.org/packages/77/d9/cbcf1a2a5c7d7856f11e7ac2d782aec12bdfea60d104e60e0aa1c97849dc/charset_normalizer-3.4.3-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:fdabf8315679312cfa71302f9bd509ded4f2f263fb5b765cf1433b39106c3cc9", size = 160468 }, 103 | { url = "https://files.pythonhosted.org/packages/f6/42/6f45efee8697b89fda4d50580f292b8f7f9306cb2971d4b53f8914e4d890/charset_normalizer-3.4.3-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:bd28b817ea8c70215401f657edef3a8aa83c29d447fb0b622c35403780ba11d5", size = 158187 }, 104 | { url = "https://files.pythonhosted.org/packages/70/99/f1c3bdcfaa9c45b3ce96f70b14f070411366fa19549c1d4832c935d8e2c3/charset_normalizer-3.4.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:18343b2d246dc6761a249ba1fb13f9ee9a2bcd95decc767319506056ea4ad4dc", size = 152699 }, 105 | { url = "https://files.pythonhosted.org/packages/a3/ad/b0081f2f99a4b194bcbb1934ef3b12aa4d9702ced80a37026b7607c72e58/charset_normalizer-3.4.3-cp313-cp313-win32.whl", hash = "sha256:6fb70de56f1859a3f71261cbe41005f56a7842cc348d3aeb26237560bfa5e0ce", size = 99580 }, 106 | { url = "https://files.pythonhosted.org/packages/9a/8f/ae790790c7b64f925e5c953b924aaa42a243fb778fed9e41f147b2a5715a/charset_normalizer-3.4.3-cp313-cp313-win_amd64.whl", hash = "sha256:cf1ebb7d78e1ad8ec2a8c4732c7be2e736f6e5123a4146c5b89c9d1f585f8cef", size = 107366 }, 107 | { url = "https://files.pythonhosted.org/packages/8e/91/b5a06ad970ddc7a0e513112d40113e834638f4ca1120eb727a249fb2715e/charset_normalizer-3.4.3-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:3cd35b7e8aedeb9e34c41385fda4f73ba609e561faedfae0a9e75e44ac558a15", size = 204342 }, 108 | { url = "https://files.pythonhosted.org/packages/ce/ec/1edc30a377f0a02689342f214455c3f6c2fbedd896a1d2f856c002fc3062/charset_normalizer-3.4.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b89bc04de1d83006373429975f8ef9e7932534b8cc9ca582e4db7d20d91816db", size = 145995 }, 109 | { url = "https://files.pythonhosted.org/packages/17/e5/5e67ab85e6d22b04641acb5399c8684f4d37caf7558a53859f0283a650e9/charset_normalizer-3.4.3-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:2001a39612b241dae17b4687898843f254f8748b796a2e16f1051a17078d991d", size = 158640 }, 110 | { url = "https://files.pythonhosted.org/packages/f1/e5/38421987f6c697ee3722981289d554957c4be652f963d71c5e46a262e135/charset_normalizer-3.4.3-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8dcfc373f888e4fb39a7bc57e93e3b845e7f462dacc008d9749568b1c4ece096", size = 156636 }, 111 | { url = "https://files.pythonhosted.org/packages/a0/e4/5a075de8daa3ec0745a9a3b54467e0c2967daaaf2cec04c845f73493e9a1/charset_normalizer-3.4.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:18b97b8404387b96cdbd30ad660f6407799126d26a39ca65729162fd810a99aa", size = 150939 }, 112 | { url = "https://files.pythonhosted.org/packages/02/f7/3611b32318b30974131db62b4043f335861d4d9b49adc6d57c1149cc49d4/charset_normalizer-3.4.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:ccf600859c183d70eb47e05a44cd80a4ce77394d1ac0f79dbd2dd90a69a3a049", size = 148580 }, 113 | { url = "https://files.pythonhosted.org/packages/7e/61/19b36f4bd67f2793ab6a99b979b4e4f3d8fc754cbdffb805335df4337126/charset_normalizer-3.4.3-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:53cd68b185d98dde4ad8990e56a58dea83a4162161b1ea9272e5c9182ce415e0", size = 159870 }, 114 | { url = "https://files.pythonhosted.org/packages/06/57/84722eefdd338c04cf3030ada66889298eaedf3e7a30a624201e0cbe424a/charset_normalizer-3.4.3-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:30a96e1e1f865f78b030d65241c1ee850cdf422d869e9028e2fc1d5e4db73b92", size = 157797 }, 115 | { url = "https://files.pythonhosted.org/packages/72/2a/aff5dd112b2f14bcc3462c312dce5445806bfc8ab3a7328555da95330e4b/charset_normalizer-3.4.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:d716a916938e03231e86e43782ca7878fb602a125a91e7acb8b5112e2e96ac16", size = 152224 }, 116 | { url = "https://files.pythonhosted.org/packages/b7/8c/9839225320046ed279c6e839d51f028342eb77c91c89b8ef2549f951f3ec/charset_normalizer-3.4.3-cp314-cp314-win32.whl", hash = "sha256:c6dbd0ccdda3a2ba7c2ecd9d77b37f3b5831687d8dc1b6ca5f56a4880cc7b7ce", size = 100086 }, 117 | { url = "https://files.pythonhosted.org/packages/ee/7a/36fbcf646e41f710ce0a563c1c9a343c6edf9be80786edeb15b6f62e17db/charset_normalizer-3.4.3-cp314-cp314-win_amd64.whl", hash = "sha256:73dc19b562516fc9bcf6e5d6e596df0b4eb98d87e4f79f3ae71840e6ed21361c", size = 107400 }, 118 | { url = "https://files.pythonhosted.org/packages/8a/1f/f041989e93b001bc4e44bb1669ccdcf54d3f00e628229a85b08d330615c5/charset_normalizer-3.4.3-py3-none-any.whl", hash = "sha256:ce571ab16d890d23b5c278547ba694193a45011ff86a9162a71307ed9f86759a", size = 53175 }, 119 | ] 120 | 121 | [[package]] 122 | name = "colorama" 123 | version = "0.4.6" 124 | source = { registry = "https://pypi.org/simple" } 125 | sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697 } 126 | wheels = [ 127 | { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 }, 128 | ] 129 | 130 | [[package]] 131 | name = "comm" 132 | version = "0.2.3" 133 | source = { registry = "https://pypi.org/simple" } 134 | sdist = { url = "https://files.pythonhosted.org/packages/4c/13/7d740c5849255756bc17888787313b61fd38a0a8304fc4f073dfc46122aa/comm-0.2.3.tar.gz", hash = "sha256:2dc8048c10962d55d7ad693be1e7045d891b7ce8d999c97963a5e3e99c055971", size = 6319 } 135 | wheels = [ 136 | { url = "https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl", hash = "sha256:c615d91d75f7f04f095b30d1c1711babd43bdc6419c1be9886a85f2f4e489417", size = 7294 }, 137 | ] 138 | 139 | [[package]] 140 | name = "debugpy" 141 | version = "1.8.16" 142 | source = { registry = "https://pypi.org/simple" } 143 | sdist = { url = "https://files.pythonhosted.org/packages/ca/d4/722d0bcc7986172ac2ef3c979ad56a1030e3afd44ced136d45f8142b1f4a/debugpy-1.8.16.tar.gz", hash = "sha256:31e69a1feb1cf6b51efbed3f6c9b0ef03bc46ff050679c4be7ea6d2e23540870", size = 1643809 } 144 | wheels = [ 145 | { url = "https://files.pythonhosted.org/packages/61/fb/0387c0e108d842c902801bc65ccc53e5b91d8c169702a9bbf4f7efcedf0c/debugpy-1.8.16-cp312-cp312-macosx_14_0_universal2.whl", hash = "sha256:b202e2843e32e80b3b584bcebfe0e65e0392920dc70df11b2bfe1afcb7a085e4", size = 2511822 }, 146 | { url = "https://files.pythonhosted.org/packages/37/44/19e02745cae22bf96440141f94e15a69a1afaa3a64ddfc38004668fcdebf/debugpy-1.8.16-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64473c4a306ba11a99fe0bb14622ba4fbd943eb004847d9b69b107bde45aa9ea", size = 4230135 }, 147 | { url = "https://files.pythonhosted.org/packages/f3/0b/19b1ba5ee4412f303475a2c7ad5858efb99c90eae5ec627aa6275c439957/debugpy-1.8.16-cp312-cp312-win32.whl", hash = "sha256:833a61ed446426e38b0dd8be3e9d45ae285d424f5bf6cd5b2b559c8f12305508", size = 5281271 }, 148 | { url = "https://files.pythonhosted.org/packages/b1/e0/bc62e2dc141de53bd03e2c7cb9d7011de2e65e8bdcdaa26703e4d28656ba/debugpy-1.8.16-cp312-cp312-win_amd64.whl", hash = "sha256:75f204684581e9ef3dc2f67687c3c8c183fde2d6675ab131d94084baf8084121", size = 5323149 }, 149 | { url = "https://files.pythonhosted.org/packages/62/66/607ab45cc79e60624df386e233ab64a6d8d39ea02e7f80e19c1d451345bb/debugpy-1.8.16-cp313-cp313-macosx_14_0_universal2.whl", hash = "sha256:85df3adb1de5258dca910ae0bb185e48c98801ec15018a263a92bb06be1c8787", size = 2496157 }, 150 | { url = "https://files.pythonhosted.org/packages/4d/a0/c95baae08a75bceabb79868d663a0736655e427ab9c81fb848da29edaeac/debugpy-1.8.16-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bee89e948bc236a5c43c4214ac62d28b29388453f5fd328d739035e205365f0b", size = 4222491 }, 151 | { url = "https://files.pythonhosted.org/packages/5b/2f/1c8db6ddd8a257c3cd2c46413b267f1d5fa3df910401c899513ce30392d6/debugpy-1.8.16-cp313-cp313-win32.whl", hash = "sha256:cf358066650439847ec5ff3dae1da98b5461ea5da0173d93d5e10f477c94609a", size = 5281126 }, 152 | { url = "https://files.pythonhosted.org/packages/d3/ba/c3e154ab307366d6c5a9c1b68de04914e2ce7fa2f50d578311d8cc5074b2/debugpy-1.8.16-cp313-cp313-win_amd64.whl", hash = "sha256:b5aea1083f6f50023e8509399d7dc6535a351cc9f2e8827d1e093175e4d9fa4c", size = 5323094 }, 153 | { url = "https://files.pythonhosted.org/packages/52/57/ecc9ae29fa5b2d90107cd1d9bf8ed19aacb74b2264d986ae9d44fe9bdf87/debugpy-1.8.16-py2.py3-none-any.whl", hash = "sha256:19c9521962475b87da6f673514f7fd610328757ec993bf7ec0d8c96f9a325f9e", size = 5287700 }, 154 | ] 155 | 156 | [[package]] 157 | name = "decorator" 158 | version = "5.2.1" 159 | source = { registry = "https://pypi.org/simple" } 160 | sdist = { url = "https://files.pythonhosted.org/packages/43/fa/6d96a0978d19e17b68d634497769987b16c8f4cd0a7a05048bec693caa6b/decorator-5.2.1.tar.gz", hash = "sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360", size = 56711 } 161 | wheels = [ 162 | { url = "https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a", size = 9190 }, 163 | ] 164 | 165 | [[package]] 166 | name = "distlib" 167 | version = "0.4.0" 168 | source = { registry = "https://pypi.org/simple" } 169 | sdist = { url = "https://files.pythonhosted.org/packages/96/8e/709914eb2b5749865801041647dc7f4e6d00b549cfe88b65ca192995f07c/distlib-0.4.0.tar.gz", hash = "sha256:feec40075be03a04501a973d81f633735b4b69f98b05450592310c0f401a4e0d", size = 614605 } 170 | wheels = [ 171 | { url = "https://files.pythonhosted.org/packages/33/6b/e0547afaf41bf2c42e52430072fa5658766e3d65bd4b03a563d1b6336f57/distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16", size = 469047 }, 172 | ] 173 | 174 | [[package]] 175 | name = "einops" 176 | version = "0.8.1" 177 | source = { registry = "https://pypi.org/simple" } 178 | sdist = { url = "https://files.pythonhosted.org/packages/e5/81/df4fbe24dff8ba3934af99044188e20a98ed441ad17a274539b74e82e126/einops-0.8.1.tar.gz", hash = "sha256:de5d960a7a761225532e0f1959e5315ebeafc0cd43394732f103ca44b9837e84", size = 54805 } 179 | wheels = [ 180 | { url = "https://files.pythonhosted.org/packages/87/62/9773de14fe6c45c23649e98b83231fffd7b9892b6cf863251dc2afa73643/einops-0.8.1-py3-none-any.whl", hash = "sha256:919387eb55330f5757c6bea9165c5ff5cfe63a642682ea788a6d472576d81737", size = 64359 }, 181 | ] 182 | 183 | [[package]] 184 | name = "executing" 185 | version = "2.2.1" 186 | source = { registry = "https://pypi.org/simple" } 187 | sdist = { url = "https://files.pythonhosted.org/packages/cc/28/c14e053b6762b1044f34a13aab6859bbf40456d37d23aa286ac24cfd9a5d/executing-2.2.1.tar.gz", hash = "sha256:3632cc370565f6648cc328b32435bd120a1e4ebb20c77e3fdde9a13cd1e533c4", size = 1129488 } 188 | wheels = [ 189 | { url = "https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl", hash = "sha256:760643d3452b4d777d295bb167ccc74c64a81df23fb5e08eff250c425a4b2017", size = 28317 }, 190 | ] 191 | 192 | [[package]] 193 | name = "filelock" 194 | version = "3.19.1" 195 | source = { registry = "https://pypi.org/simple" } 196 | sdist = { url = "https://files.pythonhosted.org/packages/40/bb/0ab3e58d22305b6f5440629d20683af28959bf793d98d11950e305c1c326/filelock-3.19.1.tar.gz", hash = "sha256:66eda1888b0171c998b35be2bcc0f6d75c388a7ce20c3f3f37aa8e96c2dddf58", size = 17687 } 197 | wheels = [ 198 | { url = "https://files.pythonhosted.org/packages/42/14/42b2651a2f46b022ccd948bca9f2d5af0fd8929c4eec235b8d6d844fbe67/filelock-3.19.1-py3-none-any.whl", hash = "sha256:d38e30481def20772f5baf097c122c3babc4fcdb7e14e57049eb9d88c6dc017d", size = 15988 }, 199 | ] 200 | 201 | [[package]] 202 | name = "fsspec" 203 | version = "2025.7.0" 204 | source = { registry = "https://pypi.org/simple" } 205 | sdist = { url = "https://files.pythonhosted.org/packages/8b/02/0835e6ab9cfc03916fe3f78c0956cfcdb6ff2669ffa6651065d5ebf7fc98/fsspec-2025.7.0.tar.gz", hash = "sha256:786120687ffa54b8283d942929540d8bc5ccfa820deb555a2b5d0ed2b737bf58", size = 304432 } 206 | wheels = [ 207 | { url = "https://files.pythonhosted.org/packages/2f/e0/014d5d9d7a4564cf1c40b5039bc882db69fd881111e03ab3657ac0b218e2/fsspec-2025.7.0-py3-none-any.whl", hash = "sha256:8b012e39f63c7d5f10474de957f3ab793b47b45ae7d39f2fb735f8bbe25c0e21", size = 199597 }, 208 | ] 209 | 210 | [[package]] 211 | name = "hf-xet" 212 | version = "1.1.9" 213 | source = { registry = "https://pypi.org/simple" } 214 | sdist = { url = "https://files.pythonhosted.org/packages/23/0f/5b60fc28ee7f8cc17a5114a584fd6b86e11c3e0a6e142a7f97a161e9640a/hf_xet-1.1.9.tar.gz", hash = "sha256:c99073ce404462e909f1d5839b2d14a3827b8fe75ed8aed551ba6609c026c803", size = 484242 } 215 | wheels = [ 216 | { url = "https://files.pythonhosted.org/packages/de/12/56e1abb9a44cdef59a411fe8a8673313195711b5ecce27880eb9c8fa90bd/hf_xet-1.1.9-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:a3b6215f88638dd7a6ff82cb4e738dcbf3d863bf667997c093a3c990337d1160", size = 2762553 }, 217 | { url = "https://files.pythonhosted.org/packages/3a/e6/2d0d16890c5f21b862f5df3146519c182e7f0ae49b4b4bf2bd8a40d0b05e/hf_xet-1.1.9-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:9b486de7a64a66f9a172f4b3e0dfe79c9f0a93257c501296a2521a13495a698a", size = 2623216 }, 218 | { url = "https://files.pythonhosted.org/packages/81/42/7e6955cf0621e87491a1fb8cad755d5c2517803cea174229b0ec00ff0166/hf_xet-1.1.9-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4c5a840c2c4e6ec875ed13703a60e3523bc7f48031dfd750923b2a4d1a5fc3c", size = 3186789 }, 219 | { url = "https://files.pythonhosted.org/packages/df/8b/759233bce05457f5f7ec062d63bbfd2d0c740b816279eaaa54be92aa452a/hf_xet-1.1.9-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:96a6139c9e44dad1c52c52520db0fffe948f6bce487cfb9d69c125f254bb3790", size = 3088747 }, 220 | { url = "https://files.pythonhosted.org/packages/6c/3c/28cc4db153a7601a996985bcb564f7b8f5b9e1a706c7537aad4b4809f358/hf_xet-1.1.9-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ad1022e9a998e784c97b2173965d07fe33ee26e4594770b7785a8cc8f922cd95", size = 3251429 }, 221 | { url = "https://files.pythonhosted.org/packages/84/17/7caf27a1d101bfcb05be85850d4aa0a265b2e1acc2d4d52a48026ef1d299/hf_xet-1.1.9-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:86754c2d6d5afb11b0a435e6e18911a4199262fe77553f8c50d75e21242193ea", size = 3354643 }, 222 | { url = "https://files.pythonhosted.org/packages/cd/50/0c39c9eed3411deadcc98749a6699d871b822473f55fe472fad7c01ec588/hf_xet-1.1.9-cp37-abi3-win_amd64.whl", hash = "sha256:5aad3933de6b725d61d51034e04174ed1dce7a57c63d530df0014dea15a40127", size = 2804797 }, 223 | ] 224 | 225 | [[package]] 226 | name = "huggingface-hub" 227 | version = "0.34.4" 228 | source = { registry = "https://pypi.org/simple" } 229 | dependencies = [ 230 | { name = "filelock" }, 231 | { name = "fsspec" }, 232 | { name = "hf-xet", marker = "platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" }, 233 | { name = "packaging" }, 234 | { name = "pyyaml" }, 235 | { name = "requests" }, 236 | { name = "tqdm" }, 237 | { name = "typing-extensions" }, 238 | ] 239 | sdist = { url = "https://files.pythonhosted.org/packages/45/c9/bdbe19339f76d12985bc03572f330a01a93c04dffecaaea3061bdd7fb892/huggingface_hub-0.34.4.tar.gz", hash = "sha256:a4228daa6fb001be3f4f4bdaf9a0db00e1739235702848df00885c9b5742c85c", size = 459768 } 240 | wheels = [ 241 | { url = "https://files.pythonhosted.org/packages/39/7b/bb06b061991107cd8783f300adff3e7b7f284e330fd82f507f2a1417b11d/huggingface_hub-0.34.4-py3-none-any.whl", hash = "sha256:9b365d781739c93ff90c359844221beef048403f1bc1f1c123c191257c3c890a", size = 561452 }, 242 | ] 243 | 244 | [[package]] 245 | name = "identify" 246 | version = "2.6.13" 247 | source = { registry = "https://pypi.org/simple" } 248 | sdist = { url = "https://files.pythonhosted.org/packages/82/ca/ffbabe3635bb839aa36b3a893c91a9b0d368cb4d8073e03a12896970af82/identify-2.6.13.tar.gz", hash = "sha256:da8d6c828e773620e13bfa86ea601c5a5310ba4bcd65edf378198b56a1f9fb32", size = 99243 } 249 | wheels = [ 250 | { url = "https://files.pythonhosted.org/packages/e7/ce/461b60a3ee109518c055953729bf9ed089a04db895d47e95444071dcdef2/identify-2.6.13-py2.py3-none-any.whl", hash = "sha256:60381139b3ae39447482ecc406944190f690d4a2997f2584062089848361b33b", size = 99153 }, 251 | ] 252 | 253 | [[package]] 254 | name = "idna" 255 | version = "3.10" 256 | source = { registry = "https://pypi.org/simple" } 257 | sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490 } 258 | wheels = [ 259 | { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 }, 260 | ] 261 | 262 | [[package]] 263 | name = "ipykernel" 264 | version = "6.30.1" 265 | source = { registry = "https://pypi.org/simple" } 266 | dependencies = [ 267 | { name = "appnope", marker = "sys_platform == 'darwin'" }, 268 | { name = "comm" }, 269 | { name = "debugpy" }, 270 | { name = "ipython" }, 271 | { name = "jupyter-client" }, 272 | { name = "jupyter-core" }, 273 | { name = "matplotlib-inline" }, 274 | { name = "nest-asyncio" }, 275 | { name = "packaging" }, 276 | { name = "psutil" }, 277 | { name = "pyzmq" }, 278 | { name = "tornado" }, 279 | { name = "traitlets" }, 280 | ] 281 | sdist = { url = "https://files.pythonhosted.org/packages/bb/76/11082e338e0daadc89c8ff866185de11daf67d181901038f9e139d109761/ipykernel-6.30.1.tar.gz", hash = "sha256:6abb270161896402e76b91394fcdce5d1be5d45f456671e5080572f8505be39b", size = 166260 } 282 | wheels = [ 283 | { url = "https://files.pythonhosted.org/packages/fc/c7/b445faca8deb954fe536abebff4ece5b097b923de482b26e78448c89d1dd/ipykernel-6.30.1-py3-none-any.whl", hash = "sha256:aa6b9fb93dca949069d8b85b6c79b2518e32ac583ae9c7d37c51d119e18b3fb4", size = 117484 }, 284 | ] 285 | 286 | [[package]] 287 | name = "ipython" 288 | version = "9.5.0" 289 | source = { registry = "https://pypi.org/simple" } 290 | dependencies = [ 291 | { name = "colorama", marker = "sys_platform == 'win32'" }, 292 | { name = "decorator" }, 293 | { name = "ipython-pygments-lexers" }, 294 | { name = "jedi" }, 295 | { name = "matplotlib-inline" }, 296 | { name = "pexpect", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, 297 | { name = "prompt-toolkit" }, 298 | { name = "pygments" }, 299 | { name = "stack-data" }, 300 | { name = "traitlets" }, 301 | ] 302 | sdist = { url = "https://files.pythonhosted.org/packages/6e/71/a86262bf5a68bf211bcc71fe302af7e05f18a2852fdc610a854d20d085e6/ipython-9.5.0.tar.gz", hash = "sha256:129c44b941fe6d9b82d36fc7a7c18127ddb1d6f02f78f867f402e2e3adde3113", size = 4389137 } 303 | wheels = [ 304 | { url = "https://files.pythonhosted.org/packages/08/2a/5628a99d04acb2d2f2e749cdf4ea571d2575e898df0528a090948018b726/ipython-9.5.0-py3-none-any.whl", hash = "sha256:88369ffa1d5817d609120daa523a6da06d02518e582347c29f8451732a9c5e72", size = 612426 }, 305 | ] 306 | 307 | [[package]] 308 | name = "ipython-pygments-lexers" 309 | version = "1.1.1" 310 | source = { registry = "https://pypi.org/simple" } 311 | dependencies = [ 312 | { name = "pygments" }, 313 | ] 314 | sdist = { url = "https://files.pythonhosted.org/packages/ef/4c/5dd1d8af08107f88c7f741ead7a40854b8ac24ddf9ae850afbcf698aa552/ipython_pygments_lexers-1.1.1.tar.gz", hash = "sha256:09c0138009e56b6854f9535736f4171d855c8c08a563a0dcd8022f78355c7e81", size = 8393 } 315 | wheels = [ 316 | { url = "https://files.pythonhosted.org/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl", hash = "sha256:a9462224a505ade19a605f71f8fa63c2048833ce50abc86768a0d81d876dc81c", size = 8074 }, 317 | ] 318 | 319 | [[package]] 320 | name = "jedi" 321 | version = "0.19.2" 322 | source = { registry = "https://pypi.org/simple" } 323 | dependencies = [ 324 | { name = "parso" }, 325 | ] 326 | sdist = { url = "https://files.pythonhosted.org/packages/72/3a/79a912fbd4d8dd6fbb02bf69afd3bb72cf0c729bb3063c6f4498603db17a/jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0", size = 1231287 } 327 | wheels = [ 328 | { url = "https://files.pythonhosted.org/packages/c0/5a/9cac0c82afec3d09ccd97c8b6502d48f165f9124db81b4bcb90b4af974ee/jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9", size = 1572278 }, 329 | ] 330 | 331 | [[package]] 332 | name = "jinja2" 333 | version = "3.1.6" 334 | source = { registry = "https://pypi.org/simple" } 335 | dependencies = [ 336 | { name = "markupsafe" }, 337 | ] 338 | sdist = { url = "https://files.pythonhosted.org/packages/df/bf/f7da0350254c0ed7c72f3e33cef02e048281fec7ecec5f032d4aac52226b/jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d", size = 245115 } 339 | wheels = [ 340 | { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899 }, 341 | ] 342 | 343 | [[package]] 344 | name = "jupyter-client" 345 | version = "8.6.3" 346 | source = { registry = "https://pypi.org/simple" } 347 | dependencies = [ 348 | { name = "jupyter-core" }, 349 | { name = "python-dateutil" }, 350 | { name = "pyzmq" }, 351 | { name = "tornado" }, 352 | { name = "traitlets" }, 353 | ] 354 | sdist = { url = "https://files.pythonhosted.org/packages/71/22/bf9f12fdaeae18019a468b68952a60fe6dbab5d67cd2a103cac7659b41ca/jupyter_client-8.6.3.tar.gz", hash = "sha256:35b3a0947c4a6e9d589eb97d7d4cd5e90f910ee73101611f01283732bd6d9419", size = 342019 } 355 | wheels = [ 356 | { url = "https://files.pythonhosted.org/packages/11/85/b0394e0b6fcccd2c1eeefc230978a6f8cb0c5df1e4cd3e7625735a0d7d1e/jupyter_client-8.6.3-py3-none-any.whl", hash = "sha256:e8a19cc986cc45905ac3362915f410f3af85424b4c0905e94fa5f2cb08e8f23f", size = 106105 }, 357 | ] 358 | 359 | [[package]] 360 | name = "jupyter-core" 361 | version = "5.8.1" 362 | source = { registry = "https://pypi.org/simple" } 363 | dependencies = [ 364 | { name = "platformdirs" }, 365 | { name = "pywin32", marker = "platform_python_implementation != 'PyPy' and sys_platform == 'win32'" }, 366 | { name = "traitlets" }, 367 | ] 368 | sdist = { url = "https://files.pythonhosted.org/packages/99/1b/72906d554acfeb588332eaaa6f61577705e9ec752ddb486f302dafa292d9/jupyter_core-5.8.1.tar.gz", hash = "sha256:0a5f9706f70e64786b75acba995988915ebd4601c8a52e534a40b51c95f59941", size = 88923 } 369 | wheels = [ 370 | { url = "https://files.pythonhosted.org/packages/2f/57/6bffd4b20b88da3800c5d691e0337761576ee688eb01299eae865689d2df/jupyter_core-5.8.1-py3-none-any.whl", hash = "sha256:c28d268fc90fb53f1338ded2eb410704c5449a358406e8a948b75706e24863d0", size = 28880 }, 371 | ] 372 | 373 | [[package]] 374 | name = "markupsafe" 375 | version = "3.0.2" 376 | source = { registry = "https://pypi.org/simple" } 377 | sdist = { url = "https://files.pythonhosted.org/packages/b2/97/5d42485e71dfc078108a86d6de8fa46db44a1a9295e89c5d6d4a06e23a62/markupsafe-3.0.2.tar.gz", hash = "sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0", size = 20537 } 378 | wheels = [ 379 | { url = "https://files.pythonhosted.org/packages/22/09/d1f21434c97fc42f09d290cbb6350d44eb12f09cc62c9476effdb33a18aa/MarkupSafe-3.0.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:9778bd8ab0a994ebf6f84c2b949e65736d5575320a17ae8984a77fab08db94cf", size = 14274 }, 380 | { url = "https://files.pythonhosted.org/packages/6b/b0/18f76bba336fa5aecf79d45dcd6c806c280ec44538b3c13671d49099fdd0/MarkupSafe-3.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:846ade7b71e3536c4e56b386c2a47adf5741d2d8b94ec9dc3e92e5e1ee1e2225", size = 12348 }, 381 | { url = "https://files.pythonhosted.org/packages/e0/25/dd5c0f6ac1311e9b40f4af06c78efde0f3b5cbf02502f8ef9501294c425b/MarkupSafe-3.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c99d261bd2d5f6b59325c92c73df481e05e57f19837bdca8413b9eac4bd8028", size = 24149 }, 382 | { url = "https://files.pythonhosted.org/packages/f3/f0/89e7aadfb3749d0f52234a0c8c7867877876e0a20b60e2188e9850794c17/MarkupSafe-3.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e17c96c14e19278594aa4841ec148115f9c7615a47382ecb6b82bd8fea3ab0c8", size = 23118 }, 383 | { url = "https://files.pythonhosted.org/packages/d5/da/f2eeb64c723f5e3777bc081da884b414671982008c47dcc1873d81f625b6/MarkupSafe-3.0.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:88416bd1e65dcea10bc7569faacb2c20ce071dd1f87539ca2ab364bf6231393c", size = 22993 }, 384 | { url = "https://files.pythonhosted.org/packages/da/0e/1f32af846df486dce7c227fe0f2398dc7e2e51d4a370508281f3c1c5cddc/MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2181e67807fc2fa785d0592dc2d6206c019b9502410671cc905d132a92866557", size = 24178 }, 385 | { url = "https://files.pythonhosted.org/packages/c4/f6/bb3ca0532de8086cbff5f06d137064c8410d10779c4c127e0e47d17c0b71/MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:52305740fe773d09cffb16f8ed0427942901f00adedac82ec8b67752f58a1b22", size = 23319 }, 386 | { url = "https://files.pythonhosted.org/packages/a2/82/8be4c96ffee03c5b4a034e60a31294daf481e12c7c43ab8e34a1453ee48b/MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ad10d3ded218f1039f11a75f8091880239651b52e9bb592ca27de44eed242a48", size = 23352 }, 387 | { url = "https://files.pythonhosted.org/packages/51/ae/97827349d3fcffee7e184bdf7f41cd6b88d9919c80f0263ba7acd1bbcb18/MarkupSafe-3.0.2-cp312-cp312-win32.whl", hash = "sha256:0f4ca02bea9a23221c0182836703cbf8930c5e9454bacce27e767509fa286a30", size = 15097 }, 388 | { url = "https://files.pythonhosted.org/packages/c1/80/a61f99dc3a936413c3ee4e1eecac96c0da5ed07ad56fd975f1a9da5bc630/MarkupSafe-3.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:8e06879fc22a25ca47312fbe7c8264eb0b662f6db27cb2d3bbbc74b1df4b9b87", size = 15601 }, 389 | { url = "https://files.pythonhosted.org/packages/83/0e/67eb10a7ecc77a0c2bbe2b0235765b98d164d81600746914bebada795e97/MarkupSafe-3.0.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ba9527cdd4c926ed0760bc301f6728ef34d841f405abf9d4f959c478421e4efd", size = 14274 }, 390 | { url = "https://files.pythonhosted.org/packages/2b/6d/9409f3684d3335375d04e5f05744dfe7e9f120062c9857df4ab490a1031a/MarkupSafe-3.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f8b3d067f2e40fe93e1ccdd6b2e1d16c43140e76f02fb1319a05cf2b79d99430", size = 12352 }, 391 | { url = "https://files.pythonhosted.org/packages/d2/f5/6eadfcd3885ea85fe2a7c128315cc1bb7241e1987443d78c8fe712d03091/MarkupSafe-3.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:569511d3b58c8791ab4c2e1285575265991e6d8f8700c7be0e88f86cb0672094", size = 24122 }, 392 | { url = "https://files.pythonhosted.org/packages/0c/91/96cf928db8236f1bfab6ce15ad070dfdd02ed88261c2afafd4b43575e9e9/MarkupSafe-3.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15ab75ef81add55874e7ab7055e9c397312385bd9ced94920f2802310c930396", size = 23085 }, 393 | { url = "https://files.pythonhosted.org/packages/c2/cf/c9d56af24d56ea04daae7ac0940232d31d5a8354f2b457c6d856b2057d69/MarkupSafe-3.0.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f3818cb119498c0678015754eba762e0d61e5b52d34c8b13d770f0719f7b1d79", size = 22978 }, 394 | { url = "https://files.pythonhosted.org/packages/2a/9f/8619835cd6a711d6272d62abb78c033bda638fdc54c4e7f4272cf1c0962b/MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:cdb82a876c47801bb54a690c5ae105a46b392ac6099881cdfb9f6e95e4014c6a", size = 24208 }, 395 | { url = "https://files.pythonhosted.org/packages/f9/bf/176950a1792b2cd2102b8ffeb5133e1ed984547b75db47c25a67d3359f77/MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:cabc348d87e913db6ab4aa100f01b08f481097838bdddf7c7a84b7575b7309ca", size = 23357 }, 396 | { url = "https://files.pythonhosted.org/packages/ce/4f/9a02c1d335caabe5c4efb90e1b6e8ee944aa245c1aaaab8e8a618987d816/MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:444dcda765c8a838eaae23112db52f1efaf750daddb2d9ca300bcae1039adc5c", size = 23344 }, 397 | { url = "https://files.pythonhosted.org/packages/ee/55/c271b57db36f748f0e04a759ace9f8f759ccf22b4960c270c78a394f58be/MarkupSafe-3.0.2-cp313-cp313-win32.whl", hash = "sha256:bcf3e58998965654fdaff38e58584d8937aa3096ab5354d493c77d1fdd66d7a1", size = 15101 }, 398 | { url = "https://files.pythonhosted.org/packages/29/88/07df22d2dd4df40aba9f3e402e6dc1b8ee86297dddbad4872bd5e7b0094f/MarkupSafe-3.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:e6a2a455bd412959b57a172ce6328d2dd1f01cb2135efda2e4576e8a23fa3b0f", size = 15603 }, 399 | { url = "https://files.pythonhosted.org/packages/62/6a/8b89d24db2d32d433dffcd6a8779159da109842434f1dd2f6e71f32f738c/MarkupSafe-3.0.2-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:b5a6b3ada725cea8a5e634536b1b01c30bcdcd7f9c6fff4151548d5bf6b3a36c", size = 14510 }, 400 | { url = "https://files.pythonhosted.org/packages/7a/06/a10f955f70a2e5a9bf78d11a161029d278eeacbd35ef806c3fd17b13060d/MarkupSafe-3.0.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:a904af0a6162c73e3edcb969eeeb53a63ceeb5d8cf642fade7d39e7963a22ddb", size = 12486 }, 401 | { url = "https://files.pythonhosted.org/packages/34/cf/65d4a571869a1a9078198ca28f39fba5fbb910f952f9dbc5220afff9f5e6/MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4aa4e5faecf353ed117801a068ebab7b7e09ffb6e1d5e412dc852e0da018126c", size = 25480 }, 402 | { url = "https://files.pythonhosted.org/packages/0c/e3/90e9651924c430b885468b56b3d597cabf6d72be4b24a0acd1fa0e12af67/MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0ef13eaeee5b615fb07c9a7dadb38eac06a0608b41570d8ade51c56539e509d", size = 23914 }, 403 | { url = "https://files.pythonhosted.org/packages/66/8c/6c7cf61f95d63bb866db39085150df1f2a5bd3335298f14a66b48e92659c/MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d16a81a06776313e817c951135cf7340a3e91e8c1ff2fac444cfd75fffa04afe", size = 23796 }, 404 | { url = "https://files.pythonhosted.org/packages/bb/35/cbe9238ec3f47ac9a7c8b3df7a808e7cb50fe149dc7039f5f454b3fba218/MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:6381026f158fdb7c72a168278597a5e3a5222e83ea18f543112b2662a9b699c5", size = 25473 }, 405 | { url = "https://files.pythonhosted.org/packages/e6/32/7621a4382488aa283cc05e8984a9c219abad3bca087be9ec77e89939ded9/MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:3d79d162e7be8f996986c064d1c7c817f6df3a77fe3d6859f6f9e7be4b8c213a", size = 24114 }, 406 | { url = "https://files.pythonhosted.org/packages/0d/80/0985960e4b89922cb5a0bac0ed39c5b96cbc1a536a99f30e8c220a996ed9/MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:131a3c7689c85f5ad20f9f6fb1b866f402c445b220c19fe4308c0b147ccd2ad9", size = 24098 }, 407 | { url = "https://files.pythonhosted.org/packages/82/78/fedb03c7d5380df2427038ec8d973587e90561b2d90cd472ce9254cf348b/MarkupSafe-3.0.2-cp313-cp313t-win32.whl", hash = "sha256:ba8062ed2cf21c07a9e295d5b8a2a5ce678b913b45fdf68c32d95d6c1291e0b6", size = 15208 }, 408 | { url = "https://files.pythonhosted.org/packages/4f/65/6079a46068dfceaeabb5dcad6d674f5f5c61a6fa5673746f42a9f4c233b3/MarkupSafe-3.0.2-cp313-cp313t-win_amd64.whl", hash = "sha256:e444a31f8db13eb18ada366ab3cf45fd4b31e4db1236a4448f68778c1d1a5a2f", size = 15739 }, 409 | ] 410 | 411 | [[package]] 412 | name = "matplotlib-inline" 413 | version = "0.1.7" 414 | source = { registry = "https://pypi.org/simple" } 415 | dependencies = [ 416 | { name = "traitlets" }, 417 | ] 418 | sdist = { url = "https://files.pythonhosted.org/packages/99/5b/a36a337438a14116b16480db471ad061c36c3694df7c2084a0da7ba538b7/matplotlib_inline-0.1.7.tar.gz", hash = "sha256:8423b23ec666be3d16e16b60bdd8ac4e86e840ebd1dd11a30b9f117f2fa0ab90", size = 8159 } 419 | wheels = [ 420 | { url = "https://files.pythonhosted.org/packages/8f/8e/9ad090d3553c280a8060fbf6e24dc1c0c29704ee7d1c372f0c174aa59285/matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca", size = 9899 }, 421 | ] 422 | 423 | [[package]] 424 | name = "mpmath" 425 | version = "1.3.0" 426 | source = { registry = "https://pypi.org/simple" } 427 | sdist = { url = "https://files.pythonhosted.org/packages/e0/47/dd32fa426cc72114383ac549964eecb20ecfd886d1e5ccf5340b55b02f57/mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f", size = 508106 } 428 | wheels = [ 429 | { url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198 }, 430 | ] 431 | 432 | [[package]] 433 | name = "nest-asyncio" 434 | version = "1.6.0" 435 | source = { registry = "https://pypi.org/simple" } 436 | sdist = { url = "https://files.pythonhosted.org/packages/83/f8/51569ac65d696c8ecbee95938f89d4abf00f47d58d48f6fbabfe8f0baefe/nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe", size = 7418 } 437 | wheels = [ 438 | { url = "https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c", size = 5195 }, 439 | ] 440 | 441 | [[package]] 442 | name = "networkx" 443 | version = "3.5" 444 | source = { registry = "https://pypi.org/simple" } 445 | sdist = { url = "https://files.pythonhosted.org/packages/6c/4f/ccdb8ad3a38e583f214547fd2f7ff1fc160c43a75af88e6aec213404b96a/networkx-3.5.tar.gz", hash = "sha256:d4c6f9cf81f52d69230866796b82afbccdec3db7ae4fbd1b65ea750feed50037", size = 2471065 } 446 | wheels = [ 447 | { url = "https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl", hash = "sha256:0030d386a9a06dee3565298b4a734b68589749a544acbb6c412dc9e2489ec6ec", size = 2034406 }, 448 | ] 449 | 450 | [[package]] 451 | name = "ninja" 452 | version = "1.13.0" 453 | source = { registry = "https://pypi.org/simple" } 454 | sdist = { url = "https://files.pythonhosted.org/packages/43/73/79a0b22fc731989c708068427579e840a6cf4e937fe7ae5c5d0b7356ac22/ninja-1.13.0.tar.gz", hash = "sha256:4a40ce995ded54d9dc24f8ea37ff3bf62ad192b547f6c7126e7e25045e76f978", size = 242558 } 455 | wheels = [ 456 | { url = "https://files.pythonhosted.org/packages/3c/74/d02409ed2aa865e051b7edda22ad416a39d81a84980f544f8de717cab133/ninja-1.13.0-py3-none-macosx_10_9_universal2.whl", hash = "sha256:fa2a8bfc62e31b08f83127d1613d10821775a0eb334197154c4d6067b7068ff1", size = 310125 }, 457 | { url = "https://files.pythonhosted.org/packages/8e/de/6e1cd6b84b412ac1ef327b76f0641aeb5dcc01e9d3f9eee0286d0c34fd93/ninja-1.13.0-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3d00c692fb717fd511abeb44b8c5d00340c36938c12d6538ba989fe764e79630", size = 177467 }, 458 | { url = "https://files.pythonhosted.org/packages/c8/83/49320fb6e58ae3c079381e333575fdbcf1cca3506ee160a2dcce775046fa/ninja-1.13.0-py3-none-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:be7f478ff9f96a128b599a964fc60a6a87b9fa332ee1bd44fa243ac88d50291c", size = 187834 }, 459 | { url = "https://files.pythonhosted.org/packages/56/c7/ba22748fb59f7f896b609cd3e568d28a0a367a6d953c24c461fe04fc4433/ninja-1.13.0-py3-none-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:60056592cf495e9a6a4bea3cd178903056ecb0943e4de45a2ea825edb6dc8d3e", size = 202736 }, 460 | { url = "https://files.pythonhosted.org/packages/79/22/d1de07632b78ac8e6b785f41fa9aad7a978ec8c0a1bf15772def36d77aac/ninja-1.13.0-py3-none-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:1c97223cdda0417f414bf864cfb73b72d8777e57ebb279c5f6de368de0062988", size = 179034 }, 461 | { url = "https://files.pythonhosted.org/packages/ed/de/0e6edf44d6a04dabd0318a519125ed0415ce437ad5a1ec9b9be03d9048cf/ninja-1.13.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fb46acf6b93b8dd0322adc3a4945452a4e774b75b91293bafcc7b7f8e6517dfa", size = 180716 }, 462 | { url = "https://files.pythonhosted.org/packages/54/28/938b562f9057aaa4d6bfbeaa05e81899a47aebb3ba6751e36c027a7f5ff7/ninja-1.13.0-py3-none-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:4be9c1b082d244b1ad7ef41eb8ab088aae8c109a9f3f0b3e56a252d3e00f42c1", size = 146843 }, 463 | { url = "https://files.pythonhosted.org/packages/2a/fb/d06a3838de4f8ab866e44ee52a797b5491df823901c54943b2adb0389fbb/ninja-1.13.0-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:6739d3352073341ad284246f81339a384eec091d9851a886dfa5b00a6d48b3e2", size = 154402 }, 464 | { url = "https://files.pythonhosted.org/packages/31/bf/0d7808af695ceddc763cf251b84a9892cd7f51622dc8b4c89d5012779f06/ninja-1.13.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:11be2d22027bde06f14c343f01d31446747dbb51e72d00decca2eb99be911e2f", size = 552388 }, 465 | { url = "https://files.pythonhosted.org/packages/9d/70/c99d0c2c809f992752453cce312848abb3b1607e56d4cd1b6cded317351a/ninja-1.13.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:aa45b4037b313c2f698bc13306239b8b93b4680eb47e287773156ac9e9304714", size = 472501 }, 466 | { url = "https://files.pythonhosted.org/packages/9f/43/c217b1153f0e499652f5e0766da8523ce3480f0a951039c7af115e224d55/ninja-1.13.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:5f8e1e8a1a30835eeb51db05cf5a67151ad37542f5a4af2a438e9490915e5b72", size = 638280 }, 467 | { url = "https://files.pythonhosted.org/packages/8c/45/9151bba2c8d0ae2b6260f71696330590de5850e5574b7b5694dce6023e20/ninja-1.13.0-py3-none-musllinux_1_2_ppc64le.whl", hash = "sha256:3d7d7779d12cb20c6d054c61b702139fd23a7a964ec8f2c823f1ab1b084150db", size = 642420 }, 468 | { url = "https://files.pythonhosted.org/packages/3c/fb/95752eb635bb8ad27d101d71bef15bc63049de23f299e312878fc21cb2da/ninja-1.13.0-py3-none-musllinux_1_2_riscv64.whl", hash = "sha256:d741a5e6754e0bda767e3274a0f0deeef4807f1fec6c0d7921a0244018926ae5", size = 585106 }, 469 | { url = "https://files.pythonhosted.org/packages/c1/31/aa56a1a286703800c0cbe39fb4e82811c277772dc8cd084f442dd8e2938a/ninja-1.13.0-py3-none-musllinux_1_2_s390x.whl", hash = "sha256:e8bad11f8a00b64137e9b315b137d8bb6cbf3086fbdc43bf1f90fd33324d2e96", size = 707138 }, 470 | { url = "https://files.pythonhosted.org/packages/34/6f/5f5a54a1041af945130abdb2b8529cbef0cdcbbf9bcf3f4195378319d29a/ninja-1.13.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b4f2a072db3c0f944c32793e91532d8948d20d9ab83da9c0c7c15b5768072200", size = 581758 }, 471 | { url = "https://files.pythonhosted.org/packages/95/97/51359c77527d45943fe7a94d00a3843b81162e6c4244b3579fe8fc54cb9c/ninja-1.13.0-py3-none-win32.whl", hash = "sha256:8cfbb80b4a53456ae8a39f90ae3d7a2129f45ea164f43fadfa15dc38c4aef1c9", size = 267201 }, 472 | { url = "https://files.pythonhosted.org/packages/29/45/c0adfbfb0b5895aa18cec400c535b4f7ff3e52536e0403602fc1a23f7de9/ninja-1.13.0-py3-none-win_amd64.whl", hash = "sha256:fb8ee8719f8af47fed145cced4a85f0755dd55d45b2bddaf7431fa89803c5f3e", size = 309975 }, 473 | { url = "https://files.pythonhosted.org/packages/df/93/a7b983643d1253bb223234b5b226e69de6cda02b76cdca7770f684b795f5/ninja-1.13.0-py3-none-win_arm64.whl", hash = "sha256:3c0b40b1f0bba764644385319028650087b4c1b18cdfa6f45cb39a3669b81aa9", size = 290806 }, 474 | ] 475 | 476 | [[package]] 477 | name = "nodeenv" 478 | version = "1.9.1" 479 | source = { registry = "https://pypi.org/simple" } 480 | sdist = { url = "https://files.pythonhosted.org/packages/43/16/fc88b08840de0e0a72a2f9d8c6bae36be573e475a6326ae854bcc549fc45/nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f", size = 47437 } 481 | wheels = [ 482 | { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314 }, 483 | ] 484 | 485 | [[package]] 486 | name = "numpy" 487 | version = "1.26.4" 488 | source = { registry = "https://pypi.org/simple" } 489 | sdist = { url = "https://files.pythonhosted.org/packages/65/6e/09db70a523a96d25e115e71cc56a6f9031e7b8cd166c1ac8438307c14058/numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010", size = 15786129 } 490 | wheels = [ 491 | { url = "https://files.pythonhosted.org/packages/95/12/8f2020a8e8b8383ac0177dc9570aad031a3beb12e38847f7129bacd96228/numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218", size = 20335901 }, 492 | { url = "https://files.pythonhosted.org/packages/75/5b/ca6c8bd14007e5ca171c7c03102d17b4f4e0ceb53957e8c44343a9546dcc/numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b", size = 13685868 }, 493 | { url = "https://files.pythonhosted.org/packages/79/f8/97f10e6755e2a7d027ca783f63044d5b1bc1ae7acb12afe6a9b4286eac17/numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b", size = 13925109 }, 494 | { url = "https://files.pythonhosted.org/packages/0f/50/de23fde84e45f5c4fda2488c759b69990fd4512387a8632860f3ac9cd225/numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed", size = 17950613 }, 495 | { url = "https://files.pythonhosted.org/packages/4c/0c/9c603826b6465e82591e05ca230dfc13376da512b25ccd0894709b054ed0/numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a", size = 13572172 }, 496 | { url = "https://files.pythonhosted.org/packages/76/8c/2ba3902e1a0fc1c74962ea9bb33a534bb05984ad7ff9515bf8d07527cadd/numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0", size = 17786643 }, 497 | { url = "https://files.pythonhosted.org/packages/28/4a/46d9e65106879492374999e76eb85f87b15328e06bd1550668f79f7b18c6/numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110", size = 5677803 }, 498 | { url = "https://files.pythonhosted.org/packages/16/2e/86f24451c2d530c88daf997cb8d6ac622c1d40d19f5a031ed68a4b73a374/numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818", size = 15517754 }, 499 | ] 500 | 501 | [[package]] 502 | name = "nvidia-cublas-cu12" 503 | version = "12.8.4.1" 504 | source = { registry = "https://pypi.org/simple" } 505 | wheels = [ 506 | { url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921 }, 507 | ] 508 | 509 | [[package]] 510 | name = "nvidia-cuda-cupti-cu12" 511 | version = "12.8.90" 512 | source = { registry = "https://pypi.org/simple" } 513 | wheels = [ 514 | { url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621 }, 515 | ] 516 | 517 | [[package]] 518 | name = "nvidia-cuda-nvrtc-cu12" 519 | version = "12.8.93" 520 | source = { registry = "https://pypi.org/simple" } 521 | wheels = [ 522 | { url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029 }, 523 | ] 524 | 525 | [[package]] 526 | name = "nvidia-cuda-runtime-cu12" 527 | version = "12.8.90" 528 | source = { registry = "https://pypi.org/simple" } 529 | wheels = [ 530 | { url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765 }, 531 | ] 532 | 533 | [[package]] 534 | name = "nvidia-cudnn-cu12" 535 | version = "9.10.2.21" 536 | source = { registry = "https://pypi.org/simple" } 537 | dependencies = [ 538 | { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, 539 | ] 540 | wheels = [ 541 | { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467 }, 542 | ] 543 | 544 | [[package]] 545 | name = "nvidia-cufft-cu12" 546 | version = "11.3.3.83" 547 | source = { registry = "https://pypi.org/simple" } 548 | dependencies = [ 549 | { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, 550 | ] 551 | wheels = [ 552 | { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695 }, 553 | ] 554 | 555 | [[package]] 556 | name = "nvidia-cufile-cu12" 557 | version = "1.13.1.3" 558 | source = { registry = "https://pypi.org/simple" } 559 | wheels = [ 560 | { url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834 }, 561 | ] 562 | 563 | [[package]] 564 | name = "nvidia-curand-cu12" 565 | version = "10.3.9.90" 566 | source = { registry = "https://pypi.org/simple" } 567 | wheels = [ 568 | { url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976 }, 569 | ] 570 | 571 | [[package]] 572 | name = "nvidia-cusolver-cu12" 573 | version = "11.7.3.90" 574 | source = { registry = "https://pypi.org/simple" } 575 | dependencies = [ 576 | { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, 577 | { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, 578 | { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, 579 | ] 580 | wheels = [ 581 | { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905 }, 582 | ] 583 | 584 | [[package]] 585 | name = "nvidia-cusparse-cu12" 586 | version = "12.5.8.93" 587 | source = { registry = "https://pypi.org/simple" } 588 | dependencies = [ 589 | { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, 590 | ] 591 | wheels = [ 592 | { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466 }, 593 | ] 594 | 595 | [[package]] 596 | name = "nvidia-cusparselt-cu12" 597 | version = "0.7.1" 598 | source = { registry = "https://pypi.org/simple" } 599 | wheels = [ 600 | { url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691 }, 601 | ] 602 | 603 | [[package]] 604 | name = "nvidia-nccl-cu12" 605 | version = "2.27.3" 606 | source = { registry = "https://pypi.org/simple" } 607 | wheels = [ 608 | { url = "https://files.pythonhosted.org/packages/5c/5b/4e4fff7bad39adf89f735f2bc87248c81db71205b62bcc0d5ca5b606b3c3/nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adf27ccf4238253e0b826bce3ff5fa532d65fc42322c8bfdfaf28024c0fbe039", size = 322364134 }, 609 | ] 610 | 611 | [[package]] 612 | name = "nvidia-nvjitlink-cu12" 613 | version = "12.8.93" 614 | source = { registry = "https://pypi.org/simple" } 615 | wheels = [ 616 | { url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836 }, 617 | ] 618 | 619 | [[package]] 620 | name = "nvidia-nvtx-cu12" 621 | version = "12.8.90" 622 | source = { registry = "https://pypi.org/simple" } 623 | wheels = [ 624 | { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954 }, 625 | ] 626 | 627 | [[package]] 628 | name = "opencv-python" 629 | version = "4.11.0.86" 630 | source = { registry = "https://pypi.org/simple" } 631 | dependencies = [ 632 | { name = "numpy" }, 633 | ] 634 | sdist = { url = "https://files.pythonhosted.org/packages/17/06/68c27a523103dad5837dc5b87e71285280c4f098c60e4fe8a8db6486ab09/opencv-python-4.11.0.86.tar.gz", hash = "sha256:03d60ccae62304860d232272e4a4fda93c39d595780cb40b161b310244b736a4", size = 95171956 } 635 | wheels = [ 636 | { url = "https://files.pythonhosted.org/packages/05/4d/53b30a2a3ac1f75f65a59eb29cf2ee7207ce64867db47036ad61743d5a23/opencv_python-4.11.0.86-cp37-abi3-macosx_13_0_arm64.whl", hash = "sha256:432f67c223f1dc2824f5e73cdfcd9db0efc8710647d4e813012195dc9122a52a", size = 37326322 }, 637 | { url = "https://files.pythonhosted.org/packages/3b/84/0a67490741867eacdfa37bc18df96e08a9d579583b419010d7f3da8ff503/opencv_python-4.11.0.86-cp37-abi3-macosx_13_0_x86_64.whl", hash = "sha256:9d05ef13d23fe97f575153558653e2d6e87103995d54e6a35db3f282fe1f9c66", size = 56723197 }, 638 | { url = "https://files.pythonhosted.org/packages/f3/bd/29c126788da65c1fb2b5fb621b7fed0ed5f9122aa22a0868c5e2c15c6d23/opencv_python-4.11.0.86-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b92ae2c8852208817e6776ba1ea0d6b1e0a1b5431e971a2a0ddd2a8cc398202", size = 42230439 }, 639 | { url = "https://files.pythonhosted.org/packages/2c/8b/90eb44a40476fa0e71e05a0283947cfd74a5d36121a11d926ad6f3193cc4/opencv_python-4.11.0.86-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b02611523803495003bd87362db3e1d2a0454a6a63025dc6658a9830570aa0d", size = 62986597 }, 640 | { url = "https://files.pythonhosted.org/packages/fb/d7/1d5941a9dde095468b288d989ff6539dd69cd429dbf1b9e839013d21b6f0/opencv_python-4.11.0.86-cp37-abi3-win32.whl", hash = "sha256:810549cb2a4aedaa84ad9a1c92fbfdfc14090e2749cedf2c1589ad8359aa169b", size = 29384337 }, 641 | { url = "https://files.pythonhosted.org/packages/a4/7d/f1c30a92854540bf789e9cd5dde7ef49bbe63f855b85a2e6b3db8135c591/opencv_python-4.11.0.86-cp37-abi3-win_amd64.whl", hash = "sha256:085ad9b77c18853ea66283e98affefe2de8cc4c1f43eda4c100cf9b2721142ec", size = 39488044 }, 642 | ] 643 | 644 | [[package]] 645 | name = "packaging" 646 | version = "25.0" 647 | source = { registry = "https://pypi.org/simple" } 648 | sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727 } 649 | wheels = [ 650 | { url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469 }, 651 | ] 652 | 653 | [[package]] 654 | name = "parso" 655 | version = "0.8.5" 656 | source = { registry = "https://pypi.org/simple" } 657 | sdist = { url = "https://files.pythonhosted.org/packages/d4/de/53e0bcf53d13e005bd8c92e7855142494f41171b34c2536b86187474184d/parso-0.8.5.tar.gz", hash = "sha256:034d7354a9a018bdce352f48b2a8a450f05e9d6ee85db84764e9b6bd96dafe5a", size = 401205 } 658 | wheels = [ 659 | { url = "https://files.pythonhosted.org/packages/16/32/f8e3c85d1d5250232a5d3477a2a28cc291968ff175caeadaf3cc19ce0e4a/parso-0.8.5-py2.py3-none-any.whl", hash = "sha256:646204b5ee239c396d040b90f9e272e9a8017c630092bf59980beb62fd033887", size = 106668 }, 660 | ] 661 | 662 | [[package]] 663 | name = "pexpect" 664 | version = "4.9.0" 665 | source = { registry = "https://pypi.org/simple" } 666 | dependencies = [ 667 | { name = "ptyprocess" }, 668 | ] 669 | sdist = { url = "https://files.pythonhosted.org/packages/42/92/cc564bf6381ff43ce1f4d06852fc19a2f11d180f23dc32d9588bee2f149d/pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f", size = 166450 } 670 | wheels = [ 671 | { url = "https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523", size = 63772 }, 672 | ] 673 | 674 | [[package]] 675 | name = "pillow" 676 | version = "10.3.0" 677 | source = { registry = "https://pypi.org/simple" } 678 | sdist = { url = "https://files.pythonhosted.org/packages/ef/43/c50c17c5f7d438e836c169e343695534c38c77f60e7c90389bd77981bc21/pillow-10.3.0.tar.gz", hash = "sha256:9d2455fbf44c914840c793e89aa82d0e1763a14253a000743719ae5946814b2d", size = 46572854 } 679 | wheels = [ 680 | { url = "https://files.pythonhosted.org/packages/cc/5d/b7fcd38cba0f7706f64c1674fc9f018e4c64f791770598c44affadea7c2f/pillow-10.3.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:e46f38133e5a060d46bd630faa4d9fa0202377495df1f068a8299fd78c84de84", size = 3528535 }, 681 | { url = "https://files.pythonhosted.org/packages/5e/77/4cf407e7b033b4d8e5fcaac295b6e159cf1c70fa105d769f01ea2e1e5eca/pillow-10.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:50b8eae8f7334ec826d6eeffaeeb00e36b5e24aa0b9df322c247539714c6df19", size = 3352281 }, 682 | { url = "https://files.pythonhosted.org/packages/53/7b/4f7b153a776725a87797d744ea1c73b83ac0b723f5e379297605dee118eb/pillow-10.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9d3bea1c75f8c53ee4d505c3e67d8c158ad4df0d83170605b50b64025917f338", size = 4321427 }, 683 | { url = "https://files.pythonhosted.org/packages/45/08/d2cc751b790e77464f8648aa707e2327d6da5d95cf236a532e99c2e7a499/pillow-10.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19aeb96d43902f0a783946a0a87dbdad5c84c936025b8419da0a0cd7724356b1", size = 4435915 }, 684 | { url = "https://files.pythonhosted.org/packages/ef/97/f69d1932cf45bf5bd9fa1e2ae57bdf716524faa4fa9fb7dc62cdb1a19113/pillow-10.3.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:74d28c17412d9caa1066f7a31df8403ec23d5268ba46cd0ad2c50fb82ae40462", size = 4347392 }, 685 | { url = "https://files.pythonhosted.org/packages/c6/c1/3521ddb9c1f3ac106af3e4512a98c785b6ed8a39e0f778480b8a4d340165/pillow-10.3.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:ff61bfd9253c3915e6d41c651d5f962da23eda633cf02262990094a18a55371a", size = 4514536 }, 686 | { url = "https://files.pythonhosted.org/packages/c0/6f/347c241904a6514e59515284b01ba6f61765269a0d1a19fd2e6cbe331c8a/pillow-10.3.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d886f5d353333b4771d21267c7ecc75b710f1a73d72d03ca06df49b09015a9ef", size = 4555987 }, 687 | { url = "https://files.pythonhosted.org/packages/c3/e2/3cc490c6b2e262713da82ce849c34bd8e6c31242afb53be8595d820b9877/pillow-10.3.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4b5ec25d8b17217d635f8935dbc1b9aa5907962fae29dff220f2659487891cd3", size = 4623526 }, 688 | { url = "https://files.pythonhosted.org/packages/c1/b3/0209f70fa29b383e7618e47db95712a45788dea03bb960601753262a2883/pillow-10.3.0-cp312-cp312-win32.whl", hash = "sha256:51243f1ed5161b9945011a7360e997729776f6e5d7005ba0c6879267d4c5139d", size = 2217547 }, 689 | { url = "https://files.pythonhosted.org/packages/d3/23/3927d888481ff7c44fdbca3bc2a2e97588c933db46723bf115201377c436/pillow-10.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:412444afb8c4c7a6cc11a47dade32982439925537e483be7c0ae0cf96c4f6a0b", size = 2531641 }, 690 | { url = "https://files.pythonhosted.org/packages/db/36/1ecaa0541d3a1b1362f937d386eeb1875847bfa06d5225f1b0e1588d1007/pillow-10.3.0-cp312-cp312-win_arm64.whl", hash = "sha256:798232c92e7665fe82ac085f9d8e8ca98826f8e27859d9a96b41d519ecd2e49a", size = 2229746 }, 691 | ] 692 | 693 | [[package]] 694 | name = "platformdirs" 695 | version = "4.4.0" 696 | source = { registry = "https://pypi.org/simple" } 697 | sdist = { url = "https://files.pythonhosted.org/packages/23/e8/21db9c9987b0e728855bd57bff6984f67952bea55d6f75e055c46b5383e8/platformdirs-4.4.0.tar.gz", hash = "sha256:ca753cf4d81dc309bc67b0ea38fd15dc97bc30ce419a7f58d13eb3bf14c4febf", size = 21634 } 698 | wheels = [ 699 | { url = "https://files.pythonhosted.org/packages/40/4b/2028861e724d3bd36227adfa20d3fd24c3fc6d52032f4a93c133be5d17ce/platformdirs-4.4.0-py3-none-any.whl", hash = "sha256:abd01743f24e5287cd7a5db3752faf1a2d65353f38ec26d98e25a6db65958c85", size = 18654 }, 700 | ] 701 | 702 | [[package]] 703 | name = "plyfile" 704 | version = "1.1.2" 705 | source = { registry = "https://pypi.org/simple" } 706 | dependencies = [ 707 | { name = "numpy" }, 708 | ] 709 | sdist = { url = "https://files.pythonhosted.org/packages/3e/2e/212e061f2dfa6c6510d4d144819a3e23039c6d4c30731bd023c32da15d8c/plyfile-1.1.2.tar.gz", hash = "sha256:bfb2d88cb3e369cae56df8cf4d107fbfd7beef49ed6cf8ef6b2931b321af9def", size = 36078 } 710 | wheels = [ 711 | { url = "https://files.pythonhosted.org/packages/55/48/600cc57763e82a7d305f816d79fe8267c43e74487e144744349c2aa9f6fd/plyfile-1.1.2-py3-none-any.whl", hash = "sha256:f5529face448b14e1927382bb05231827268650f3411b467b17bd6438eb0b83d", size = 36357 }, 712 | ] 713 | 714 | [[package]] 715 | name = "pre-commit" 716 | version = "4.3.0" 717 | source = { registry = "https://pypi.org/simple" } 718 | dependencies = [ 719 | { name = "cfgv" }, 720 | { name = "identify" }, 721 | { name = "nodeenv" }, 722 | { name = "pyyaml" }, 723 | { name = "virtualenv" }, 724 | ] 725 | sdist = { url = "https://files.pythonhosted.org/packages/ff/29/7cf5bbc236333876e4b41f56e06857a87937ce4bf91e117a6991a2dbb02a/pre_commit-4.3.0.tar.gz", hash = "sha256:499fe450cc9d42e9d58e606262795ecb64dd05438943c62b66f6a8673da30b16", size = 193792 } 726 | wheels = [ 727 | { url = "https://files.pythonhosted.org/packages/5b/a5/987a405322d78a73b66e39e4a90e4ef156fd7141bf71df987e50717c321b/pre_commit-4.3.0-py2.py3-none-any.whl", hash = "sha256:2b0747ad7e6e967169136edffee14c16e148a778a54e4f967921aa1ebf2308d8", size = 220965 }, 728 | ] 729 | 730 | [[package]] 731 | name = "prompt-toolkit" 732 | version = "3.0.52" 733 | source = { registry = "https://pypi.org/simple" } 734 | dependencies = [ 735 | { name = "wcwidth" }, 736 | ] 737 | sdist = { url = "https://files.pythonhosted.org/packages/a1/96/06e01a7b38dce6fe1db213e061a4602dd6032a8a97ef6c1a862537732421/prompt_toolkit-3.0.52.tar.gz", hash = "sha256:28cde192929c8e7321de85de1ddbe736f1375148b02f2e17edd840042b1be855", size = 434198 } 738 | wheels = [ 739 | { url = "https://files.pythonhosted.org/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl", hash = "sha256:9aac639a3bbd33284347de5ad8d68ecc044b91a762dc39b7c21095fcd6a19955", size = 391431 }, 740 | ] 741 | 742 | [[package]] 743 | name = "psutil" 744 | version = "7.0.0" 745 | source = { registry = "https://pypi.org/simple" } 746 | sdist = { url = "https://files.pythonhosted.org/packages/2a/80/336820c1ad9286a4ded7e845b2eccfcb27851ab8ac6abece774a6ff4d3de/psutil-7.0.0.tar.gz", hash = "sha256:7be9c3eba38beccb6495ea33afd982a44074b78f28c434a1f51cc07fd315c456", size = 497003 } 747 | wheels = [ 748 | { url = "https://files.pythonhosted.org/packages/ed/e6/2d26234410f8b8abdbf891c9da62bee396583f713fb9f3325a4760875d22/psutil-7.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:101d71dc322e3cffd7cea0650b09b3d08b8e7c4109dd6809fe452dfd00e58b25", size = 238051 }, 749 | { url = "https://files.pythonhosted.org/packages/04/8b/30f930733afe425e3cbfc0e1468a30a18942350c1a8816acfade80c005c4/psutil-7.0.0-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:39db632f6bb862eeccf56660871433e111b6ea58f2caea825571951d4b6aa3da", size = 239535 }, 750 | { url = "https://files.pythonhosted.org/packages/2a/ed/d362e84620dd22876b55389248e522338ed1bf134a5edd3b8231d7207f6d/psutil-7.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fcee592b4c6f146991ca55919ea3d1f8926497a713ed7faaf8225e174581e91", size = 275004 }, 751 | { url = "https://files.pythonhosted.org/packages/bf/b9/b0eb3f3cbcb734d930fdf839431606844a825b23eaf9a6ab371edac8162c/psutil-7.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b1388a4f6875d7e2aff5c4ca1cc16c545ed41dd8bb596cefea80111db353a34", size = 277986 }, 752 | { url = "https://files.pythonhosted.org/packages/eb/a2/709e0fe2f093556c17fbafda93ac032257242cabcc7ff3369e2cb76a97aa/psutil-7.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5f098451abc2828f7dc6b58d44b532b22f2088f4999a937557b603ce72b1993", size = 279544 }, 753 | { url = "https://files.pythonhosted.org/packages/50/e6/eecf58810b9d12e6427369784efe814a1eec0f492084ce8eb8f4d89d6d61/psutil-7.0.0-cp37-abi3-win32.whl", hash = "sha256:ba3fcef7523064a6c9da440fc4d6bd07da93ac726b5733c29027d7dc95b39d99", size = 241053 }, 754 | { url = "https://files.pythonhosted.org/packages/50/1b/6921afe68c74868b4c9fa424dad3be35b095e16687989ebbb50ce4fceb7c/psutil-7.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:4cf3d4eb1aa9b348dec30105c55cd9b7d4629285735a102beb4441e38db90553", size = 244885 }, 755 | ] 756 | 757 | [[package]] 758 | name = "ptyprocess" 759 | version = "0.7.0" 760 | source = { registry = "https://pypi.org/simple" } 761 | sdist = { url = "https://files.pythonhosted.org/packages/20/e5/16ff212c1e452235a90aeb09066144d0c5a6a8c0834397e03f5224495c4e/ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220", size = 70762 } 762 | wheels = [ 763 | { url = "https://files.pythonhosted.org/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35", size = 13993 }, 764 | ] 765 | 766 | [[package]] 767 | name = "pure-eval" 768 | version = "0.2.3" 769 | source = { registry = "https://pypi.org/simple" } 770 | sdist = { url = "https://files.pythonhosted.org/packages/cd/05/0a34433a064256a578f1783a10da6df098ceaa4a57bbeaa96a6c0352786b/pure_eval-0.2.3.tar.gz", hash = "sha256:5f4e983f40564c576c7c8635ae88db5956bb2229d7e9237d03b3c0b0190eaf42", size = 19752 } 771 | wheels = [ 772 | { url = "https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0", size = 11842 }, 773 | ] 774 | 775 | [[package]] 776 | name = "pycparser" 777 | version = "2.22" 778 | source = { registry = "https://pypi.org/simple" } 779 | sdist = { url = "https://files.pythonhosted.org/packages/1d/b2/31537cf4b1ca988837256c910a668b553fceb8f069bedc4b1c826024b52c/pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6", size = 172736 } 780 | wheels = [ 781 | { url = "https://files.pythonhosted.org/packages/13/a3/a812df4e2dd5696d1f351d58b8fe16a405b234ad2886a0dab9183fb78109/pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc", size = 117552 }, 782 | ] 783 | 784 | [[package]] 785 | name = "pygments" 786 | version = "2.19.2" 787 | source = { registry = "https://pypi.org/simple" } 788 | sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631 } 789 | wheels = [ 790 | { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217 }, 791 | ] 792 | 793 | [[package]] 794 | name = "python-dateutil" 795 | version = "2.9.0.post0" 796 | source = { registry = "https://pypi.org/simple" } 797 | dependencies = [ 798 | { name = "six" }, 799 | ] 800 | sdist = { url = "https://files.pythonhosted.org/packages/66/c0/0c8b6ad9f17a802ee498c46e004a0eb49bc148f2fd230864601a86dcf6db/python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 342432 } 801 | wheels = [ 802 | { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892 }, 803 | ] 804 | 805 | [[package]] 806 | name = "pywin32" 807 | version = "311" 808 | source = { registry = "https://pypi.org/simple" } 809 | wheels = [ 810 | { url = "https://files.pythonhosted.org/packages/e7/ab/01ea1943d4eba0f850c3c61e78e8dd59757ff815ff3ccd0a84de5f541f42/pywin32-311-cp312-cp312-win32.whl", hash = "sha256:750ec6e621af2b948540032557b10a2d43b0cee2ae9758c54154d711cc852d31", size = 8706543 }, 811 | { url = "https://files.pythonhosted.org/packages/d1/a8/a0e8d07d4d051ec7502cd58b291ec98dcc0c3fff027caad0470b72cfcc2f/pywin32-311-cp312-cp312-win_amd64.whl", hash = "sha256:b8c095edad5c211ff31c05223658e71bf7116daa0ecf3ad85f3201ea3190d067", size = 9495040 }, 812 | { url = "https://files.pythonhosted.org/packages/ba/3a/2ae996277b4b50f17d61f0603efd8253cb2d79cc7ae159468007b586396d/pywin32-311-cp312-cp312-win_arm64.whl", hash = "sha256:e286f46a9a39c4a18b319c28f59b61de793654af2f395c102b4f819e584b5852", size = 8710102 }, 813 | { url = "https://files.pythonhosted.org/packages/a5/be/3fd5de0979fcb3994bfee0d65ed8ca9506a8a1260651b86174f6a86f52b3/pywin32-311-cp313-cp313-win32.whl", hash = "sha256:f95ba5a847cba10dd8c4d8fefa9f2a6cf283b8b88ed6178fa8a6c1ab16054d0d", size = 8705700 }, 814 | { url = "https://files.pythonhosted.org/packages/e3/28/e0a1909523c6890208295a29e05c2adb2126364e289826c0a8bc7297bd5c/pywin32-311-cp313-cp313-win_amd64.whl", hash = "sha256:718a38f7e5b058e76aee1c56ddd06908116d35147e133427e59a3983f703a20d", size = 9494700 }, 815 | { url = "https://files.pythonhosted.org/packages/04/bf/90339ac0f55726dce7d794e6d79a18a91265bdf3aa70b6b9ca52f35e022a/pywin32-311-cp313-cp313-win_arm64.whl", hash = "sha256:7b4075d959648406202d92a2310cb990fea19b535c7f4a78d3f5e10b926eeb8a", size = 8709318 }, 816 | { url = "https://files.pythonhosted.org/packages/c9/31/097f2e132c4f16d99a22bfb777e0fd88bd8e1c634304e102f313af69ace5/pywin32-311-cp314-cp314-win32.whl", hash = "sha256:b7a2c10b93f8986666d0c803ee19b5990885872a7de910fc460f9b0c2fbf92ee", size = 8840714 }, 817 | { url = "https://files.pythonhosted.org/packages/90/4b/07c77d8ba0e01349358082713400435347df8426208171ce297da32c313d/pywin32-311-cp314-cp314-win_amd64.whl", hash = "sha256:3aca44c046bd2ed8c90de9cb8427f581c479e594e99b5c0bb19b29c10fd6cb87", size = 9656800 }, 818 | { url = "https://files.pythonhosted.org/packages/c0/d2/21af5c535501a7233e734b8af901574572da66fcc254cb35d0609c9080dd/pywin32-311-cp314-cp314-win_arm64.whl", hash = "sha256:a508e2d9025764a8270f93111a970e1d0fbfc33f4153b388bb649b7eec4f9b42", size = 8932540 }, 819 | ] 820 | 821 | [[package]] 822 | name = "pyyaml" 823 | version = "6.0.2" 824 | source = { registry = "https://pypi.org/simple" } 825 | sdist = { url = "https://files.pythonhosted.org/packages/54/ed/79a089b6be93607fa5cdaedf301d7dfb23af5f25c398d5ead2525b063e17/pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e", size = 130631 } 826 | wheels = [ 827 | { url = "https://files.pythonhosted.org/packages/86/0c/c581167fc46d6d6d7ddcfb8c843a4de25bdd27e4466938109ca68492292c/PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab", size = 183873 }, 828 | { url = "https://files.pythonhosted.org/packages/a8/0c/38374f5bb272c051e2a69281d71cba6fdb983413e6758b84482905e29a5d/PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725", size = 173302 }, 829 | { url = "https://files.pythonhosted.org/packages/c3/93/9916574aa8c00aa06bbac729972eb1071d002b8e158bd0e83a3b9a20a1f7/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5", size = 739154 }, 830 | { url = "https://files.pythonhosted.org/packages/95/0f/b8938f1cbd09739c6da569d172531567dbcc9789e0029aa070856f123984/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425", size = 766223 }, 831 | { url = "https://files.pythonhosted.org/packages/b9/2b/614b4752f2e127db5cc206abc23a8c19678e92b23c3db30fc86ab731d3bd/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476", size = 767542 }, 832 | { url = "https://files.pythonhosted.org/packages/d4/00/dd137d5bcc7efea1836d6264f049359861cf548469d18da90cd8216cf05f/PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48", size = 731164 }, 833 | { url = "https://files.pythonhosted.org/packages/c9/1f/4f998c900485e5c0ef43838363ba4a9723ac0ad73a9dc42068b12aaba4e4/PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b", size = 756611 }, 834 | { url = "https://files.pythonhosted.org/packages/df/d1/f5a275fdb252768b7a11ec63585bc38d0e87c9e05668a139fea92b80634c/PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4", size = 140591 }, 835 | { url = "https://files.pythonhosted.org/packages/0c/e8/4f648c598b17c3d06e8753d7d13d57542b30d56e6c2dedf9c331ae56312e/PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8", size = 156338 }, 836 | { url = "https://files.pythonhosted.org/packages/ef/e3/3af305b830494fa85d95f6d95ef7fa73f2ee1cc8ef5b495c7c3269fb835f/PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba", size = 181309 }, 837 | { url = "https://files.pythonhosted.org/packages/45/9f/3b1c20a0b7a3200524eb0076cc027a970d320bd3a6592873c85c92a08731/PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1", size = 171679 }, 838 | { url = "https://files.pythonhosted.org/packages/7c/9a/337322f27005c33bcb656c655fa78325b730324c78620e8328ae28b64d0c/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133", size = 733428 }, 839 | { url = "https://files.pythonhosted.org/packages/a3/69/864fbe19e6c18ea3cc196cbe5d392175b4cf3d5d0ac1403ec3f2d237ebb5/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484", size = 763361 }, 840 | { url = "https://files.pythonhosted.org/packages/04/24/b7721e4845c2f162d26f50521b825fb061bc0a5afcf9a386840f23ea19fa/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5", size = 759523 }, 841 | { url = "https://files.pythonhosted.org/packages/2b/b2/e3234f59ba06559c6ff63c4e10baea10e5e7df868092bf9ab40e5b9c56b6/PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc", size = 726660 }, 842 | { url = "https://files.pythonhosted.org/packages/fe/0f/25911a9f080464c59fab9027482f822b86bf0608957a5fcc6eaac85aa515/PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652", size = 751597 }, 843 | { url = "https://files.pythonhosted.org/packages/14/0d/e2c3b43bbce3cf6bd97c840b46088a3031085179e596d4929729d8d68270/PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183", size = 140527 }, 844 | { url = "https://files.pythonhosted.org/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446 }, 845 | ] 846 | 847 | [[package]] 848 | name = "pyzmq" 849 | version = "27.0.2" 850 | source = { registry = "https://pypi.org/simple" } 851 | dependencies = [ 852 | { name = "cffi", marker = "implementation_name == 'pypy'" }, 853 | ] 854 | sdist = { url = "https://files.pythonhosted.org/packages/f8/66/159f38d184f08b5f971b467f87b1ab142ab1320d5200825c824b32b84b66/pyzmq-27.0.2.tar.gz", hash = "sha256:b398dd713b18de89730447347e96a0240225e154db56e35b6bb8447ffdb07798", size = 281440 } 855 | wheels = [ 856 | { url = "https://files.pythonhosted.org/packages/68/69/b3a729e7b03e412bee2b1823ab8d22e20a92593634f664afd04c6c9d9ac0/pyzmq-27.0.2-cp312-abi3-macosx_10_15_universal2.whl", hash = "sha256:5da05e3c22c95e23bfc4afeee6ff7d4be9ff2233ad6cb171a0e8257cd46b169a", size = 1305910 }, 857 | { url = "https://files.pythonhosted.org/packages/15/b7/f6a6a285193d489b223c340b38ee03a673467cb54914da21c3d7849f1b10/pyzmq-27.0.2-cp312-abi3-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:4e4520577971d01d47e2559bb3175fce1be9103b18621bf0b241abe0a933d040", size = 895507 }, 858 | { url = "https://files.pythonhosted.org/packages/17/e6/c4ed2da5ef9182cde1b1f5d0051a986e76339d71720ec1a00be0b49275ad/pyzmq-27.0.2-cp312-abi3-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:56d7de7bf73165b90bd25a8668659ccb134dd28449116bf3c7e9bab5cf8a8ec9", size = 652670 }, 859 | { url = "https://files.pythonhosted.org/packages/0e/66/d781ab0636570d32c745c4e389b1c6b713115905cca69ab6233508622edd/pyzmq-27.0.2-cp312-abi3-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:340e7cddc32f147c6c00d116a3f284ab07ee63dbd26c52be13b590520434533c", size = 840581 }, 860 | { url = "https://files.pythonhosted.org/packages/a6/df/f24790caf565d72544f5c8d8500960b9562c1dc848d6f22f3c7e122e73d4/pyzmq-27.0.2-cp312-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ba95693f9df8bb4a9826464fb0fe89033936f35fd4a8ff1edff09a473570afa0", size = 1641931 }, 861 | { url = "https://files.pythonhosted.org/packages/65/65/77d27b19fc5e845367f9100db90b9fce924f611b14770db480615944c9c9/pyzmq-27.0.2-cp312-abi3-musllinux_1_2_i686.whl", hash = "sha256:ca42a6ce2d697537da34f77a1960d21476c6a4af3e539eddb2b114c3cf65a78c", size = 2021226 }, 862 | { url = "https://files.pythonhosted.org/packages/5b/65/1ed14421ba27a4207fa694772003a311d1142b7f543179e4d1099b7eb746/pyzmq-27.0.2-cp312-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:3e44e665d78a07214b2772ccbd4b9bcc6d848d7895f1b2d7653f047b6318a4f6", size = 1878047 }, 863 | { url = "https://files.pythonhosted.org/packages/dd/dc/e578549b89b40dc78a387ec471c2a360766690c0a045cd8d1877d401012d/pyzmq-27.0.2-cp312-abi3-win32.whl", hash = "sha256:272d772d116615397d2be2b1417b3b8c8bc8671f93728c2f2c25002a4530e8f6", size = 558757 }, 864 | { url = "https://files.pythonhosted.org/packages/b5/89/06600980aefcc535c758414da969f37a5194ea4cdb73b745223f6af3acfb/pyzmq-27.0.2-cp312-abi3-win_amd64.whl", hash = "sha256:734be4f44efba0aa69bf5f015ed13eb69ff29bf0d17ea1e21588b095a3147b8e", size = 619281 }, 865 | { url = "https://files.pythonhosted.org/packages/30/84/df8a5c089552d17c9941d1aea4314b606edf1b1622361dae89aacedc6467/pyzmq-27.0.2-cp312-abi3-win_arm64.whl", hash = "sha256:41f0bd56d9279392810950feb2785a419c2920bbf007fdaaa7f4a07332ae492d", size = 552680 }, 866 | { url = "https://files.pythonhosted.org/packages/b4/7b/b79e976508517ab80dc800f7021ef1fb602a6d55e4caa2d47fb3dca5d8b6/pyzmq-27.0.2-cp313-cp313-android_24_arm64_v8a.whl", hash = "sha256:7f01118133427cd7f34ee133b5098e2af5f70303fa7519785c007bca5aa6f96a", size = 1122259 }, 867 | { url = "https://files.pythonhosted.org/packages/2b/1c/777217b9940ebcb7e71c924184ca5f31e410580a58d9fd93798589f0d31c/pyzmq-27.0.2-cp313-cp313-android_24_x86_64.whl", hash = "sha256:e4b860edf6379a7234ccbb19b4ed2c57e3ff569c3414fadfb49ae72b61a8ef07", size = 1156113 }, 868 | { url = "https://files.pythonhosted.org/packages/59/7d/654657a4c6435f41538182e71b61eac386a789a2bbb6f30171915253a9a7/pyzmq-27.0.2-cp313-cp313t-macosx_10_15_universal2.whl", hash = "sha256:cb77923ea163156da14295c941930bd525df0d29c96c1ec2fe3c3806b1e17cb3", size = 1341437 }, 869 | { url = "https://files.pythonhosted.org/packages/20/a0/5ed7710037f9c096017adc748bcb1698674a2d297f8b9422d38816f7b56a/pyzmq-27.0.2-cp313-cp313t-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:61678b7407b04df8f9423f188156355dc94d0fb52d360ae79d02ed7e0d431eea", size = 897888 }, 870 | { url = "https://files.pythonhosted.org/packages/2c/8a/6e4699a60931c17e7406641d201d7f2c121e2a38979bc83226a6d8f1ba32/pyzmq-27.0.2-cp313-cp313t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e3c824b70925963bdc8e39a642672c15ffaa67e7d4b491f64662dd56d6271263", size = 660727 }, 871 | { url = "https://files.pythonhosted.org/packages/7b/d8/d761e438c186451bd89ce63a665cde5690c084b61cd8f5d7b51e966e875a/pyzmq-27.0.2-cp313-cp313t-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c4833e02fcf2751975457be1dfa2f744d4d09901a8cc106acaa519d868232175", size = 848136 }, 872 | { url = "https://files.pythonhosted.org/packages/43/f1/a0f31684efdf3eb92f46b7dd2117e752208115e89d278f8ca5f413c5bb85/pyzmq-27.0.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:b18045668d09cf0faa44918af2a67f0dbbef738c96f61c2f1b975b1ddb92ccfc", size = 1650402 }, 873 | { url = "https://files.pythonhosted.org/packages/41/fd/0d7f2a1732812df02c85002770da4a7864c79b210084bcdab01ea57e8d92/pyzmq-27.0.2-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:bbbb7e2f3ac5a22901324e7b086f398b8e16d343879a77b15ca3312e8cd8e6d5", size = 2024587 }, 874 | { url = "https://files.pythonhosted.org/packages/f1/73/358be69e279a382dd09e46dda29df8446365cddee4f79ef214e71e5b2b5a/pyzmq-27.0.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:b751914a73604d40d88a061bab042a11d4511b3ddbb7624cd83c39c8a498564c", size = 1885493 }, 875 | { url = "https://files.pythonhosted.org/packages/c5/7b/e9951ad53b3dfed8cfb4c2cfd6e0097c9b454e5c0d0e6df5f2b60d7c8c3d/pyzmq-27.0.2-cp313-cp313t-win32.whl", hash = "sha256:3e8f833dd82af11db5321c414638045c70f61009f72dd61c88db4a713c1fb1d2", size = 574934 }, 876 | { url = "https://files.pythonhosted.org/packages/55/33/1a7fc3a92f2124a63e6e2a6afa0af471a5c0c713e776b476d4eda5111b13/pyzmq-27.0.2-cp313-cp313t-win_amd64.whl", hash = "sha256:5b45153cb8eadcab14139970643a84f7a7b08dda541fbc1f6f4855c49334b549", size = 640932 }, 877 | { url = "https://files.pythonhosted.org/packages/2a/52/2598a94ac251a7c83f3887866225eea1952b0d4463a68df5032eb00ff052/pyzmq-27.0.2-cp313-cp313t-win_arm64.whl", hash = "sha256:86898f5c9730df23427c1ee0097d8aa41aa5f89539a79e48cd0d2c22d059f1b7", size = 561315 }, 878 | { url = "https://files.pythonhosted.org/packages/42/7d/10ef02ea36590b29d48ef88eb0831f0af3eb240cccca2752556faec55f59/pyzmq-27.0.2-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:d2b4b261dce10762be5c116b6ad1f267a9429765b493c454f049f33791dd8b8a", size = 1341463 }, 879 | { url = "https://files.pythonhosted.org/packages/94/36/115d18dade9a3d4d3d08dd8bfe5459561b8e02815f99df040555fdd7768e/pyzmq-27.0.2-cp314-cp314t-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:4e4d88b6cff156fed468903006b24bbd85322612f9c2f7b96e72d5016fd3f543", size = 897840 }, 880 | { url = "https://files.pythonhosted.org/packages/39/66/083b37839b95c386a95f1537bb41bdbf0c002b7c55b75ee737949cecb11f/pyzmq-27.0.2-cp314-cp314t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8426c0ebbc11ed8416a6e9409c194142d677c2c5c688595f2743664e356d9e9b", size = 660704 }, 881 | { url = "https://files.pythonhosted.org/packages/76/5a/196ab46e549ba35bf3268f575e10cfac0dc86b78dcaa7a3e36407ecda752/pyzmq-27.0.2-cp314-cp314t-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:565bee96a155fe6452caed5fb5f60c9862038e6b51a59f4f632562081cdb4004", size = 848037 }, 882 | { url = "https://files.pythonhosted.org/packages/70/ea/a27b9eb44b2e615a9ecb8510ebb023cc1d2d251181e4a1e50366bfbf94d6/pyzmq-27.0.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:5de735c745ca5cefe9c2d1547d8f28cfe1b1926aecb7483ab1102fd0a746c093", size = 1650278 }, 883 | { url = "https://files.pythonhosted.org/packages/62/ac/3e9af036bfaf718ab5e69ded8f6332da392c5450ad43e8e3ca66797f145a/pyzmq-27.0.2-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:ea4f498f8115fd90d7bf03a3e83ae3e9898e43362f8e8e8faec93597206e15cc", size = 2024504 }, 884 | { url = "https://files.pythonhosted.org/packages/ae/e9/3202d31788df8ebaa176b23d846335eb9c768d8b43c0506bbd6265ad36a0/pyzmq-27.0.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:d00e81cb0afd672915257a3927124ee2ad117ace3c256d39cd97ca3f190152ad", size = 1885381 }, 885 | { url = "https://files.pythonhosted.org/packages/4b/ed/42de80b7ab4e8fcf13376f81206cf8041740672ac1fd2e1c598d63f595bf/pyzmq-27.0.2-cp314-cp314t-win32.whl", hash = "sha256:0f6e9b00d81b58f859fffc112365d50413954e02aefe36c5b4c8fb4af79f8cc3", size = 587526 }, 886 | { url = "https://files.pythonhosted.org/packages/ed/c8/8f3c72d6f0bfbf090aa5e283576073ca5c59839b85a5cc8c66ddb9b59801/pyzmq-27.0.2-cp314-cp314t-win_amd64.whl", hash = "sha256:2e73cf3b127a437fef4100eb3ac2ebe6b49e655bb721329f667f59eca0a26221", size = 661368 }, 887 | { url = "https://files.pythonhosted.org/packages/69/a4/7ee652ea1c77d872f5d99ed937fa8bbd1f6f4b7a39a6d3a0076c286e0c3e/pyzmq-27.0.2-cp314-cp314t-win_arm64.whl", hash = "sha256:4108785f2e5ac865d06f678a07a1901e3465611356df21a545eeea8b45f56265", size = 574901 }, 888 | ] 889 | 890 | [[package]] 891 | name = "requests" 892 | version = "2.32.5" 893 | source = { registry = "https://pypi.org/simple" } 894 | dependencies = [ 895 | { name = "certifi" }, 896 | { name = "charset-normalizer" }, 897 | { name = "idna" }, 898 | { name = "urllib3" }, 899 | ] 900 | sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517 } 901 | wheels = [ 902 | { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738 }, 903 | ] 904 | 905 | [[package]] 906 | name = "ruff" 907 | version = "0.12.11" 908 | source = { registry = "https://pypi.org/simple" } 909 | sdist = { url = "https://files.pythonhosted.org/packages/de/55/16ab6a7d88d93001e1ae4c34cbdcfb376652d761799459ff27c1dc20f6fa/ruff-0.12.11.tar.gz", hash = "sha256:c6b09ae8426a65bbee5425b9d0b82796dbb07cb1af045743c79bfb163001165d", size = 5347103 } 910 | wheels = [ 911 | { url = "https://files.pythonhosted.org/packages/d6/a2/3b3573e474de39a7a475f3fbaf36a25600bfeb238e1a90392799163b64a0/ruff-0.12.11-py3-none-linux_armv6l.whl", hash = "sha256:93fce71e1cac3a8bf9200e63a38ac5c078f3b6baebffb74ba5274fb2ab276065", size = 11979885 }, 912 | { url = "https://files.pythonhosted.org/packages/76/e4/235ad6d1785a2012d3ded2350fd9bc5c5af8c6f56820e696b0118dfe7d24/ruff-0.12.11-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b8e33ac7b28c772440afa80cebb972ffd823621ded90404f29e5ab6d1e2d4b93", size = 12742364 }, 913 | { url = "https://files.pythonhosted.org/packages/2c/0d/15b72c5fe6b1e402a543aa9d8960e0a7e19dfb079f5b0b424db48b7febab/ruff-0.12.11-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d69fb9d4937aa19adb2e9f058bc4fbfe986c2040acb1a4a9747734834eaa0bfd", size = 11920111 }, 914 | { url = "https://files.pythonhosted.org/packages/3e/c0/f66339d7893798ad3e17fa5a1e587d6fd9806f7c1c062b63f8b09dda6702/ruff-0.12.11-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:411954eca8464595077a93e580e2918d0a01a19317af0a72132283e28ae21bee", size = 12160060 }, 915 | { url = "https://files.pythonhosted.org/packages/03/69/9870368326db26f20c946205fb2d0008988aea552dbaec35fbacbb46efaa/ruff-0.12.11-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6a2c0a2e1a450f387bf2c6237c727dd22191ae8c00e448e0672d624b2bbd7fb0", size = 11799848 }, 916 | { url = "https://files.pythonhosted.org/packages/25/8c/dd2c7f990e9b3a8a55eee09d4e675027d31727ce33cdb29eab32d025bdc9/ruff-0.12.11-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ca4c3a7f937725fd2413c0e884b5248a19369ab9bdd850b5781348ba283f644", size = 13536288 }, 917 | { url = "https://files.pythonhosted.org/packages/7a/30/d5496fa09aba59b5e01ea76775a4c8897b13055884f56f1c35a4194c2297/ruff-0.12.11-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:4d1df0098124006f6a66ecf3581a7f7e754c4df7644b2e6704cd7ca80ff95211", size = 14490633 }, 918 | { url = "https://files.pythonhosted.org/packages/9b/2f/81f998180ad53445d403c386549d6946d0748e536d58fce5b5e173511183/ruff-0.12.11-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5a8dd5f230efc99a24ace3b77e3555d3fbc0343aeed3fc84c8d89e75ab2ff793", size = 13888430 }, 919 | { url = "https://files.pythonhosted.org/packages/87/71/23a0d1d5892a377478c61dbbcffe82a3476b050f38b5162171942a029ef3/ruff-0.12.11-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4dc75533039d0ed04cd33fb8ca9ac9620b99672fe7ff1533b6402206901c34ee", size = 12913133 }, 920 | { url = "https://files.pythonhosted.org/packages/80/22/3c6cef96627f89b344c933781ed38329bfb87737aa438f15da95907cbfd5/ruff-0.12.11-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4fc58f9266d62c6eccc75261a665f26b4ef64840887fc6cbc552ce5b29f96cc8", size = 13169082 }, 921 | { url = "https://files.pythonhosted.org/packages/05/b5/68b3ff96160d8b49e8dd10785ff3186be18fd650d356036a3770386e6c7f/ruff-0.12.11-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:5a0113bd6eafd545146440225fe60b4e9489f59eb5f5f107acd715ba5f0b3d2f", size = 13139490 }, 922 | { url = "https://files.pythonhosted.org/packages/59/b9/050a3278ecd558f74f7ee016fbdf10591d50119df8d5f5da45a22c6afafc/ruff-0.12.11-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:0d737b4059d66295c3ea5720e6efc152623bb83fde5444209b69cd33a53e2000", size = 11958928 }, 923 | { url = "https://files.pythonhosted.org/packages/f9/bc/93be37347db854806904a43b0493af8d6873472dfb4b4b8cbb27786eb651/ruff-0.12.11-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:916fc5defee32dbc1fc1650b576a8fed68f5e8256e2180d4d9855aea43d6aab2", size = 11764513 }, 924 | { url = "https://files.pythonhosted.org/packages/7a/a1/1471751e2015a81fd8e166cd311456c11df74c7e8769d4aabfbc7584c7ac/ruff-0.12.11-py3-none-musllinux_1_2_i686.whl", hash = "sha256:c984f07d7adb42d3ded5be894fb4007f30f82c87559438b4879fe7aa08c62b39", size = 12745154 }, 925 | { url = "https://files.pythonhosted.org/packages/68/ab/2542b14890d0f4872dd81b7b2a6aed3ac1786fae1ce9b17e11e6df9e31e3/ruff-0.12.11-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e07fbb89f2e9249f219d88331c833860489b49cdf4b032b8e4432e9b13e8a4b9", size = 13227653 }, 926 | { url = "https://files.pythonhosted.org/packages/22/16/2fbfc61047dbfd009c58a28369a693a1484ad15441723be1cd7fe69bb679/ruff-0.12.11-py3-none-win32.whl", hash = "sha256:c792e8f597c9c756e9bcd4d87cf407a00b60af77078c96f7b6366ea2ce9ba9d3", size = 11944270 }, 927 | { url = "https://files.pythonhosted.org/packages/08/a5/34276984705bfe069cd383101c45077ee029c3fe3b28225bf67aa35f0647/ruff-0.12.11-py3-none-win_amd64.whl", hash = "sha256:a3283325960307915b6deb3576b96919ee89432ebd9c48771ca12ee8afe4a0fd", size = 13046600 }, 928 | { url = "https://files.pythonhosted.org/packages/84/a8/001d4a7c2b37623a3fd7463208267fb906df40ff31db496157549cfd6e72/ruff-0.12.11-py3-none-win_arm64.whl", hash = "sha256:bae4d6e6a2676f8fb0f98b74594a048bae1b944aab17e9f5d504062303c6dbea", size = 12135290 }, 929 | ] 930 | 931 | [[package]] 932 | name = "safetensors" 933 | version = "0.6.2" 934 | source = { registry = "https://pypi.org/simple" } 935 | sdist = { url = "https://files.pythonhosted.org/packages/ac/cc/738f3011628920e027a11754d9cae9abec1aed00f7ae860abbf843755233/safetensors-0.6.2.tar.gz", hash = "sha256:43ff2aa0e6fa2dc3ea5524ac7ad93a9839256b8703761e76e2d0b2a3fa4f15d9", size = 197968 } 936 | wheels = [ 937 | { url = "https://files.pythonhosted.org/packages/4d/b1/3f5fd73c039fc87dba3ff8b5d528bfc5a32b597fea8e7a6a4800343a17c7/safetensors-0.6.2-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:9c85ede8ec58f120bad982ec47746981e210492a6db876882aa021446af8ffba", size = 454797 }, 938 | { url = "https://files.pythonhosted.org/packages/8c/c9/bb114c158540ee17907ec470d01980957fdaf87b4aa07914c24eba87b9c6/safetensors-0.6.2-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:d6675cf4b39c98dbd7d940598028f3742e0375a6b4d4277e76beb0c35f4b843b", size = 432206 }, 939 | { url = "https://files.pythonhosted.org/packages/d3/8e/f70c34e47df3110e8e0bb268d90db8d4be8958a54ab0336c9be4fe86dac8/safetensors-0.6.2-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d2d2b3ce1e2509c68932ca03ab8f20570920cd9754b05063d4368ee52833ecd", size = 473261 }, 940 | { url = "https://files.pythonhosted.org/packages/2a/f5/be9c6a7c7ef773e1996dc214e73485286df1836dbd063e8085ee1976f9cb/safetensors-0.6.2-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:93de35a18f46b0f5a6a1f9e26d91b442094f2df02e9fd7acf224cfec4238821a", size = 485117 }, 941 | { url = "https://files.pythonhosted.org/packages/c9/55/23f2d0a2c96ed8665bf17a30ab4ce5270413f4d74b6d87dd663258b9af31/safetensors-0.6.2-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:89a89b505f335640f9120fac65ddeb83e40f1fd081cb8ed88b505bdccec8d0a1", size = 616154 }, 942 | { url = "https://files.pythonhosted.org/packages/98/c6/affb0bd9ce02aa46e7acddbe087912a04d953d7a4d74b708c91b5806ef3f/safetensors-0.6.2-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fc4d0d0b937e04bdf2ae6f70cd3ad51328635fe0e6214aa1fc811f3b576b3bda", size = 520713 }, 943 | { url = "https://files.pythonhosted.org/packages/fe/5d/5a514d7b88e310c8b146e2404e0dc161282e78634d9358975fd56dfd14be/safetensors-0.6.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8045db2c872db8f4cbe3faa0495932d89c38c899c603f21e9b6486951a5ecb8f", size = 485835 }, 944 | { url = "https://files.pythonhosted.org/packages/7a/7b/4fc3b2ba62c352b2071bea9cfbad330fadda70579f617506ae1a2f129cab/safetensors-0.6.2-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:81e67e8bab9878bb568cffbc5f5e655adb38d2418351dc0859ccac158f753e19", size = 521503 }, 945 | { url = "https://files.pythonhosted.org/packages/5a/50/0057e11fe1f3cead9254315a6c106a16dd4b1a19cd247f7cc6414f6b7866/safetensors-0.6.2-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:b0e4d029ab0a0e0e4fdf142b194514695b1d7d3735503ba700cf36d0fc7136ce", size = 652256 }, 946 | { url = "https://files.pythonhosted.org/packages/e9/29/473f789e4ac242593ac1656fbece6e1ecd860bb289e635e963667807afe3/safetensors-0.6.2-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:fa48268185c52bfe8771e46325a1e21d317207bcabcb72e65c6e28e9ffeb29c7", size = 747281 }, 947 | { url = "https://files.pythonhosted.org/packages/68/52/f7324aad7f2df99e05525c84d352dc217e0fa637a4f603e9f2eedfbe2c67/safetensors-0.6.2-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:d83c20c12c2d2f465997c51b7ecb00e407e5f94d7dec3ea0cc11d86f60d3fde5", size = 692286 }, 948 | { url = "https://files.pythonhosted.org/packages/ad/fe/cad1d9762868c7c5dc70c8620074df28ebb1a8e4c17d4c0cb031889c457e/safetensors-0.6.2-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:d944cea65fad0ead848b6ec2c37cc0b197194bec228f8020054742190e9312ac", size = 655957 }, 949 | { url = "https://files.pythonhosted.org/packages/59/a7/e2158e17bbe57d104f0abbd95dff60dda916cf277c9f9663b4bf9bad8b6e/safetensors-0.6.2-cp38-abi3-win32.whl", hash = "sha256:cab75ca7c064d3911411461151cb69380c9225798a20e712b102edda2542ddb1", size = 308926 }, 950 | { url = "https://files.pythonhosted.org/packages/2c/c3/c0be1135726618dc1e28d181b8c442403d8dbb9e273fd791de2d4384bcdd/safetensors-0.6.2-cp38-abi3-win_amd64.whl", hash = "sha256:c7b214870df923cbc1593c3faee16bec59ea462758699bd3fee399d00aac072c", size = 320192 }, 951 | ] 952 | 953 | [[package]] 954 | name = "setuptools" 955 | version = "80.9.0" 956 | source = { registry = "https://pypi.org/simple" } 957 | sdist = { url = "https://files.pythonhosted.org/packages/18/5d/3bf57dcd21979b887f014ea83c24ae194cfcd12b9e0fda66b957c69d1fca/setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c", size = 1319958 } 958 | wheels = [ 959 | { url = "https://files.pythonhosted.org/packages/a3/dc/17031897dae0efacfea57dfd3a82fdd2a2aeb58e0ff71b77b87e44edc772/setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922", size = 1201486 }, 960 | ] 961 | 962 | [[package]] 963 | name = "six" 964 | version = "1.17.0" 965 | source = { registry = "https://pypi.org/simple" } 966 | sdist = { url = "https://files.pythonhosted.org/packages/94/e7/b2c673351809dca68a0e064b6af791aa332cf192da575fd474ed7d6f16a2/six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81", size = 34031 } 967 | wheels = [ 968 | { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050 }, 969 | ] 970 | 971 | [[package]] 972 | name = "sparse-vggt" 973 | version = "0.1.0" 974 | source = { editable = "." } 975 | dependencies = [ 976 | { name = "einops" }, 977 | { name = "huggingface-hub" }, 978 | { name = "ipykernel" }, 979 | { name = "ninja" }, 980 | { name = "numpy" }, 981 | { name = "pillow" }, 982 | { name = "plyfile" }, 983 | { name = "safetensors" }, 984 | { name = "torch" }, 985 | { name = "torchvision" }, 986 | { name = "vggt" }, 987 | ] 988 | 989 | [package.dev-dependencies] 990 | dev = [ 991 | { name = "pre-commit" }, 992 | { name = "ruff" }, 993 | ] 994 | 995 | [package.metadata] 996 | requires-dist = [ 997 | { name = "einops", specifier = ">=0.8.1" }, 998 | { name = "huggingface-hub", specifier = ">=0.34.4" }, 999 | { name = "ipykernel", specifier = ">=6.30.1" }, 1000 | { name = "ninja", specifier = ">=1.13.0" }, 1001 | { name = "numpy", specifier = "==1.26.4" }, 1002 | { name = "pillow", specifier = ">=10.3.0" }, 1003 | { name = "plyfile", specifier = ">=1.1.2" }, 1004 | { name = "safetensors", specifier = ">=0.6.2" }, 1005 | { name = "torch", specifier = ">=2.8.0" }, 1006 | { name = "torchvision", specifier = ">=0.23.0" }, 1007 | { name = "vggt", editable = "external/vggt" }, 1008 | ] 1009 | 1010 | [package.metadata.requires-dev] 1011 | dev = [ 1012 | { name = "pre-commit" }, 1013 | { name = "ruff" }, 1014 | ] 1015 | 1016 | [[package]] 1017 | name = "stack-data" 1018 | version = "0.6.3" 1019 | source = { registry = "https://pypi.org/simple" } 1020 | dependencies = [ 1021 | { name = "asttokens" }, 1022 | { name = "executing" }, 1023 | { name = "pure-eval" }, 1024 | ] 1025 | sdist = { url = "https://files.pythonhosted.org/packages/28/e3/55dcc2cfbc3ca9c29519eb6884dd1415ecb53b0e934862d3559ddcb7e20b/stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9", size = 44707 } 1026 | wheels = [ 1027 | { url = "https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695", size = 24521 }, 1028 | ] 1029 | 1030 | [[package]] 1031 | name = "sympy" 1032 | version = "1.14.0" 1033 | source = { registry = "https://pypi.org/simple" } 1034 | dependencies = [ 1035 | { name = "mpmath" }, 1036 | ] 1037 | sdist = { url = "https://files.pythonhosted.org/packages/83/d3/803453b36afefb7c2bb238361cd4ae6125a569b4db67cd9e79846ba2d68c/sympy-1.14.0.tar.gz", hash = "sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517", size = 7793921 } 1038 | wheels = [ 1039 | { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353 }, 1040 | ] 1041 | 1042 | [[package]] 1043 | name = "torch" 1044 | version = "2.8.0" 1045 | source = { registry = "https://pypi.org/simple" } 1046 | dependencies = [ 1047 | { name = "filelock" }, 1048 | { name = "fsspec" }, 1049 | { name = "jinja2" }, 1050 | { name = "networkx" }, 1051 | { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, 1052 | { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, 1053 | { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, 1054 | { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, 1055 | { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, 1056 | { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, 1057 | { name = "nvidia-cufile-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, 1058 | { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, 1059 | { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, 1060 | { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, 1061 | { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, 1062 | { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, 1063 | { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, 1064 | { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, 1065 | { name = "setuptools" }, 1066 | { name = "sympy" }, 1067 | { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, 1068 | { name = "typing-extensions" }, 1069 | ] 1070 | wheels = [ 1071 | { url = "https://files.pythonhosted.org/packages/49/0c/2fd4df0d83a495bb5e54dca4474c4ec5f9c62db185421563deeb5dabf609/torch-2.8.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:e2fab4153768d433f8ed9279c8133a114a034a61e77a3a104dcdf54388838705", size = 101906089 }, 1072 | { url = "https://files.pythonhosted.org/packages/99/a8/6acf48d48838fb8fe480597d98a0668c2beb02ee4755cc136de92a0a956f/torch-2.8.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b2aca0939fb7e4d842561febbd4ffda67a8e958ff725c1c27e244e85e982173c", size = 887913624 }, 1073 | { url = "https://files.pythonhosted.org/packages/af/8a/5c87f08e3abd825c7dfecef5a0f1d9aa5df5dd0e3fd1fa2f490a8e512402/torch-2.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:2f4ac52f0130275d7517b03a33d2493bab3693c83dcfadf4f81688ea82147d2e", size = 241326087 }, 1074 | { url = "https://files.pythonhosted.org/packages/be/66/5c9a321b325aaecb92d4d1855421e3a055abd77903b7dab6575ca07796db/torch-2.8.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:619c2869db3ada2c0105487ba21b5008defcc472d23f8b80ed91ac4a380283b0", size = 73630478 }, 1075 | { url = "https://files.pythonhosted.org/packages/10/4e/469ced5a0603245d6a19a556e9053300033f9c5baccf43a3d25ba73e189e/torch-2.8.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:2b2f96814e0345f5a5aed9bf9734efa913678ed19caf6dc2cddb7930672d6128", size = 101936856 }, 1076 | { url = "https://files.pythonhosted.org/packages/16/82/3948e54c01b2109238357c6f86242e6ecbf0c63a1af46906772902f82057/torch-2.8.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:65616ca8ec6f43245e1f5f296603e33923f4c30f93d65e103d9e50c25b35150b", size = 887922844 }, 1077 | { url = "https://files.pythonhosted.org/packages/e3/54/941ea0a860f2717d86a811adf0c2cd01b3983bdd460d0803053c4e0b8649/torch-2.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:659df54119ae03e83a800addc125856effda88b016dfc54d9f65215c3975be16", size = 241330968 }, 1078 | { url = "https://files.pythonhosted.org/packages/de/69/8b7b13bba430f5e21d77708b616f767683629fc4f8037564a177d20f90ed/torch-2.8.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:1a62a1ec4b0498930e2543535cf70b1bef8c777713de7ceb84cd79115f553767", size = 73915128 }, 1079 | { url = "https://files.pythonhosted.org/packages/15/0e/8a800e093b7f7430dbaefa80075aee9158ec22e4c4fc3c1a66e4fb96cb4f/torch-2.8.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:83c13411a26fac3d101fe8035a6b0476ae606deb8688e904e796a3534c197def", size = 102020139 }, 1080 | { url = "https://files.pythonhosted.org/packages/4a/15/5e488ca0bc6162c86a33b58642bc577c84ded17c7b72d97e49b5833e2d73/torch-2.8.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:8f0a9d617a66509ded240add3754e462430a6c1fc5589f86c17b433dd808f97a", size = 887990692 }, 1081 | { url = "https://files.pythonhosted.org/packages/b4/a8/6a04e4b54472fc5dba7ca2341ab219e529f3c07b6941059fbf18dccac31f/torch-2.8.0-cp313-cp313t-win_amd64.whl", hash = "sha256:a7242b86f42be98ac674b88a4988643b9bc6145437ec8f048fea23f72feb5eca", size = 241603453 }, 1082 | { url = "https://files.pythonhosted.org/packages/04/6e/650bb7f28f771af0cb791b02348db8b7f5f64f40f6829ee82aa6ce99aabe/torch-2.8.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:7b677e17f5a3e69fdef7eb3b9da72622f8d322692930297e4ccb52fefc6c8211", size = 73632395 }, 1083 | ] 1084 | 1085 | [[package]] 1086 | name = "torchvision" 1087 | version = "0.23.0" 1088 | source = { registry = "https://pypi.org/simple" } 1089 | dependencies = [ 1090 | { name = "numpy" }, 1091 | { name = "pillow" }, 1092 | { name = "torch" }, 1093 | ] 1094 | wheels = [ 1095 | { url = "https://files.pythonhosted.org/packages/df/1d/0ea0b34bde92a86d42620f29baa6dcbb5c2fc85990316df5cb8f7abb8ea2/torchvision-0.23.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e0e2c04a91403e8dd3af9756c6a024a1d9c0ed9c0d592a8314ded8f4fe30d440", size = 1856885 }, 1096 | { url = "https://files.pythonhosted.org/packages/e2/00/2f6454decc0cd67158c7890364e446aad4b91797087a57a78e72e1a8f8bc/torchvision-0.23.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:6dd7c4d329a0e03157803031bc856220c6155ef08c26d4f5bbac938acecf0948", size = 2396614 }, 1097 | { url = "https://files.pythonhosted.org/packages/e4/b5/3e580dcbc16f39a324f3dd71b90edbf02a42548ad44d2b4893cc92b1194b/torchvision-0.23.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:4e7d31c43bc7cbecbb1a5652ac0106b436aa66e26437585fc2c4b2cf04d6014c", size = 8627108 }, 1098 | { url = "https://files.pythonhosted.org/packages/82/c1/c2fe6d61e110a8d0de2f94276899a2324a8f1e6aee559eb6b4629ab27466/torchvision-0.23.0-cp312-cp312-win_amd64.whl", hash = "sha256:a2e45272abe7b8bf0d06c405e78521b5757be1bd0ed7e5cd78120f7fdd4cbf35", size = 1600723 }, 1099 | { url = "https://files.pythonhosted.org/packages/91/37/45a5b9407a7900f71d61b2b2f62db4b7c632debca397f205fdcacb502780/torchvision-0.23.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1c37e325e09a184b730c3ef51424f383ec5745378dc0eca244520aca29722600", size = 1856886 }, 1100 | { url = "https://files.pythonhosted.org/packages/ac/da/a06c60fc84fc849377cf035d3b3e9a1c896d52dbad493b963c0f1cdd74d0/torchvision-0.23.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:2f7fd6c15f3697e80627b77934f77705f3bc0e98278b989b2655de01f6903e1d", size = 2353112 }, 1101 | { url = "https://files.pythonhosted.org/packages/a0/27/5ce65ba5c9d3b7d2ccdd79892ab86a2f87ac2ca6638f04bb0280321f1a9c/torchvision-0.23.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:a76fafe113b2977be3a21bf78f115438c1f88631d7a87203acb3dd6ae55889e6", size = 8627658 }, 1102 | { url = "https://files.pythonhosted.org/packages/1f/e4/028a27b60aa578a2fa99d9d7334ff1871bb17008693ea055a2fdee96da0d/torchvision-0.23.0-cp313-cp313-win_amd64.whl", hash = "sha256:07d069cb29691ff566e3b7f11f20d91044f079e1dbdc9d72e0655899a9b06938", size = 1600749 }, 1103 | { url = "https://files.pythonhosted.org/packages/05/35/72f91ad9ac7c19a849dedf083d347dc1123f0adeb401f53974f84f1d04c8/torchvision-0.23.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:2df618e1143805a7673aaf82cb5720dd9112d4e771983156aaf2ffff692eebf9", size = 2047192 }, 1104 | { url = "https://files.pythonhosted.org/packages/1d/9d/406cea60a9eb9882145bcd62a184ee61e823e8e1d550cdc3c3ea866a9445/torchvision-0.23.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:2a3299d2b1d5a7aed2d3b6ffb69c672ca8830671967eb1cee1497bacd82fe47b", size = 2359295 }, 1105 | { url = "https://files.pythonhosted.org/packages/2b/f4/34662f71a70fa1e59de99772142f22257ca750de05ccb400b8d2e3809c1d/torchvision-0.23.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:76bc4c0b63d5114aa81281390f8472a12a6a35ce9906e67ea6044e5af4cab60c", size = 8800474 }, 1106 | { url = "https://files.pythonhosted.org/packages/6e/f5/b5a2d841a8d228b5dbda6d524704408e19e7ca6b7bb0f24490e081da1fa1/torchvision-0.23.0-cp313-cp313t-win_amd64.whl", hash = "sha256:b9e2dabf0da9c8aa9ea241afb63a8f3e98489e706b22ac3f30416a1be377153b", size = 1527667 }, 1107 | ] 1108 | 1109 | [[package]] 1110 | name = "tornado" 1111 | version = "6.5.2" 1112 | source = { registry = "https://pypi.org/simple" } 1113 | sdist = { url = "https://files.pythonhosted.org/packages/09/ce/1eb500eae19f4648281bb2186927bb062d2438c2e5093d1360391afd2f90/tornado-6.5.2.tar.gz", hash = "sha256:ab53c8f9a0fa351e2c0741284e06c7a45da86afb544133201c5cc8578eb076a0", size = 510821 } 1114 | wheels = [ 1115 | { url = "https://files.pythonhosted.org/packages/f6/48/6a7529df2c9cc12efd2e8f5dd219516184d703b34c06786809670df5b3bd/tornado-6.5.2-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:2436822940d37cde62771cff8774f4f00b3c8024fe482e16ca8387b8a2724db6", size = 442563 }, 1116 | { url = "https://files.pythonhosted.org/packages/f2/b5/9b575a0ed3e50b00c40b08cbce82eb618229091d09f6d14bce80fc01cb0b/tornado-6.5.2-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:583a52c7aa94ee046854ba81d9ebb6c81ec0fd30386d96f7640c96dad45a03ef", size = 440729 }, 1117 | { url = "https://files.pythonhosted.org/packages/1b/4e/619174f52b120efcf23633c817fd3fed867c30bff785e2cd5a53a70e483c/tornado-6.5.2-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b0fe179f28d597deab2842b86ed4060deec7388f1fd9c1b4a41adf8af058907e", size = 444295 }, 1118 | { url = "https://files.pythonhosted.org/packages/95/fa/87b41709552bbd393c85dd18e4e3499dcd8983f66e7972926db8d96aa065/tornado-6.5.2-cp39-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b186e85d1e3536d69583d2298423744740986018e393d0321df7340e71898882", size = 443644 }, 1119 | { url = "https://files.pythonhosted.org/packages/f9/41/fb15f06e33d7430ca89420283a8762a4e6b8025b800ea51796ab5e6d9559/tornado-6.5.2-cp39-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e792706668c87709709c18b353da1f7662317b563ff69f00bab83595940c7108", size = 443878 }, 1120 | { url = "https://files.pythonhosted.org/packages/11/92/fe6d57da897776ad2e01e279170ea8ae726755b045fe5ac73b75357a5a3f/tornado-6.5.2-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:06ceb1300fd70cb20e43b1ad8aaee0266e69e7ced38fa910ad2e03285009ce7c", size = 444549 }, 1121 | { url = "https://files.pythonhosted.org/packages/9b/02/c8f4f6c9204526daf3d760f4aa555a7a33ad0e60843eac025ccfd6ff4a93/tornado-6.5.2-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:74db443e0f5251be86cbf37929f84d8c20c27a355dd452a5cfa2aada0d001ec4", size = 443973 }, 1122 | { url = "https://files.pythonhosted.org/packages/ae/2d/f5f5707b655ce2317190183868cd0f6822a1121b4baeae509ceb9590d0bd/tornado-6.5.2-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b5e735ab2889d7ed33b32a459cac490eda71a1ba6857b0118de476ab6c366c04", size = 443954 }, 1123 | { url = "https://files.pythonhosted.org/packages/e8/59/593bd0f40f7355806bf6573b47b8c22f8e1374c9b6fd03114bd6b7a3dcfd/tornado-6.5.2-cp39-abi3-win32.whl", hash = "sha256:c6f29e94d9b37a95013bb669616352ddb82e3bfe8326fccee50583caebc8a5f0", size = 445023 }, 1124 | { url = "https://files.pythonhosted.org/packages/c7/2a/f609b420c2f564a748a2d80ebfb2ee02a73ca80223af712fca591386cafb/tornado-6.5.2-cp39-abi3-win_amd64.whl", hash = "sha256:e56a5af51cc30dd2cae649429af65ca2f6571da29504a07995175df14c18f35f", size = 445427 }, 1125 | { url = "https://files.pythonhosted.org/packages/5e/4f/e1f65e8f8c76d73658b33d33b81eed4322fb5085350e4328d5c956f0c8f9/tornado-6.5.2-cp39-abi3-win_arm64.whl", hash = "sha256:d6c33dc3672e3a1f3618eb63b7ef4683a7688e7b9e6e8f0d9aa5726360a004af", size = 444456 }, 1126 | ] 1127 | 1128 | [[package]] 1129 | name = "tqdm" 1130 | version = "4.67.1" 1131 | source = { registry = "https://pypi.org/simple" } 1132 | dependencies = [ 1133 | { name = "colorama", marker = "sys_platform == 'win32'" }, 1134 | ] 1135 | sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 } 1136 | wheels = [ 1137 | { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540 }, 1138 | ] 1139 | 1140 | [[package]] 1141 | name = "traitlets" 1142 | version = "5.14.3" 1143 | source = { registry = "https://pypi.org/simple" } 1144 | sdist = { url = "https://files.pythonhosted.org/packages/eb/79/72064e6a701c2183016abbbfedaba506d81e30e232a68c9f0d6f6fcd1574/traitlets-5.14.3.tar.gz", hash = "sha256:9ed0579d3502c94b4b3732ac120375cda96f923114522847de4b3bb98b96b6b7", size = 161621 } 1145 | wheels = [ 1146 | { url = "https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f", size = 85359 }, 1147 | ] 1148 | 1149 | [[package]] 1150 | name = "triton" 1151 | version = "3.4.0" 1152 | source = { registry = "https://pypi.org/simple" } 1153 | dependencies = [ 1154 | { name = "setuptools", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, 1155 | ] 1156 | wheels = [ 1157 | { url = "https://files.pythonhosted.org/packages/d0/66/b1eb52839f563623d185f0927eb3530ee4d5ffe9d377cdaf5346b306689e/triton-3.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:31c1d84a5c0ec2c0f8e8a072d7fd150cab84a9c239eaddc6706c081bfae4eb04", size = 155560068 }, 1158 | { url = "https://files.pythonhosted.org/packages/30/7b/0a685684ed5322d2af0bddefed7906674f67974aa88b0fae6e82e3b766f6/triton-3.4.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:00be2964616f4c619193cb0d1b29a99bd4b001d7dc333816073f92cf2a8ccdeb", size = 155569223 }, 1159 | { url = "https://files.pythonhosted.org/packages/20/63/8cb444ad5cdb25d999b7d647abac25af0ee37d292afc009940c05b82dda0/triton-3.4.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7936b18a3499ed62059414d7df563e6c163c5e16c3773678a3ee3d417865035d", size = 155659780 }, 1160 | ] 1161 | 1162 | [[package]] 1163 | name = "typing-extensions" 1164 | version = "4.15.0" 1165 | source = { registry = "https://pypi.org/simple" } 1166 | sdist = { url = "https://files.pythonhosted.org/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391 } 1167 | wheels = [ 1168 | { url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614 }, 1169 | ] 1170 | 1171 | [[package]] 1172 | name = "urllib3" 1173 | version = "2.5.0" 1174 | source = { registry = "https://pypi.org/simple" } 1175 | sdist = { url = "https://files.pythonhosted.org/packages/15/22/9ee70a2574a4f4599c47dd506532914ce044817c7752a79b6a51286319bc/urllib3-2.5.0.tar.gz", hash = "sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760", size = 393185 } 1176 | wheels = [ 1177 | { url = "https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795 }, 1178 | ] 1179 | 1180 | [[package]] 1181 | name = "vggt" 1182 | version = "0.0.1" 1183 | source = { editable = "external/vggt" } 1184 | dependencies = [ 1185 | { name = "einops" }, 1186 | { name = "huggingface-hub" }, 1187 | { name = "numpy" }, 1188 | { name = "opencv-python" }, 1189 | { name = "pillow" }, 1190 | { name = "safetensors" }, 1191 | ] 1192 | 1193 | [package.metadata] 1194 | requires-dist = [ 1195 | { name = "einops" }, 1196 | { name = "gradio", marker = "extra == 'demo'", specifier = "==5.17.1" }, 1197 | { name = "huggingface-hub" }, 1198 | { name = "hydra-core", marker = "extra == 'demo'" }, 1199 | { name = "matplotlib", marker = "extra == 'demo'" }, 1200 | { name = "numpy", specifier = "<2" }, 1201 | { name = "omegaconf", marker = "extra == 'demo'" }, 1202 | { name = "onnxruntime", marker = "extra == 'demo'" }, 1203 | { name = "opencv-python" }, 1204 | { name = "opencv-python", marker = "extra == 'demo'" }, 1205 | { name = "pillow" }, 1206 | { name = "requests", marker = "extra == 'demo'" }, 1207 | { name = "safetensors" }, 1208 | { name = "scipy", marker = "extra == 'demo'" }, 1209 | { name = "tqdm", marker = "extra == 'demo'" }, 1210 | { name = "trimesh", marker = "extra == 'demo'" }, 1211 | { name = "viser", marker = "extra == 'demo'", specifier = "==0.2.23" }, 1212 | ] 1213 | provides-extras = ["demo"] 1214 | 1215 | [[package]] 1216 | name = "virtualenv" 1217 | version = "20.34.0" 1218 | source = { registry = "https://pypi.org/simple" } 1219 | dependencies = [ 1220 | { name = "distlib" }, 1221 | { name = "filelock" }, 1222 | { name = "platformdirs" }, 1223 | ] 1224 | sdist = { url = "https://files.pythonhosted.org/packages/1c/14/37fcdba2808a6c615681cd216fecae00413c9dab44fb2e57805ecf3eaee3/virtualenv-20.34.0.tar.gz", hash = "sha256:44815b2c9dee7ed86e387b842a84f20b93f7f417f95886ca1996a72a4138eb1a", size = 6003808 } 1225 | wheels = [ 1226 | { url = "https://files.pythonhosted.org/packages/76/06/04c8e804f813cf972e3262f3f8584c232de64f0cde9f703b46cf53a45090/virtualenv-20.34.0-py3-none-any.whl", hash = "sha256:341f5afa7eee943e4984a9207c025feedd768baff6753cd660c857ceb3e36026", size = 5983279 }, 1227 | ] 1228 | 1229 | [[package]] 1230 | name = "wcwidth" 1231 | version = "0.2.13" 1232 | source = { registry = "https://pypi.org/simple" } 1233 | sdist = { url = "https://files.pythonhosted.org/packages/6c/63/53559446a878410fc5a5974feb13d31d78d752eb18aeba59c7fef1af7598/wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5", size = 101301 } 1234 | wheels = [ 1235 | { url = "https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859", size = 34166 }, 1236 | ] 1237 | --------------------------------------------------------------------------------