├── tests ├── .gitignore └── simple_test.py ├── src └── olah │ ├── __init__.py │ ├── auth │ └── __init__.py │ ├── cache │ ├── __init__.py │ ├── stat.py │ ├── bitset.py │ └── olah_cache.py │ ├── mirror │ ├── __init__.py │ ├── meta.py │ └── repos.py │ ├── proxy │ ├── __init__.py │ ├── lfs.py │ ├── commits.py │ ├── meta.py │ ├── tree.py │ ├── pathsinfo.py │ └── files.py │ ├── utils │ ├── __init__.py │ ├── file_utils.py │ ├── olah_utils.py │ ├── rule_utils.py │ ├── cache_utils.py │ ├── disk_utils.py │ ├── logging.py │ ├── zip_utils.py │ ├── url_utils.py │ └── repo_utils.py │ ├── database │ ├── __init__.py │ └── models.py │ ├── constants.py │ ├── static │ ├── repos.html │ └── index.html │ ├── errors.py │ └── configs.py ├── docker ├── up │ ├── repos │ │ └── .gitignore │ ├── mirrors │ │ └── .gitignore │ ├── docker-compose.yml │ └── configs.toml ├── build@source │ └── dockerfile └── build@pypi │ └── dockerfile ├── environment.yml ├── .dockerignore ├── docs ├── zh │ ├── main.md │ └── quickstart.md └── en │ ├── main.md │ └── quickstart.md ├── requirements.txt ├── .vscode ├── launch.json └── tasks.json ├── .github ├── dependabot.yml └── workflows │ ├── dev.yml │ ├── release.yml │ ├── docker-image-tag-commit.yml │ └── docker-image-tag-version.yml ├── CONTRIBUTING.md ├── LICENSE ├── assets └── full_configs.toml ├── pyproject.toml ├── .gitignore ├── README_zh.md └── README.md /tests/.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/olah/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/olah/auth/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/olah/cache/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/olah/mirror/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/olah/proxy/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/olah/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/olah/database/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docker/up/repos/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /docker/up/mirrors/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: olah 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python=3.12 6 | 7 | - pip: 8 | - -e . 9 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .git/ 2 | .github/ 3 | .vscode/ 4 | assets/ 5 | benchmark/ 6 | cache/ 7 | docker/ 8 | docs/ 9 | dist/ 10 | logs/ 11 | mirrors_dir/ 12 | olah.egg-info/ 13 | playground/ 14 | scripts/ 15 | tests/ 16 | repos/ 17 | .dockerignore 18 | .gitignore 19 | environment.yml -------------------------------------------------------------------------------- /docs/zh/main.md: -------------------------------------------------------------------------------- 1 |

Olah 文档

2 | 3 | 4 |

5 | 自托管的轻量级HuggingFace镜像服务 6 | 7 | Olah是开源的自托管轻量级HuggingFace镜像服务。`Olah`来源于丘丘人语,在丘丘人语中意味着`你好`。 8 | Olah真正地实现了huggingface资源的`镜像`功能,而不仅仅是一个简单的`反向代理`。 9 | Olah并不会立刻对huggingface全站进行镜像,而是在用户下载的同时在文件块级别对资源进行镜像(或者我们可以说是缓存)。 10 | -------------------------------------------------------------------------------- /docker/build@source/dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.12 2 | 3 | WORKDIR /app 4 | 5 | RUN pip3 install --upgrade pip 6 | 7 | COPY . /app 8 | RUN pip3 install --no-cache-dir -e . 9 | 10 | EXPOSE 8090 11 | 12 | VOLUME /data/repos 13 | VOLUME /data/mirrors 14 | 15 | ENTRYPOINT [ "olah-cli" ] 16 | 17 | CMD ["--repos-path", "/repos"] 18 | -------------------------------------------------------------------------------- /docker/up/docker-compose.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | services: 4 | olah: 5 | image: xiahan2019/olah:lastet 6 | container_name: olah 7 | command: -c /app/configs.toml 8 | ports: 9 | - 8090:8090 10 | volumes: 11 | - ./configs.toml:/app/configs.toml 12 | - ./mirrors:/data/mirrors 13 | - ./repos:/data/repos 14 | -------------------------------------------------------------------------------- /docker/build@pypi/dockerfile: -------------------------------------------------------------------------------- 1 | ARG OLAH_SOURCE=olah==0.3.3 2 | 3 | FROM python:3.12 4 | 5 | ARG OLAH_SOURCE 6 | 7 | WORKDIR /app 8 | 9 | RUN pip3 install --upgrade pip 10 | 11 | RUN pip install --no-cache-dir ${OLAH_SOURCE} 12 | 13 | EXPOSE 8090 14 | 15 | VOLUME /data/repos 16 | VOLUME /data/mirrors 17 | 18 | ENTRYPOINT [ "olah-cli" ] 19 | 20 | CMD ["--repos-path", "/repos"] 21 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi==0.115.2 2 | fastapi-utils==0.7.0 3 | GitPython==3.1.43 4 | httpx==0.27.0 5 | pydantic==2.8.2 6 | pydantic-settings==2.4.0 7 | toml==0.10.2 8 | huggingface_hub==0.26.0 9 | pytest==8.3.3 10 | cachetools==5.4.0 11 | PyYAML==6.0.1 12 | tenacity==8.5.0 13 | peewee==3.17.6 14 | typing_inspect==0.9.0 15 | jinja2==3.1.4 16 | python-multipart==0.0.9 17 | portalocker==3.1.1 18 | aiofiles==24.1.0 -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.2.0", 3 | "configurations": [ 4 | { 5 | "name": "debugpy: server", 6 | "type": "debugpy", 7 | "request": "launch", 8 | "program": "${workspaceFolder}/olah/server.py", 9 | "console": "integratedTerminal", 10 | "args": ["-c", "${workspaceFolder}/assets/full_configs.toml"], 11 | "justMyCode": false 12 | } 13 | ] 14 | } -------------------------------------------------------------------------------- /src/olah/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | import os 9 | 10 | 11 | def make_dirs(path: str): 12 | if os.path.isdir(path): 13 | save_dir = path 14 | else: 15 | save_dir = os.path.dirname(path) 16 | if not os.path.exists(save_dir): 17 | os.makedirs(save_dir, exist_ok=True) 18 | -------------------------------------------------------------------------------- /src/olah/utils/olah_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | import platform 9 | import os 10 | 11 | 12 | def get_olah_path() -> str: 13 | if platform.system() == "Windows": 14 | olah_path = os.path.expanduser("~\\.olah") 15 | else: 16 | olah_path = os.path.expanduser("~/.olah") 17 | return olah_path 18 | -------------------------------------------------------------------------------- /docs/en/main.md: -------------------------------------------------------------------------------- 1 |

Olah Document

2 | 3 |

4 | Self-hosted Lightweight Huggingface Mirror Service 5 | 6 | Olah is a self-hosted lightweight huggingface mirror service. `Olah` means `hello` in Hilichurlian. 7 | Olah implemented the `mirroring` feature for huggingface resources, rather than just a simple `reverse proxy`. 8 | Olah does not immediately mirror the entire huggingface website but mirrors the resources at the file block level when users download them (or we can say cache them). 9 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "pip" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "weekly" 12 | -------------------------------------------------------------------------------- /src/olah/constants.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | import os 9 | 10 | 11 | WORKER_API_TIMEOUT = 15 12 | CHUNK_SIZE = 4096 13 | LFS_FILE_BLOCK = 64 * 1024 * 1024 14 | 15 | DEFAULT_LOGGER_DIR = "./logs" 16 | OLAH_CODE_DIR = os.path.dirname(os.path.abspath(__file__)) 17 | 18 | ORIGINAL_LOC = "oriloc" 19 | 20 | from huggingface_hub.constants import ( 21 | REPO_TYPES_MAPPING, 22 | HUGGINGFACE_CO_URL_TEMPLATE, 23 | HUGGINGFACE_HEADER_X_REPO_COMMIT, 24 | HUGGINGFACE_HEADER_X_LINKED_ETAG, 25 | HUGGINGFACE_HEADER_X_LINKED_SIZE, 26 | ) 27 | -------------------------------------------------------------------------------- /tests/simple_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import time 4 | 5 | from huggingface_hub import snapshot_download 6 | 7 | def test_dataset(): 8 | process = subprocess.Popen(['python', '-m', 'olah.server']) 9 | 10 | os.environ['HF_ENDPOINT'] = 'http://localhost:8090' 11 | snapshot_download(repo_id='Nerfgun3/bad_prompt', repo_type='dataset', 12 | local_dir='./dataset_dir', max_workers=8) 13 | 14 | process.terminate() 15 | 16 | def test_model(): 17 | process = subprocess.Popen(['python', '-m', 'olah.server']) 18 | 19 | os.environ['HF_ENDPOINT'] = 'http://localhost:8090' 20 | snapshot_download(repo_id='prajjwal1/bert-tiny', repo_type='model', 21 | local_dir='./model_dir', max_workers=8) 22 | 23 | process.terminate() -------------------------------------------------------------------------------- /docs/zh/quickstart.md: -------------------------------------------------------------------------------- 1 | ## 快速开始 2 | 在控制台运行以下命令: 3 | ```bash 4 | python -m olah.server 5 | ``` 6 | 7 | 然后将环境变量`HF_ENDPOINT`设置为镜像站点(这里是http://localhost:8090/)。 8 | 9 | Linux: 10 | ```bash 11 | export HF_ENDPOINT=http://localhost:8090 12 | ``` 13 | 14 | Windows Powershell: 15 | ```bash 16 | $env:HF_ENDPOINT = "http://localhost:8090" 17 | ``` 18 | 19 | 从现在开始,HuggingFace库中的所有下载操作都将通过此镜像站点代理进行。 20 | ```bash 21 | pip install -U huggingface_hub 22 | ``` 23 | 24 | ```python 25 | from huggingface_hub import snapshot_download 26 | 27 | snapshot_download(repo_id='Qwen/Qwen-7B', repo_type='model', 28 | local_dir='./model_dir', resume_download=True, 29 | max_workers=8) 30 | 31 | ``` 32 | 33 | 或者你也可以使用huggingface cli直接下载模型和数据集. 34 | 35 | 下载GPT2: 36 | ```bash 37 | huggingface-cli download --resume-download openai-community/gpt2 --local-dir gpt2 38 | ``` 39 | 40 | 下载WikiText: 41 | ```bash 42 | huggingface-cli download --repo-type dataset --resume-download Salesforce/wikitext --local-dir wikitext 43 | ``` 44 | 45 | 您可以查看路径`./repos`,其中存储了所有数据集和模型的缓存。 46 | -------------------------------------------------------------------------------- /src/olah/utils/rule_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | 9 | from typing import Dict, Literal, Optional, Tuple, Union 10 | 11 | from fastapi import FastAPI 12 | from olah.configs import OlahConfig 13 | from .repo_utils import get_org_repo 14 | 15 | 16 | async def check_proxy_rules_hf( 17 | app: FastAPI, 18 | repo_type: Optional[Literal["models", "datasets", "spaces"]], 19 | org: Optional[str], 20 | repo: str, 21 | ) -> bool: 22 | config: OlahConfig = app.state.app_settings.config 23 | org_repo = get_org_repo(org, repo) 24 | return config.proxy.allow(org_repo) 25 | 26 | 27 | async def check_cache_rules_hf( 28 | app: FastAPI, 29 | repo_type: Optional[Literal["models", "datasets", "spaces"]], 30 | org: Optional[str], 31 | repo: str, 32 | ) -> bool: 33 | config: OlahConfig = app.state.app_settings.config 34 | org_repo = get_org_repo(org, repo) 35 | return config.cache.allow(org_repo) 36 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How can I contribute to Olah? 2 | 3 | Everyone is welcome to contribute, and we value everybody's contribution. Code contributions are not the only way to help the community. Answering questions, helping others, and improving the documentation are also immensely valuable. 4 | 5 | It also helps us if you spread the word! Reference the library in blog posts about the awesome projects it made possible, shout out on Twitter every time it has helped you, or simply ⭐️ the repository to say thank you. 6 | 7 | However you choose to contribute, please be mindful and respect our code of conduct. 8 | 9 | ## Ways to contribute 10 | 11 | There are lots of ways you can contribute to Olah: 12 | * Submitting issues on Github to report bugs or make feature requests 13 | * Fixing outstanding issues with the existing code 14 | * Implementing new features 15 | * Contributing to the examples or to the documentation 16 | 17 | *All are equally valuable to the community.* 18 | 19 | #### This guide was heavily inspired by the awesome [transformers guide to contributing](https://github.com/huggingface/transformers/blob/master/CONTRIBUTING.md) 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Vtuber Plan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/olah/database/models.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | import os 9 | from peewee import * 10 | import datetime 11 | 12 | from olah.utils.olah_utils import get_olah_path 13 | 14 | db_path = os.path.join(get_olah_path(), "database.db") 15 | db = SqliteDatabase(db_path) 16 | 17 | class BaseModel(Model): 18 | class Meta: 19 | database = db 20 | 21 | class Token(BaseModel): 22 | token = CharField(unique=True) 23 | first_dt = DateTimeField() 24 | last_dt = DateTimeField() 25 | 26 | class DownloadLogs(BaseModel): 27 | id = CharField(unique=True) 28 | org = CharField() 29 | repo = CharField() 30 | path = CharField() 31 | range_start = BigIntegerField() 32 | range_end = BigIntegerField() 33 | datetime = DateTimeField() 34 | token = CharField() 35 | 36 | class FileLevelLRU(BaseModel): 37 | org = CharField() 38 | repo = CharField() 39 | path = CharField() 40 | datetime = DateTimeField(default=datetime.datetime.now) 41 | 42 | db.connect() 43 | db.create_tables([ 44 | Token, 45 | DownloadLogs, 46 | FileLevelLRU, 47 | ]) 48 | -------------------------------------------------------------------------------- /assets/full_configs.toml: -------------------------------------------------------------------------------- 1 | [basic] 2 | host = "localhost" 3 | port = 8090 4 | ssl-key = "" 5 | ssl-cert = "" 6 | repos-path = "./repos" 7 | cache-size-limit = "" 8 | cache-clean-strategy = "LRU" 9 | hf-scheme = "https" 10 | hf-netloc = "huggingface.co" 11 | hf-lfs-netloc = "cdn-lfs.huggingface.co" 12 | mirror-scheme = "http" 13 | mirror-netloc = "localhost:8090" 14 | mirror-lfs-netloc = "localhost:8090" 15 | mirrors-path = ["./mirrors_dir"] 16 | 17 | [accessibility] 18 | offline = false 19 | 20 | # allow other or will be in whitelist mode. 21 | [[accessibility.proxy]] 22 | repo = "*" 23 | allow = true 24 | 25 | [[accessibility.proxy]] 26 | repo = "*/*" 27 | allow = true 28 | 29 | [[accessibility.proxy]] 30 | repo = "cais/mmlu" 31 | allow = true 32 | 33 | [[accessibility.proxy]] 34 | repo = "adept/fuyu-8b" 35 | allow = false 36 | 37 | [[accessibility.proxy]] 38 | repo = "mistralai/*" 39 | allow = true 40 | 41 | [[accessibility.proxy]] 42 | repo = "mistralai/Mistral.*" 43 | allow = false 44 | use_re = true 45 | 46 | # allow other or will be in whitelist mode. 47 | [[accessibility.cache]] 48 | repo = "*" 49 | allow = true 50 | 51 | [[accessibility.cache]] 52 | repo = "*/*" 53 | allow = true 54 | 55 | [[accessibility.cache]] 56 | repo = "cais/mmlu" 57 | allow = true 58 | 59 | [[accessibility.cache]] 60 | repo = "adept/fuyu-8b" 61 | allow = false 62 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "olah" 7 | version = "0.4.1" 8 | description = "Self-hosted lightweight huggingface mirror." 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: MIT License", 14 | ] 15 | dependencies = [ 16 | "fastapi", "fastapi-utils", "httpx", "numpy", "pydantic<=2.8.2", "pydantic-settings<=2.4.0", "requests", "toml", 17 | "rich>=10.0.0", "shortuuid", "uvicorn", "tenacity>=8.2.2", "pytz", "cachetools", "GitPython", 18 | "PyYAML", "typing_inspect>=0.9.0", "huggingface_hub", "jinja2", "python-multipart", "portalocker", 19 | "aiofiles", "brotli" 20 | ] 21 | 22 | [project.optional-dependencies] 23 | dev = ["black==24.10.0", "pylint==3.3.1", "pytest==8.3.3"] 24 | 25 | [project.urls] 26 | "Homepage" = "https://github.com/vtuber-plan/olah" 27 | "Bug Tracker" = "https://github.com/vtuber-plan/olah/issues" 28 | 29 | [project.scripts] 30 | olah-cli = "olah.server:cli" 31 | 32 | [tool.setuptools.packages.find] 33 | where = ["src"] 34 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 35 | 36 | [tool.wheel] 37 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 38 | -------------------------------------------------------------------------------- /.github/workflows/dev.yml: -------------------------------------------------------------------------------- 1 | name: Olah GitHub Actions for Development 2 | run-name: Olah GitHub Actions for Development 3 | on: 4 | push: 5 | branches: [ "dev" ] 6 | 7 | jobs: 8 | build: 9 | runs-on: ubuntu-latest 10 | strategy: 11 | matrix: 12 | python-version: ["3.9", "3.10", "3.11", "3.12"] 13 | 14 | steps: 15 | - name: Check out repository code 16 | uses: actions/checkout@v4 17 | - name: Set up Apache Arrow 18 | run: | 19 | sudo apt update 20 | sudo apt install -y -V ca-certificates lsb-release wget 21 | wget https://apache.jfrog.io/artifactory/arrow/$(lsb_release --id --short | tr 'A-Z' 'a-z')/apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb 22 | sudo apt install -y -V ./apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb 23 | sudo apt update 24 | sudo apt install -y -V libarrow-dev libarrow-glib-dev libparquet-dev libparquet-glib-dev 25 | - name: Set up Python ${{ matrix.python-version }} 26 | uses: actions/setup-python@v5 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | - name: Install Olah 30 | run: | 31 | cd ${{ github.workspace }} 32 | pip install --upgrade pip 33 | pip install -e . 34 | pip install -r requirements.txt 35 | 36 | - name: Test Olah 37 | run: | 38 | cd ${{ github.workspace }} 39 | python -m pytest tests 40 | -------------------------------------------------------------------------------- /src/olah/mirror/meta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | 9 | from typing import Any, Dict 10 | 11 | 12 | class RepoMeta(object): 13 | def __init__(self) -> None: 14 | self._id = None 15 | self.id = None 16 | self.author = None 17 | self.sha = None 18 | self.lastModified = None 19 | self.private = False 20 | self.gated = False 21 | self.disabled = False 22 | self.tags = [] 23 | self.description = "" 24 | self.paperswithcode_id = None 25 | self.downloads = 0 26 | self.likes = 0 27 | self.cardData = None 28 | self.siblings = None 29 | self.createdAt = None 30 | 31 | def to_dict(self) -> Dict[str, Any]: 32 | return { 33 | "_id": self._id, 34 | "id": self.id, 35 | "author": self.author, 36 | "sha": self.sha, 37 | "lastModified": self.lastModified, 38 | "private": self.private, 39 | "gated": self.gated, 40 | "disabled": self.disabled, 41 | "tags": self.tags, 42 | "description": self.description, 43 | "paperswithcode_id": self.paperswithcode_id, 44 | "downloads": self.downloads, 45 | "likes": self.likes, 46 | "cardData": self.cardData, 47 | "siblings": self.siblings, 48 | "createdAt": self.createdAt, 49 | } 50 | -------------------------------------------------------------------------------- /docker/up/configs.toml: -------------------------------------------------------------------------------- 1 | [basic] 2 | host = "0.0.0.0" 3 | port = 8090 4 | ssl-key = "" 5 | ssl-cert = "" 6 | repos-path = "/data/repos" 7 | cache-size-limit = "" 8 | cache-clean-strategy = "LRU" 9 | hf-scheme = "https" 10 | hf-netloc = "huggingface.co" 11 | hf-lfs-netloc = "cdn-lfs.huggingface.co" 12 | mirror-scheme = "http" 13 | mirror-netloc = "localhost:8090" 14 | mirror-lfs-netloc = "localhost:8090" 15 | mirrors-path = ["/data/mirrors"] 16 | 17 | 18 | [accessibility] 19 | offline = false 20 | 21 | 22 | [[accessibility.proxy]] 23 | repo = "*" 24 | allow = true 25 | 26 | [[accessibility.proxy]] 27 | repo = "*/*" 28 | allow = true 29 | 30 | [[accessibility.proxy]] 31 | repo = "vikp/surya_det3" 32 | allow = false 33 | 34 | # [[accessibility.proxy]] 35 | # repo = "vikp/surya_layout3" 36 | # allow = true 37 | 38 | # [[accessibility.proxy]] 39 | # repo = "vikp/surya_order" 40 | # allow = true 41 | 42 | # [[accessibility.proxy]] 43 | # repo = "vikp/surya_rec2" 44 | # allow = true 45 | 46 | # [[accessibility.proxy]] 47 | # repo = "vikp/surya_tablerec" 48 | # allow = true 49 | 50 | 51 | [[accessibility.cache]] 52 | repo = "*" 53 | allow = true 54 | 55 | [[accessibility.cache]] 56 | repo = "*/*" 57 | allow = true 58 | 59 | [[accessibility.cache]] 60 | repo = "vikp/surya_det3" 61 | allow = false 62 | 63 | # [[accessibility.cache]] 64 | # repo = "vikp/surya_layout3" 65 | # allow = true 66 | 67 | # [[accessibility.cache]] 68 | # repo = "vikp/surya_order" 69 | # allow = true 70 | 71 | # [[accessibility.cache]] 72 | # repo = "vikp/surya_rec2" 73 | # allow = true 74 | 75 | # [[accessibility.cache]] 76 | # repo = "vikp/surya_tablerec" 77 | # allow = true 78 | -------------------------------------------------------------------------------- /src/olah/utils/cache_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | 9 | import json 10 | from typing import Dict, Mapping, Union 11 | 12 | 13 | async def write_cache_request( 14 | save_path: str, 15 | status_code: int, 16 | headers: Union[Dict[str, str], Mapping], 17 | content: bytes, 18 | ) -> None: 19 | """ 20 | Write the request's status code, headers, and content to a cache file. 21 | 22 | Args: 23 | head_path (str): The path to the cache file. 24 | status_code (int): The status code of the request. 25 | headers (Dict[str, str]): The dictionary of response headers. 26 | content (bytes): The content of the request. 27 | 28 | Returns: 29 | None 30 | """ 31 | if not isinstance(headers, dict): 32 | headers = {k.lower(): v for k, v in headers.items()} 33 | rq = { 34 | "status_code": status_code, 35 | "headers": headers, 36 | "content": content.hex(), 37 | } 38 | with open(save_path, "w", encoding="utf-8") as f: 39 | f.write(json.dumps(rq, ensure_ascii=False)) 40 | 41 | 42 | async def read_cache_request(save_path: str) -> Dict[str, str]: 43 | """ 44 | Read the request's status code, headers, and content from a cache file. 45 | 46 | Args: 47 | save_path (str): The path to the cache file. 48 | 49 | Returns: 50 | Dict[str, str]: A dictionary containing the status code, headers, and content of the request. 51 | """ 52 | with open(save_path, "r", encoding="utf-8") as f: 53 | rq = json.loads(f.read()) 54 | 55 | rq["content"] = bytes.fromhex(rq["content"]) 56 | return rq 57 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Olah GitHub Actions to release 2 | run-name: Olah GitHub Actions to release 3 | on: 4 | push: 5 | tags: 6 | - "v*" 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python-version: ["3.12"] 14 | 15 | steps: 16 | - name: Check out repository code 17 | uses: actions/checkout@v4 18 | - name: Set up Apache Arrow 19 | run: | 20 | sudo apt update 21 | sudo apt install -y -V ca-certificates lsb-release wget 22 | wget https://apache.jfrog.io/artifactory/arrow/$(lsb_release --id --short | tr 'A-Z' 'a-z')/apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb 23 | sudo apt install -y -V ./apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb 24 | sudo apt update 25 | sudo apt install -y -V libarrow-dev libarrow-glib-dev libparquet-dev libparquet-glib-dev 26 | - name: Set up Python ${{ matrix.python-version }} 27 | uses: actions/setup-python@v5 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | - name: Install Olah 31 | run: | 32 | cd ${{ github.workspace }} 33 | pip install --upgrade pip 34 | pip install -e . 35 | pip install -r requirements.txt 36 | 37 | - name: Test Olah 38 | run: | 39 | cd ${{ github.workspace }} 40 | python -m pytest tests 41 | 42 | - name: Build Olah 43 | run: | 44 | cd ${{ github.workspace }} 45 | pip install build 46 | python -m build 47 | 48 | - name: Release 49 | uses: "marvinpinto/action-automatic-releases@latest" 50 | with: 51 | repo_token: "${{ secrets.GITHUB_TOKEN }}" 52 | prerelease: true 53 | files: | 54 | dist/*.tar.gz 55 | dist/*.whl -------------------------------------------------------------------------------- /docs/en/quickstart.md: -------------------------------------------------------------------------------- 1 | 2 | ## Quick Start 3 | Run the command in the console: 4 | ```bash 5 | python -m olah.server 6 | ``` 7 | 8 | Then set the Environment Variable `HF_ENDPOINT` to the mirror site (Here is http://localhost:8090). 9 | 10 | Linux: 11 | ```bash 12 | export HF_ENDPOINT=http://localhost:8090 13 | ``` 14 | 15 | Windows Powershell: 16 | ```bash 17 | $env:HF_ENDPOINT = "http://localhost:8090" 18 | ``` 19 | 20 | Starting from now on, all download operations in the HuggingFace library will be proxied through this mirror site. 21 | ```bash 22 | pip install -U huggingface_hub 23 | ``` 24 | 25 | ```python 26 | from huggingface_hub import snapshot_download 27 | 28 | snapshot_download(repo_id='Qwen/Qwen-7B', repo_type='model', 29 | local_dir='./model_dir', resume_download=True, 30 | max_workers=8) 31 | ``` 32 | 33 | Or you can download models and datasets by using huggingface cli. 34 | 35 | Download GPT2: 36 | ```bash 37 | huggingface-cli download --resume-download openai-community/gpt2 --local-dir gpt2 38 | ``` 39 | 40 | Download WikiText: 41 | ```bash 42 | huggingface-cli download --repo-type dataset --resume-download Salesforce/wikitext --local-dir wikitext 43 | ``` 44 | 45 | You can check the path `./repos`, in which olah stores all cached datasets and models. 46 | 47 | ## Start the server 48 | Run the command in the console: 49 | ```bash 50 | python -m olah.server 51 | ``` 52 | 53 | Or you can specify the host address and listening port: 54 | ```bash 55 | python -m olah.server --host localhost --port 8090 56 | ``` 57 | **Note: Please change --mirror-netloc and --mirror-lfs-netloc to the actual URLs of the mirror sites when modifying the host and port.** 58 | ```bash 59 | python -m olah.server --host 192.168.1.100 --port 8090 --mirror-netloc 192.168.1.100:8090 60 | ``` 61 | 62 | The default mirror cache path is `./repos`, you can change it by `--repos-path` parameter: 63 | ```bash 64 | python -m olah.server --host localhost --port 8090 --repos-path ./hf_mirrors 65 | ``` 66 | 67 | **Note that the cached data between different versions cannot be migrated. Please delete the cache folder before upgrading to the latest version of Olah.** 68 | -------------------------------------------------------------------------------- /src/olah/proxy/lfs.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | import os 9 | from typing import Literal 10 | from fastapi import FastAPI, Header, Request 11 | 12 | from olah.proxy.files import _file_realtime_stream 13 | from olah.utils.file_utils import make_dirs 14 | 15 | 16 | async def lfs_head_generator( 17 | app, dir1: str, dir2: str, hash_repo: str, hash_file: str, request: Request 18 | ): 19 | # save 20 | repos_path = app.state.app_settings.config.repos_path 21 | head_path = os.path.join( 22 | repos_path, f"lfs/heads/{dir1}/{dir2}/{hash_repo}/{hash_file}" 23 | ) 24 | save_path = os.path.join( 25 | repos_path, f"lfs/files/{dir1}/{dir2}/{hash_repo}/{hash_file}" 26 | ) 27 | make_dirs(head_path) 28 | make_dirs(save_path) 29 | 30 | # use_cache = os.path.exists(head_path) and os.path.exists(save_path) 31 | allow_cache = True 32 | 33 | # proxy 34 | return _file_realtime_stream( 35 | app=app, 36 | save_path=save_path, 37 | head_path=head_path, 38 | url=str(request.url), 39 | request=request, 40 | method="HEAD", 41 | allow_cache=allow_cache, 42 | commit=None, 43 | ) 44 | 45 | 46 | async def lfs_get_generator( 47 | app, dir1: str, dir2: str, hash_repo: str, hash_file: str, request: Request 48 | ): 49 | # save 50 | repos_path = app.state.app_settings.config.repos_path 51 | head_path = os.path.join( 52 | repos_path, f"lfs/heads/{dir1}/{dir2}/{hash_repo}/{hash_file}" 53 | ) 54 | save_path = os.path.join( 55 | repos_path, f"lfs/files/{dir1}/{dir2}/{hash_repo}/{hash_file}" 56 | ) 57 | make_dirs(head_path) 58 | make_dirs(save_path) 59 | 60 | # use_cache = os.path.exists(head_path) and os.path.exists(save_path) 61 | allow_cache = True 62 | 63 | # proxy 64 | return _file_realtime_stream( 65 | app=app, 66 | save_path=save_path, 67 | head_path=head_path, 68 | url=str(request.url), 69 | request=request, 70 | method="GET", 71 | allow_cache=allow_cache, 72 | commit=None, 73 | ) 74 | -------------------------------------------------------------------------------- /src/olah/cache/stat.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | import argparse 9 | import os 10 | import sys 11 | from olah.cache.olah_cache import OlahCache 12 | 13 | def get_size_human(size: int) -> str: 14 | if size > 1024 * 1024 * 1024: 15 | return f"{int(size / (1024 * 1024 * 1024)):.4f}GB" 16 | elif size > 1024 * 1024: 17 | return f"{int(size / (1024 * 1024)):.4f}MB" 18 | elif size > 1024: 19 | return f"{int(size / (1024)):.4f}KB" 20 | else: 21 | return f"{size:.4f}B" 22 | 23 | def insert_newlines(input_str, every=10): 24 | return '\n'.join(input_str[i:i+every] for i in range(0, len(input_str), every)) 25 | 26 | if __name__ == "__main__": 27 | parser = argparse.ArgumentParser(description="Olah Cache Visualization Tool.") 28 | parser.add_argument("--file", "-f", type=str, required=True, help="The path of Olah cache file") 29 | parser.add_argument("--export", "-e", type=str, default="", help="Export the cached file if all blocks are cached") 30 | args = parser.parse_args() 31 | print(args) 32 | 33 | with open(args.file, "rb") as f: 34 | f.seek(0, os.SEEK_END) 35 | bin_size = f.tell() 36 | 37 | try: 38 | cache = OlahCache(args.file) 39 | except Exception as e: 40 | print(e) 41 | sys.exit(1) 42 | print(f"File: {args.file}") 43 | print(f"Olah Cache Version: {cache.header.version}") 44 | print(f"File Size: {get_size_human(cache.header.file_size)}") 45 | print(f"Cache Total Size: {get_size_human(bin_size)}") 46 | print(f"Block Size: {cache.header.block_size}") 47 | print(f"Block Number: {cache.header.block_number}") 48 | print(f"Cache Status: ") 49 | cache_status = cache.header.block_mask.__str__()[:cache.header._block_number] 50 | print(insert_newlines(cache_status, every=50)) 51 | 52 | if args.export != "": 53 | if all([c == "1" for c in cache_status]): 54 | with open(args.file, "rb") as f: 55 | f.seek(cache._get_header_size(), os.SEEK_SET) 56 | with open(args.export, "wb") as fout: 57 | fout.write(f.read()) 58 | else: 59 | print("Some blocks are not cached, so the export is skipped.") -------------------------------------------------------------------------------- /src/olah/static/repos.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 7 | 24 | 25 | 26 | 27 |

28 |

Huggingface Mirror Repositories

29 | 30 |
31 |

Data Sets

32 |
33 | {% for dataset_repo in datasets_repos %} 34 |
35 |
36 |
{{ dataset_repo }}
37 |
38 |
39 | {% endfor %} 40 |
41 |
42 | 43 |
44 |

Models

45 |
46 | {% for model_repo in models_repos %} 47 |
48 |
49 |
{{ model_repo }}
50 |
51 |
52 | {% endfor %} 53 |
54 |
55 | 56 |
57 |

Spaces

58 |
59 | {% for space_repo in spaces_repos %} 60 |
61 |
62 |
{{ space_repo }}
63 |
64 |
65 | {% endfor %} 66 |
67 |
68 |
69 | 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /src/olah/errors.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | from fastapi import Response 9 | from fastapi.responses import JSONResponse 10 | 11 | 12 | def error_repo_not_found() -> JSONResponse: 13 | return JSONResponse( 14 | content={"error": "Repository not found"}, 15 | headers={ 16 | "x-error-code": "RepoNotFound", 17 | "x-error-message": "Repository not found", 18 | }, 19 | status_code=401, 20 | ) 21 | 22 | 23 | def error_page_not_found() -> JSONResponse: 24 | return JSONResponse( 25 | content={"error": "Sorry, we can't find the page you are looking for."}, 26 | headers={ 27 | "x-error-code": "RepoNotFound", 28 | "x-error-message": "Sorry, we can't find the page you are looking for.", 29 | }, 30 | status_code=404, 31 | ) 32 | 33 | 34 | def error_entry_not_found_branch(branch: str, path: str) -> Response: 35 | return Response( 36 | headers={ 37 | "x-error-code": "EntryNotFound", 38 | "x-error-message": f'{path} does not exist on "{branch}"', 39 | }, 40 | status_code=404, 41 | ) 42 | 43 | 44 | def error_entry_not_found() -> Response: 45 | return Response( 46 | headers={ 47 | "x-error-code": "EntryNotFound", 48 | "x-error-message": "Entry not found", 49 | }, 50 | status_code=404, 51 | ) 52 | 53 | 54 | def error_revision_not_found(revision: str) -> Response: 55 | return JSONResponse( 56 | content={"error": f"Invalid rev id: {revision}"}, 57 | headers={ 58 | "x-error-code": "RevisionNotFound", 59 | "x-error-message": f"Invalid rev id: {revision}", 60 | }, 61 | status_code=404, 62 | ) 63 | 64 | 65 | # Olah Custom Messages 66 | def error_proxy_timeout() -> Response: 67 | return Response( 68 | headers={ 69 | "x-error-code": "ProxyTimeout", 70 | "x-error-message": "Proxy Timeout", 71 | }, 72 | status_code=504, 73 | ) 74 | 75 | 76 | def error_proxy_invalid_data() -> Response: 77 | return Response( 78 | headers={ 79 | "x-error-code": "ProxyInvalidData", 80 | "x-error-message": "Proxy Invalid Data", 81 | }, 82 | status_code=504, 83 | ) 84 | -------------------------------------------------------------------------------- /.vscode/tasks.json: -------------------------------------------------------------------------------- 1 | { 2 | // See https://go.microsoft.com/fwlink/?LinkId=733558 3 | // for the documentation about the tasks.json format 4 | "version": "2.0.0", 5 | "tasks": [ 6 | { 7 | "label": "docker: compose up", 8 | "type": "shell", 9 | "options": { 10 | "cwd": "${workspaceFolder}/docker/up/" 11 | }, 12 | "command": "docker compose -p olah up", 13 | }, 14 | { 15 | "label": "docker: build main", 16 | "type": "shell", 17 | "command": "docker build -t xiahan2019/olah:main -f ./docker/build@source/dockerfile .", 18 | }, 19 | { 20 | "label": "docker: build 0.3.3", 21 | "type": "shell", 22 | "command": "docker build -t xiahan2019/olah:0.3.3 -f ./docker/build@pypi/dockerfile .", 23 | }, 24 | { 25 | "label": "huggingface-cli: download bert-base-uncased", 26 | "type": "shell", 27 | "options": { 28 | "cwd": "${workspaceFolder}", 29 | "env": { 30 | "HF_ENDPOINT": "http://localhost:18090", 31 | "HF_HUB_ETAG_TIMEOUT": "100", 32 | "HF_HUB_DOWNLOAD_TIMEOUT": "100" 33 | } 34 | }, 35 | "command": "huggingface-cli download bert-base-uncased --revision main --cache-dir ./cache/huggingface/hub" 36 | }, 37 | { 38 | "label": "huggingface-cli: download cais/mmlu", 39 | "type": "shell", 40 | "options": { 41 | "cwd": "${workspaceFolder}", 42 | "env": { 43 | "HF_ENDPOINT": "http://localhost:8090", 44 | "HF_HUB_ETAG_TIMEOUT": "100", 45 | "HF_HUB_DOWNLOAD_TIMEOUT": "100" 46 | } 47 | }, 48 | "command": "huggingface-cli download cais/mmlu --repo-type dataset --revision main --cache-dir ./cache/huggingface/hub" 49 | }, 50 | { 51 | "label": "conda: run olah-cli", 52 | "type": "shell", 53 | "command": [ 54 | "conda run --no-capture-output -n olah olah-cli -c ./assets/full_configs.toml" 55 | ], 56 | "problemMatcher": [] 57 | }, 58 | { 59 | "label": "conda: create env", 60 | "type": "shell", 61 | "command": [ 62 | "conda env create -f ./environment.yml" 63 | ] 64 | } 65 | ] 66 | } -------------------------------------------------------------------------------- /src/olah/cache/bitset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | 9 | class Bitset: 10 | def __init__(self, size) -> None: 11 | """ 12 | Initializes a Bitset object with a given size. 13 | 14 | Args: 15 | size (int): The number of bits in the Bitset. 16 | """ 17 | self.size = size 18 | self.bits = bytearray((0,) * ((size + 7) // 8)) 19 | 20 | def set(self, index: int) -> None: 21 | """ 22 | Sets the bit at the specified index to 1. 23 | 24 | Args: 25 | index (int): The index of the bit to be set. 26 | 27 | Raises: 28 | IndexError: If the index is out of range. 29 | """ 30 | if index < 0 or index >= self.size: 31 | raise IndexError("Index out of range") 32 | byte_index = index // 8 33 | bit_index = index % 8 34 | self.bits[byte_index] |= 1 << bit_index 35 | 36 | def clear(self, index: int) -> None: 37 | """ 38 | Sets the bit at the specified index to 0. 39 | 40 | Args: 41 | index (int): The index of the bit to be cleared. 42 | 43 | Raises: 44 | IndexError: If the index is out of range. 45 | """ 46 | if index < 0 or index >= self.size: 47 | raise IndexError("Index out of range") 48 | self._resize_if_needed(index) 49 | byte_index = index // 8 50 | bit_index = index % 8 51 | self.bits[byte_index] &= ~(1 << bit_index) 52 | 53 | def test(self, index: int) -> None: 54 | """ 55 | Checks the value of the bit at the specified index. 56 | 57 | Args: 58 | index (int): The index of the bit to be checked. 59 | 60 | Returns: 61 | bool: True if the bit is set (1), False if the bit is cleared (0). 62 | 63 | Raises: 64 | IndexError: If the index is out of range. 65 | """ 66 | if index < 0 or index >= self.size: 67 | raise IndexError("Index out of range") 68 | byte_index = index // 8 69 | bit_index = index % 8 70 | return bool(self.bits[byte_index] & (1 << bit_index)) 71 | 72 | def __str__(self): 73 | """ 74 | Returns a string representation of the Bitset. 75 | 76 | Returns: 77 | str: A string representation of the Bitset object, showing the binary representation of each byte. 78 | """ 79 | return "".join(bin(byte)[2:].zfill(8)[::-1] for byte in self.bits) 80 | -------------------------------------------------------------------------------- /.github/workflows/docker-image-tag-commit.yml: -------------------------------------------------------------------------------- 1 | 2 | name: Docker Image Build/Publish tag with commit 3 | 4 | on: 5 | push: 6 | branches: 7 | - 'dev' 8 | paths: 9 | - docker/build@source/dockerfile 10 | - .github/workflows/docker-image-tag-commit.yml 11 | workflow_dispatch: 12 | inputs: 13 | commit_id: 14 | description: olah commit id(like 'main' 'db028a3b') 15 | required: true 16 | default: main 17 | 18 | jobs: 19 | build-and-push-docker-image: 20 | name: Build Docker image and push to repositories 21 | runs-on: ubuntu-latest 22 | 23 | strategy: 24 | matrix: 25 | BRANCH_CHECKOUT: 26 | - ${{ github.event.inputs.commit_id || 'main' }} 27 | platforms: 28 | - linux/amd64,linux/arm64 29 | 30 | steps: 31 | - name: Checkout code 32 | uses: actions/checkout@v4 33 | 34 | - name: Set up QEMU 35 | uses: docker/setup-qemu-action@v3 36 | 37 | - name: Set up Docker Buildx 38 | uses: docker/setup-buildx-action@v3 39 | 40 | - name: Login to Docker Hub 41 | uses: docker/login-action@v3 42 | with: 43 | username: ${{ secrets.DOCKERHUB_USERNAME }} 44 | password: ${{ secrets.DOCKERHUB_TOKEN }} 45 | 46 | - name: Login to GitHub Container Registry 47 | uses: docker/login-action@v3 48 | with: 49 | registry: ghcr.io 50 | username: ${{ github.repository_owner }} 51 | password: ${{ secrets.GITHUB_TOKEN }} 52 | 53 | - name: Checkout commit 54 | run: | 55 | git checkout ${{ matrix.BRANCH_CHECKOUT }} 56 | 57 | - name: Set env git short head 58 | working-directory: ./olah 59 | run: echo "COMMIT_SHORT=$(git rev-parse --short HEAD)" >> $GITHUB_ENV 60 | 61 | - name: Meta data image 62 | id: meta 63 | uses: docker/metadata-action@v5 64 | with: 65 | images: | 66 | ${{ secrets.DOCKERHUB_USERNAME }}/olah 67 | ghcr.io/${{ github.repository_owner }}/olah 68 | tags: | 69 | type=raw,value=${{ matrix.BRANCH_CHECKOUT }} 70 | type=raw,value=${{ env.COMMIT_SHORT }} 71 | flavor: | 72 | latest=false 73 | 74 | - name: Build push image 75 | id: build 76 | uses: docker/build-push-action@v5 77 | with: 78 | context: . 79 | file: ./docker/build@source/dockerfile 80 | platforms: ${{ matrix.platforms }} 81 | push: true 82 | tags: ${{ steps.meta.outputs.tags }} 83 | labels: ${{ steps.meta.outputs.labels }} 84 | 85 | - name: Print image digest 86 | run: echo ${{ steps.build.outputs.digest }} 87 | -------------------------------------------------------------------------------- /.github/workflows/docker-image-tag-version.yml: -------------------------------------------------------------------------------- 1 | 2 | name: Docker Image Build/Publish tag with version 3 | 4 | on: 5 | push: 6 | tags: 7 | - "v*" 8 | branches: 9 | - 'main' 10 | paths: 11 | - docker/build@pypi/dockerfile 12 | - .github/workflows/docker-image-tag-version.yml 13 | workflow_dispatch: 14 | inputs: 15 | olah_version: 16 | description: olah version of pypi 17 | required: true 18 | default: 0.3.3 19 | 20 | jobs: 21 | build-and-push-docker-image: 22 | name: Build Docker image and push to repositories 23 | runs-on: ubuntu-latest 24 | 25 | strategy: 26 | matrix: 27 | platform: 28 | - linux/amd64,linux/arm64 29 | 30 | steps: 31 | - name: Set OLAH_SOURCE/OLAH_TAG variable for push or workflow_dispatch 32 | id: set_olah_source 33 | run: | 34 | if [[ "${{ github.event_name }}" == "push" ]]; then 35 | if [[ "${{ github.ref }}" == refs/tags/* ]]; then 36 | TAG_NAME=${GITHUB_REF##*/} 37 | VERSION=${TAG_NAME#v} 38 | OLAH_SOURCE="https://github.com/${{ github.repository }}/releases/download/${TAG_NAME}/olah-${VERSION}-py3-none-any.whl" 39 | OLAH_TAG="${VERSION}" 40 | elif [[ "${{ github.ref }}" == refs/heads/main ]]; then 41 | OLAH_SOURCE="olah" 42 | OLAH_TAG="lastet" 43 | fi 44 | elif [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then 45 | OLAH_SOURCE="olah==${{ github.event.inputs.olah_version }}" 46 | OLAH_TAG="${{ github.event.inputs.olah_version }}" 47 | fi 48 | echo "OLAH_SOURCE=$OLAH_SOURCE" >> $GITHUB_ENV 49 | echo "OLAH_TAG=$OLAH_TAG" >> $GITHUB_ENV 50 | 51 | - name: Checkout code 52 | uses: actions/checkout@v4 53 | 54 | - name: Set up QEMU 55 | uses: docker/setup-qemu-action@v3 56 | 57 | - name: Set up Docker Buildx 58 | uses: docker/setup-buildx-action@v3 59 | 60 | - name: Login to Docker Hub 61 | uses: docker/login-action@v3 62 | with: 63 | username: ${{ secrets.DOCKERHUB_USERNAME }} 64 | password: ${{ secrets.DOCKERHUB_TOKEN }} 65 | 66 | - name: Login to GitHub Container Registry 67 | uses: docker/login-action@v3 68 | with: 69 | registry: ghcr.io 70 | username: ${{ github.repository_owner }} 71 | password: ${{ secrets.GITHUB_TOKEN }} 72 | 73 | - name: Meta data image 74 | id: meta 75 | uses: docker/metadata-action@v5 76 | with: 77 | images: | 78 | ${{ secrets.DOCKERHUB_USERNAME }}/olah 79 | ghcr.io/${{ github.repository_owner }}/olah 80 | tags: | 81 | type=raw,value=${{ env.OLAH_TAG }} 82 | type=raw,value=lastet 83 | flavor: | 84 | latest=false 85 | 86 | - name: Build push image 87 | id: build 88 | uses: docker/build-push-action@v5 89 | with: 90 | context: . 91 | file: ./docker/build@pypi/dockerfile 92 | build-args: | 93 | OLAH_SOURCE=${{ env.OLAH_SOURCE }} 94 | platforms: ${{ matrix.platform }} 95 | push: true 96 | tags: ${{ steps.meta.outputs.tags }} 97 | labels: ${{ steps.meta.outputs.labels }} 98 | 99 | - name: Print image digest 100 | run: echo ${{ steps.build.outputs.digest }} 101 | -------------------------------------------------------------------------------- /src/olah/proxy/commits.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | import os 9 | from typing import Dict, Literal, Mapping, Optional 10 | from urllib.parse import urljoin 11 | from fastapi import FastAPI, Request 12 | 13 | import httpx 14 | from olah.constants import CHUNK_SIZE, WORKER_API_TIMEOUT 15 | 16 | from olah.utils.cache_utils import read_cache_request, write_cache_request 17 | from olah.utils.rule_utils import check_cache_rules_hf 18 | from olah.utils.repo_utils import get_org_repo 19 | from olah.utils.file_utils import make_dirs 20 | 21 | 22 | async def _commits_cache_generator(save_path: str): 23 | cache_rq = await read_cache_request(save_path) 24 | yield cache_rq["status_code"] 25 | yield cache_rq["headers"] 26 | yield cache_rq["content"] 27 | 28 | 29 | async def _commits_proxy_generator( 30 | app: FastAPI, 31 | headers: Dict[str, str], 32 | commits_url: str, 33 | method: str, 34 | params: Mapping[str, str], 35 | allow_cache: bool, 36 | save_path: str, 37 | ): 38 | async with httpx.AsyncClient(follow_redirects=True) as client: 39 | content_chunks = [] 40 | async with client.stream( 41 | method=method, 42 | url=commits_url, 43 | params=params, 44 | headers=headers, 45 | timeout=WORKER_API_TIMEOUT, 46 | ) as response: 47 | response_status_code = response.status_code 48 | response_headers = response.headers 49 | yield response_status_code 50 | yield response_headers 51 | 52 | async for raw_chunk in response.aiter_raw(): 53 | if not raw_chunk: 54 | continue 55 | content_chunks.append(raw_chunk) 56 | yield raw_chunk 57 | 58 | content = bytearray() 59 | for chunk in content_chunks: 60 | content += chunk 61 | 62 | if allow_cache and response_status_code == 200: 63 | make_dirs(save_path) 64 | await write_cache_request( 65 | save_path, response_status_code, response_headers, bytes(content) 66 | ) 67 | 68 | 69 | async def commits_generator( 70 | app: FastAPI, 71 | repo_type: Literal["models", "datasets", "spaces"], 72 | org: str, 73 | repo: str, 74 | commit: str, 75 | override_cache: bool, 76 | method: str, 77 | authorization: Optional[str], 78 | ): 79 | headers = {} 80 | if authorization is not None: 81 | headers["authorization"] = authorization 82 | 83 | org_repo = get_org_repo(org, repo) 84 | # save 85 | repos_path = app.state.app_settings.config.repos_path 86 | save_dir = os.path.join( 87 | repos_path, f"api/{repo_type}/{org_repo}/commits/{commit}" 88 | ) 89 | save_path = os.path.join(save_dir, f"commits_{method}.json") 90 | 91 | use_cache = os.path.exists(save_path) 92 | allow_cache = await check_cache_rules_hf(app, repo_type, org, repo) 93 | 94 | org_repo = get_org_repo(org, repo) 95 | commits_url = urljoin( 96 | app.state.app_settings.config.hf_url_base(), 97 | f"/api/{repo_type}/{org_repo}/commits/{commit}", 98 | ) 99 | # proxy 100 | if use_cache and not override_cache: 101 | async for item in _commits_cache_generator(save_path): 102 | yield item 103 | else: 104 | async for item in _commits_proxy_generator( 105 | app, headers, commits_url, method, {}, allow_cache, save_path 106 | ): 107 | yield item 108 | -------------------------------------------------------------------------------- /src/olah/proxy/meta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | import os 9 | import shutil 10 | import tempfile 11 | from typing import Dict, Literal, Optional, AsyncGenerator, Union 12 | from urllib.parse import urljoin 13 | from fastapi import FastAPI, Request 14 | 15 | import httpx 16 | from olah.constants import CHUNK_SIZE, WORKER_API_TIMEOUT 17 | 18 | from olah.utils.cache_utils import read_cache_request, write_cache_request 19 | from olah.utils.rule_utils import check_cache_rules_hf 20 | from olah.utils.repo_utils import get_org_repo 21 | from olah.utils.file_utils import make_dirs 22 | 23 | async def _meta_cache_generator(save_path: str) -> AsyncGenerator[Union[int, Dict[str, str], bytes], None]: 24 | cache_rq = await read_cache_request(save_path) 25 | yield cache_rq["headers"] 26 | yield cache_rq["content"] 27 | 28 | 29 | async def _meta_proxy_generator( 30 | app: FastAPI, 31 | headers: Dict[str, str], 32 | meta_url: str, 33 | method: str, 34 | allow_cache: bool, 35 | save_path: str, 36 | ) -> AsyncGenerator[Union[int, Dict[str, str], bytes], None]: 37 | async with httpx.AsyncClient(follow_redirects=True) as client: 38 | content_chunks = [] 39 | async with client.stream( 40 | method=method, 41 | url=meta_url, 42 | headers=headers, 43 | timeout=WORKER_API_TIMEOUT, 44 | ) as response: 45 | response_status_code = response.status_code 46 | response_headers = response.headers 47 | yield response_headers 48 | 49 | async for raw_chunk in response.aiter_raw(): 50 | if not raw_chunk: 51 | continue 52 | content_chunks.append(raw_chunk) 53 | yield raw_chunk 54 | 55 | content = bytearray() 56 | for chunk in content_chunks: 57 | content += chunk 58 | 59 | if allow_cache and response_status_code == 200: 60 | await write_cache_request( 61 | save_path, response_status_code, response_headers, bytes(content) 62 | ) 63 | 64 | 65 | async def meta_generator( 66 | app: FastAPI, 67 | repo_type: Literal["models", "datasets", "spaces"], 68 | org: str, 69 | repo: str, 70 | commit: str, 71 | override_cache: bool, 72 | method: str, 73 | authorization: Optional[str], 74 | ) -> AsyncGenerator[Union[int, Dict[str, str], bytes], None]: 75 | headers = {} 76 | if authorization is not None: 77 | headers["authorization"] = authorization 78 | 79 | org_repo = get_org_repo(org, repo) 80 | # save 81 | repos_path = app.state.app_settings.config.repos_path 82 | save_dir = os.path.join( 83 | repos_path, f"api/{repo_type}/{org_repo}/revision/{commit}" 84 | ) 85 | save_path = os.path.join(save_dir, f"meta_{method}.json") 86 | make_dirs(save_path) 87 | 88 | use_cache = os.path.exists(save_path) 89 | allow_cache = await check_cache_rules_hf(app, repo_type, org, repo) 90 | 91 | org_repo = get_org_repo(org, repo) 92 | meta_url = urljoin( 93 | app.state.app_settings.config.hf_url_base(), 94 | f"/api/{repo_type}/{org_repo}/revision/{commit}", 95 | ) 96 | # proxy 97 | if use_cache and not override_cache: 98 | async for item in _meta_cache_generator(save_path): 99 | yield item 100 | else: 101 | async for item in _meta_proxy_generator( 102 | app, headers, meta_url, method, allow_cache, save_path 103 | ): 104 | yield item 105 | -------------------------------------------------------------------------------- /src/olah/utils/disk_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | import datetime 9 | import os 10 | 11 | import time 12 | from typing import List, Optional, Tuple 13 | 14 | 15 | def get_folder_size(folder_path: str) -> int: 16 | total_size = 0 17 | for dirpath, dirnames, filenames in os.walk(folder_path): 18 | for f in filenames: 19 | fp = os.path.join(dirpath, f) 20 | total_size += os.path.getsize(fp) 21 | return total_size 22 | 23 | def sort_files_by_access_time(folder_path: str) -> List[Tuple[str, datetime.datetime]]: 24 | files = [] 25 | 26 | # Get all file paths and time 27 | for dirpath, dirnames, filenames in os.walk(folder_path): 28 | for f in filenames: 29 | file_path = os.path.join(dirpath, f) 30 | if not os.path.isfile(file_path): 31 | continue 32 | access_time = datetime.datetime.fromtimestamp(os.path.getatime(file_path)) 33 | files.append((file_path, access_time)) 34 | 35 | # Sort by accesstime 36 | sorted_files = sorted(files, key=lambda x: x[1]) 37 | 38 | return sorted_files 39 | 40 | def sort_files_by_modify_time(folder_path: str) -> List[Tuple[str, datetime.datetime]]: 41 | files = [] 42 | 43 | # Get all file paths and time 44 | for dirpath, dirnames, filenames in os.walk(folder_path): 45 | for f in filenames: 46 | file_path = os.path.join(dirpath, f) 47 | if not os.path.isfile(file_path): 48 | continue 49 | access_time = datetime.datetime.fromtimestamp(os.path.getmtime(file_path)) 50 | files.append((file_path, access_time)) 51 | 52 | # Sort by modify time 53 | sorted_files = sorted(files, key=lambda x: x[1]) 54 | 55 | return sorted_files 56 | 57 | def sort_files_by_size(folder_path: str) -> List[Tuple[str, int]]: 58 | files = [] 59 | 60 | # Get all file paths and sizes 61 | for dirpath, dirnames, filenames in os.walk(folder_path): 62 | for f in filenames: 63 | file_path = os.path.join(dirpath, f) 64 | if not os.path.isfile(file_path): 65 | continue 66 | file_size = os.path.getsize(file_path) 67 | files.append((file_path, file_size)) 68 | 69 | # Sort by file size 70 | sorted_files = sorted(files, key=lambda x: x[1]) 71 | 72 | return sorted_files 73 | 74 | def touch_file_access_time(filename: str): 75 | if not os.path.exists(filename): 76 | return 77 | now = time.time() 78 | stat_info = os.stat(filename) 79 | atime = stat_info.st_atime 80 | mtime = stat_info.st_mtime 81 | 82 | os.utime(filename, times=(now, mtime)) 83 | 84 | def convert_to_bytes(size_str) -> Optional[int]: 85 | size_str = size_str.strip().upper() 86 | multipliers = { 87 | "K": 1024, 88 | "M": 1024**2, 89 | "G": 1024**3, 90 | "T": 1024**4, 91 | "KB": 1024, 92 | "MB": 1024**2, 93 | "GB": 1024**3, 94 | "TB": 1024**4, 95 | } 96 | 97 | for unit in multipliers: 98 | if size_str.endswith(unit): 99 | size = int(size_str[: -len(unit)]) 100 | return size * multipliers[unit] 101 | 102 | # Default use bytes 103 | try: 104 | return int(size_str) 105 | except ValueError: 106 | return None 107 | 108 | 109 | def convert_bytes_to_human_readable(bytes: int) -> str: 110 | suffixes = ["B", "KB", "MB", "GB", "TB"] 111 | index = 0 112 | while bytes >= 1024 and index < len(suffixes) - 1: 113 | bytes /= 1024 114 | index += 1 115 | return f"{bytes:.2f} {suffixes[index]}" 116 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | /mirrors_dir/ 163 | /model_dir/ 164 | /dataset_dir/ 165 | /repos/ 166 | /logs/ 167 | /cache/ 168 | 169 | *.key 170 | *.csr 171 | *.pem 172 | *.crt 173 | 174 | -------------------------------------------------------------------------------- /src/olah/proxy/tree.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | import os 9 | from typing import Dict, Literal, Mapping, Optional, AsyncGenerator, Union 10 | from urllib.parse import urljoin 11 | from fastapi import FastAPI, Request 12 | 13 | import httpx 14 | from olah.constants import CHUNK_SIZE, WORKER_API_TIMEOUT 15 | 16 | from olah.utils.cache_utils import read_cache_request, write_cache_request 17 | from olah.utils.rule_utils import check_cache_rules_hf 18 | from olah.utils.repo_utils import get_org_repo 19 | from olah.utils.file_utils import make_dirs 20 | 21 | 22 | async def _tree_cache_generator(save_path: str) -> AsyncGenerator[Union[int, Dict[str, str], bytes], None]: 23 | cache_rq = await read_cache_request(save_path) 24 | yield cache_rq["status_code"] 25 | yield cache_rq["headers"] 26 | yield cache_rq["content"] 27 | 28 | async def _tree_proxy_generator( 29 | app: FastAPI, 30 | headers: Dict[str, str], 31 | tree_url: str, 32 | method: str, 33 | params: Mapping[str, str], 34 | allow_cache: bool, 35 | save_path: str, 36 | ) -> AsyncGenerator[Union[int, Dict[str, str], bytes], None]: 37 | async with httpx.AsyncClient(follow_redirects=True) as client: 38 | content_chunks = [] 39 | async with client.stream( 40 | method=method, 41 | url=tree_url, 42 | params=params, 43 | headers=headers, 44 | timeout=WORKER_API_TIMEOUT, 45 | ) as response: 46 | response_status_code = response.status_code 47 | response_headers = response.headers 48 | yield response_status_code 49 | yield response_headers 50 | 51 | async for raw_chunk in response.aiter_raw(): 52 | if not raw_chunk: 53 | continue 54 | content_chunks.append(raw_chunk) 55 | yield raw_chunk 56 | 57 | content = bytearray() 58 | for chunk in content_chunks: 59 | content += chunk 60 | 61 | if allow_cache and response_status_code == 200: 62 | make_dirs(save_path) 63 | await write_cache_request( 64 | save_path, response_status_code, response_headers, bytes(content) 65 | ) 66 | 67 | 68 | async def tree_generator( 69 | app: FastAPI, 70 | repo_type: Literal["models", "datasets", "spaces"], 71 | org: str, 72 | repo: str, 73 | commit: str, 74 | path: str, 75 | recursive: bool, 76 | expand: bool, 77 | override_cache: bool, 78 | method: str, 79 | authorization: Optional[str], 80 | ) -> AsyncGenerator[Union[int, Dict[str, str], bytes], None]: 81 | headers = {} 82 | if authorization is not None: 83 | headers["authorization"] = authorization 84 | 85 | org_repo = get_org_repo(org, repo) 86 | # save 87 | repos_path = app.state.app_settings.config.repos_path 88 | save_dir = os.path.join( 89 | repos_path, f"api/{repo_type}/{org_repo}/tree/{commit}/{path}" 90 | ) 91 | save_path = os.path.join(save_dir, f"tree_{method}_recursive_{recursive}_expand_{expand}.json") 92 | 93 | use_cache = os.path.exists(save_path) 94 | allow_cache = await check_cache_rules_hf(app, repo_type, org, repo) 95 | 96 | org_repo = get_org_repo(org, repo) 97 | tree_url = urljoin( 98 | app.state.app_settings.config.hf_url_base(), 99 | f"/api/{repo_type}/{org_repo}/tree/{commit}/{path}", 100 | ) 101 | # proxy 102 | if use_cache and not override_cache: 103 | async for item in _tree_cache_generator(save_path): 104 | yield item 105 | else: 106 | async for item in _tree_proxy_generator( 107 | app, headers, tree_url, method, {"recursive": recursive, "expand": expand}, allow_cache, save_path 108 | ): 109 | yield item 110 | -------------------------------------------------------------------------------- /src/olah/proxy/pathsinfo.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | import json 9 | import os 10 | from typing import AsyncGenerator, Dict, List, Literal, Optional, Tuple, Union 11 | from urllib.parse import quote, urljoin 12 | from fastapi import FastAPI, Request 13 | 14 | import httpx 15 | from olah.constants import CHUNK_SIZE, WORKER_API_TIMEOUT 16 | 17 | from olah.utils.cache_utils import read_cache_request, write_cache_request 18 | from olah.utils.rule_utils import check_cache_rules_hf 19 | from olah.utils.repo_utils import get_org_repo 20 | from olah.utils.file_utils import make_dirs 21 | 22 | 23 | async def _pathsinfo_cache(save_path: str) -> Tuple[int, Dict[str, str], bytes]: 24 | cache_rq = await read_cache_request(save_path) 25 | return cache_rq["status_code"], cache_rq["headers"], cache_rq["content"] 26 | 27 | 28 | async def _pathsinfo_proxy( 29 | app: FastAPI, 30 | headers: Dict[str, str], 31 | pathsinfo_url: str, 32 | method: str, 33 | path: str, 34 | allow_cache: bool, 35 | save_path: str, 36 | ) -> Tuple[int, Dict[str, str], bytes]: 37 | headers = {k: v for k, v in headers.items()} 38 | if "content-length" in headers: 39 | headers.pop("content-length") 40 | async with httpx.AsyncClient(follow_redirects=True) as client: 41 | response = await client.request( 42 | method=method, 43 | url=pathsinfo_url, 44 | headers=headers, 45 | data={"paths": path}, 46 | timeout=WORKER_API_TIMEOUT, 47 | ) 48 | 49 | if allow_cache and response.status_code == 200: 50 | make_dirs(save_path) 51 | await write_cache_request( 52 | save_path, 53 | response.status_code, 54 | response.headers, 55 | bytes(response.content), 56 | ) 57 | return response.status_code, response.headers, response.content 58 | 59 | 60 | async def pathsinfo_generator( 61 | app: FastAPI, 62 | repo_type: Literal["models", "datasets", "spaces"], 63 | org: str, 64 | repo: str, 65 | commit: str, 66 | paths: List[str], 67 | override_cache: bool, 68 | method: str, 69 | authorization: Optional[str], 70 | ) -> AsyncGenerator[Union[int, Dict[str, str], bytes], None]: 71 | headers = {} 72 | if authorization is not None: 73 | headers["authorization"] = authorization 74 | 75 | org_repo = get_org_repo(org, repo) 76 | # save 77 | repos_path = app.state.app_settings.config.repos_path 78 | 79 | final_content = [] 80 | for path in paths: 81 | save_dir = os.path.join( 82 | repos_path, f"api/{repo_type}/{org_repo}/paths-info/{commit}/{path}" 83 | ) 84 | 85 | save_path = os.path.join(save_dir, f"paths-info_{method}.json") 86 | 87 | use_cache = os.path.exists(save_path) 88 | allow_cache = await check_cache_rules_hf(app, repo_type, org, repo) 89 | 90 | org_repo = get_org_repo(org, repo) 91 | pathsinfo_url = urljoin( 92 | app.state.app_settings.config.hf_url_base(), 93 | f"/api/{repo_type}/{org_repo}/paths-info/{commit}", 94 | ) 95 | # proxy 96 | if use_cache and not override_cache: 97 | status, headers, content = await _pathsinfo_cache(save_path) 98 | else: 99 | status, headers, content = await _pathsinfo_proxy( 100 | app, headers, pathsinfo_url, method, path, allow_cache, save_path 101 | ) 102 | 103 | try: 104 | content_json = json.loads(content) 105 | except json.JSONDecodeError: 106 | continue 107 | if status == 200 and isinstance(content_json, list): 108 | final_content.extend(content_json) 109 | 110 | yield 200 111 | yield {'content-type': 'application/json'} 112 | yield json.dumps(final_content, ensure_ascii=True) 113 | -------------------------------------------------------------------------------- /src/olah/static/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Olah HuggingFace Mirror 5 | 7 | 8 | 11 | 14 | 15 | 18 | 46 | 61 | 62 | 63 | 64 | 66 | 67 | 68 | 69 | 70 |
71 |
72 |

Welcome to Olah!

73 |
74 |

Use Mirror URL with huggingface-cli

75 |

Set the Environment Variable HF_ENDPOINT to the mirror site (Here is 76 | {{scheme}}://{{netloc}}).

77 | 78 |

Linux:

79 |
export HF_ENDPOINT={{scheme}}://{{netloc}}
80 | 81 |

Windows Powershell:

82 |
$env:HF_ENDPOINT = "{{scheme}}://{{netloc}}"
83 | 84 |

Starting from now on, all download operations in the HuggingFace library will be proxied through this 85 | mirror site.

86 | 87 |

 88 | from huggingface_hub import snapshot_download
 89 | 
 90 | snapshot_download(repo_id='Qwen/Qwen-7B', repo_type='model',
 91 |     local_dir='./model_dir', resume_download=True,
 92 |     max_workers=8)
 93 | 
94 | 95 |

Or you can download models and datasets by using huggingface cli.

96 | 97 |
pip install -U huggingface_hub
98 | 99 |

Download GPT2:

100 |
huggingface-cli download --resume-download openai-community/gpt2 --local-dir gpt2
101 | 102 |

Download WikiText:

103 |
huggingface-cli download --repo-type dataset --resume-download Salesforce/wikitext --local-dir wikitext
104 |
105 |
106 |
107 | 108 | 109 | -------------------------------------------------------------------------------- /README_zh.md: -------------------------------------------------------------------------------- 1 |

Olah

2 | 3 | 4 |

5 | 自托管的轻量级HuggingFace镜像服务 6 | 7 | Olah是开源的自托管轻量级HuggingFace镜像服务。`Olah`来源于丘丘人语,在丘丘人语中意味着`你好`。 8 | Olah真正地实现了huggingface资源的`镜像`功能,而不仅仅是一个简单的`反向代理`。 9 | Olah并不会立刻对huggingface全站进行镜像,而是在用户下载的同时在文件块级别对资源进行镜像(或者我们可以说是缓存)。 10 | 11 | ## Olah的优势 12 | Olah能够在用户下载的同时分块缓存文件。当第二次下载时,直接从缓存中读取,极大地提升了下载速度并节约了流量。 13 | 同时Olah提供了丰富的缓存控制策略,管理员可以通过配置文件设置哪些仓库可以访问,哪些仓库可以缓存。 14 | 15 | ## 特性 16 | * 数据缓存,减少下载流量 17 | * 模型镜像 18 | * 数据集镜像 19 | * 空间镜像 20 | 21 | ## 安装 22 | 23 | ### 方法1:使用pip 24 | 25 | ```bash 26 | pip install olah 27 | ``` 28 | 29 | 或者: 30 | 31 | ```bash 32 | pip install git+https://github.com/vtuber-plan/olah.git 33 | ``` 34 | 35 | ### 方法2:从源代码安装 36 | 37 | 1. 克隆这个仓库 38 | ```bash 39 | git clone https://github.com/vtuber-plan/olah.git 40 | cd olah 41 | ``` 42 | 43 | 2. 安装包 44 | ```bash 45 | pip install --upgrade pip 46 | pip install -e . 47 | ``` 48 | 49 | ## 快速开始 50 | 在控制台运行以下命令: 51 | ```bash 52 | python -m olah.server 53 | ``` 54 | 55 | 然后将环境变量`HF_ENDPOINT`设置为镜像站点(这里是http://localhost:8090/)。 56 | 57 | Linux: 58 | ```bash 59 | export HF_ENDPOINT=http://localhost:8090 60 | ``` 61 | 62 | Windows Powershell: 63 | ```bash 64 | $env:HF_ENDPOINT = "http://localhost:8090" 65 | ``` 66 | 67 | 从现在开始,HuggingFace库中的所有下载操作都将通过此镜像站点代理进行。 68 | ```bash 69 | pip install -U huggingface_hub 70 | ``` 71 | 72 | ```python 73 | from huggingface_hub import snapshot_download 74 | 75 | snapshot_download(repo_id='Qwen/Qwen-7B', repo_type='model', 76 | local_dir='./model_dir', resume_download=True, 77 | max_workers=8) 78 | 79 | ``` 80 | 81 | 或者你也可以使用huggingface cli直接下载模型和数据集. 82 | 83 | 下载GPT2: 84 | ```bash 85 | huggingface-cli download --resume-download openai-community/gpt2 --local-dir gpt2 86 | ``` 87 | 88 | 下载WikiText: 89 | ```bash 90 | huggingface-cli download --repo-type dataset --resume-download Salesforce/wikitext --local-dir wikitext 91 | ``` 92 | 93 | 您可以查看路径`./repos`,其中存储了所有数据集和模型的缓存。 94 | 95 | ## 启动服务器 96 | 在控制台运行以下命令: 97 | ```bash 98 | olah-cli 99 | ``` 100 | 101 | 或者您可以指定主机地址和监听端口: 102 | ```bash 103 | olah-cli --host localhost --port 8090 104 | ``` 105 | **注意:请记得在修改主机和端口时将`--mirror-netloc`和`--mirror-lfs-netloc`更改为镜像站点的实际URL。** 106 | 107 | ```bash 108 | olah-cli --host 192.168.1.100 --port 8090 --mirror-netloc 192.168.1.100:8090 109 | ``` 110 | 111 | 默认的镜像缓存路径是`./repos`,您可以通过`--repos-path`参数进行更改: 112 | ```bash 113 | olah-cli --host localhost --port 8090 --repos-path ./hf_mirrors 114 | ``` 115 | 116 | **注意,不同版本之间的缓存数据不能迁移,请删除缓存文件夹后再进行olah的升级** 117 | 118 | 在实际部署中可能出现下载并发量很大,导致新的连接出现Timeout错误。 119 | 可以设置uvicorn的WEB_CONCURRENCY变量以增加worker数量以提升产品场景的并发量。 120 | 例如Linux下: 121 | ```bash 122 | export WEB_CONCURRENCY=4 123 | ``` 124 | 125 | ## 更多配置 126 | 127 | 更多配置可以通过配置文件进行控制,通过命令参数传入`configs.toml`以设置配置文件路径: 128 | ```bash 129 | olah-cli -c configs.toml 130 | ``` 131 | 132 | 完整的配置文件内容见[assets/full_configs.toml](https://github.com/vtuber-plan/olah/blob/main/assets/full_configs.toml) 133 | 134 | ### 配置详解 135 | 第一部分basic字段用于对镜像站进行基本设置 136 | ```toml 137 | [basic] 138 | host = "localhost" 139 | port = 8090 140 | ssl-key = "" 141 | ssl-cert = "" 142 | repos-path = "./repos" 143 | cache-size-limit = "" 144 | cache-clean-strategy = "LRU" 145 | hf-scheme = "https" 146 | hf-netloc = "huggingface.co" 147 | hf-lfs-netloc = "cdn-lfs.huggingface.co" 148 | mirror-scheme = "http" 149 | mirror-netloc = "localhost:8090" 150 | mirror-lfs-netloc = "localhost:8090" 151 | mirrors-path = ["./mirrors_dir"] 152 | ``` 153 | 154 | - host: 设置olah监听的host地址 155 | - port: 设置olah监听的端口 156 | - ssl-key和ssl-cert: 当需要开启HTTPS时传入key和cert的文件路径 157 | - repos-path: 用于保存缓存数据的目录 158 | - cache-size-limit: 指定缓存大小限制(例如,100G,500GB,2TB)。Olah会每小时扫描缓存文件夹的大小。如果超出限制,Olah会删除一些缓存文件 159 | - cache-clean-strategy: 指定缓存清理策略(可用策略:LRU,FIFO,LARGE_FIRST) 160 | - hf-scheme: huggingface官方站点的网络协议(一般不需要改动) 161 | - hf-netloc: huggingface官方站点的网络位置(一般不需要改动) 162 | - hf-lfs-netloc: huggingface官方站点LFS文件的网络位置(一般不需要改动) 163 | - mirror-scheme: Olah镜像站的网络协议(应当和上面的设置一致,当提供ssl-key和ssl-cert时,应改为https) 164 | - mirror-netloc: Olah镜像站的网络位置(应与host和port设置一致) 165 | - mirror-lfs-netloc: Olah镜像站LFS的网络位置(应与host和port设置一致) 166 | - mirrors-path: 额外的镜像文件目录。当你已经clone了一些git仓库时可以放入该目录下以供下载。此处例子目录为`./mirrors_dir`, 若要添加数据集`Salesforce/wikitext`,可将git仓库放置于`./mirrors_dir/datasets/Salesforce/wikitext`目录。同理,模型放置于`./mirrors_dir/models/organization/repository`下。 167 | 168 | 169 | 第二部分可以对可访问性进行限制 170 | ```toml 171 | 172 | [accessibility] 173 | offline = false 174 | 175 | [[accessibility.proxy]] 176 | repo = "cais/mmlu" 177 | allow = true 178 | 179 | [[accessibility.proxy]] 180 | repo = "adept/fuyu-8b" 181 | allow = false 182 | 183 | [[accessibility.proxy]] 184 | repo = "mistralai/*" 185 | allow = true 186 | 187 | [[accessibility.proxy]] 188 | repo = "mistralai/Mistral.*" 189 | allow = false 190 | use_re = true 191 | 192 | [[accessibility.cache]] 193 | repo = "cais/mmlu" 194 | allow = true 195 | 196 | [[accessibility.cache]] 197 | repo = "adept/fuyu-8b" 198 | allow = false 199 | ``` 200 | - offline: 设置Olah镜像站是否进入离线模式,不再向huggingface官方站点发出请求以进行数据更新,但已经缓存的仓库仍可以下载 201 | - proxy: 用于设置该仓库是否可以被代理,默认全部允许,`repo`用于匹配仓库名字; 可使用正则表达式和通配符两种模式,`use_re`用于控制是否使用正则表达式,默认使用通配符; `allow`控制该规则的属性是允许代理还是不允许代理。 202 | - cache: 用于设置该仓库是否会被缓存,默认全部允许,`repo`用于匹配仓库名字; 可使用正则表达式和通配符两种模式,`use_re`用于控制是否使用正则表达式,默认使用通配符; `allow`控制该规则的属性是允许代理还是不允许缓存。 203 | 204 | ## 许可证 205 | 206 | olah采用MIT许可证发布。 207 | 208 | ## 另请参阅 209 | 210 | - [olah-docs](https://github.com/vtuber-plan/olah/tree/main/docs) 211 | - [olah-source](https://github.com/vtuber-plan/olah) 212 | 213 | ## Star历史 214 | 215 | [![Star历史图表]()](https://star-history.com/#vtuber-plan/olah&Date) -------------------------------------------------------------------------------- /src/olah/utils/logging.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | from asyncio import AbstractEventLoop 9 | import json 10 | import logging 11 | import logging.handlers 12 | import os 13 | import platform 14 | import re 15 | import sys 16 | from typing import AsyncGenerator, Generator 17 | import warnings 18 | from olah.constants import DEFAULT_LOGGER_DIR 19 | 20 | handler = None 21 | 22 | 23 | # Define a custom formatter without color codes 24 | class NoColorFormatter(logging.Formatter): 25 | color_pattern = re.compile(r"\x1b[^m]*m") # Regex pattern to match color codes 26 | 27 | def format(self, record): 28 | message = super().format(record) 29 | # Remove color codes from the log message 30 | message = self.color_pattern.sub("", message) 31 | return message 32 | 33 | 34 | def build_logger( 35 | logger_name, logger_filename, logger_dir=DEFAULT_LOGGER_DIR 36 | ) -> logging.Logger: 37 | global handler 38 | 39 | formatter = logging.Formatter( 40 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 41 | datefmt="%Y-%m-%d %H:%M:%S", 42 | ) 43 | 44 | nocolor_formatter = NoColorFormatter( 45 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 46 | datefmt="%Y-%m-%d %H:%M:%S", 47 | ) 48 | 49 | # Set the format of root handlers 50 | if logging.getLogger().handlers is None or len(logging.getLogger().handlers) == 0: 51 | if sys.version_info[1] >= 9: 52 | # This is for windows 53 | logging.basicConfig(level=logging.INFO, encoding="utf-8") 54 | else: 55 | if platform.system() == "Windows": 56 | warnings.warn( 57 | "If you are running on Windows, " 58 | "we recommend you use Python >= 3.9 for UTF-8 encoding." 59 | ) 60 | logging.basicConfig(level=logging.INFO) 61 | logging.getLogger().handlers[0].setFormatter(formatter) 62 | 63 | # Redirect stdout and stderr to loggers 64 | stdout_logger = logging.getLogger("stdout") 65 | stdout_logger.setLevel(logging.DEBUG) 66 | sl = StreamToLogger(stdout_logger, logging.INFO) 67 | sys.stdout = sl 68 | 69 | stderr_logger = logging.getLogger("stderr") 70 | stderr_logger.setLevel(logging.ERROR) 71 | sl = StreamToLogger(stderr_logger, logging.ERROR) 72 | sys.stderr = sl 73 | 74 | # Get logger 75 | logger = logging.getLogger(logger_name) 76 | logger.setLevel(logging.DEBUG) 77 | 78 | # Add a file handler for all loggers 79 | if handler is None: 80 | os.makedirs(logger_dir, exist_ok=True) 81 | filename = os.path.join(logger_dir, logger_filename) 82 | handler = logging.handlers.TimedRotatingFileHandler( 83 | filename, when="H", utc=True, encoding="utf-8" 84 | ) 85 | handler.setFormatter(nocolor_formatter) 86 | handler.namer = lambda name: name.replace(".log", "") + ".log" 87 | 88 | for name, item in logging.root.manager.loggerDict.items(): 89 | if isinstance(item, logging.Logger): 90 | item.addHandler(handler) 91 | 92 | return logger 93 | 94 | 95 | class StreamToLogger(object): 96 | """ 97 | Fake file-like stream object that redirects writes to a logger instance. 98 | """ 99 | 100 | def __init__(self, logger, log_level=logging.INFO): 101 | self.terminal = sys.stdout 102 | self.logger = logger 103 | self.log_level = log_level 104 | self.linebuf = "" 105 | 106 | def __getattr__(self, attr): 107 | try: 108 | attr_value = getattr(self.terminal, attr) 109 | except: 110 | return None 111 | return attr_value 112 | 113 | def write(self, buf): 114 | temp_linebuf = self.linebuf + buf 115 | self.linebuf = "" 116 | for line in temp_linebuf.splitlines(True): 117 | # From the io.TextIOWrapper docs: 118 | # On output, if newline is None, any '\n' characters written 119 | # are translated to the system default line separator. 120 | # By default sys.stdout.write() expects '\n' newlines and then 121 | # translates them so this is still cross platform. 122 | if line[-1] == "\n": 123 | encoded_message = line.encode("utf-8", "ignore").decode("utf-8") 124 | self.logger.log(self.log_level, encoded_message.rstrip()) 125 | else: 126 | self.linebuf += line 127 | 128 | def flush(self): 129 | if self.linebuf != "": 130 | encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8") 131 | self.logger.log(self.log_level, encoded_message.rstrip()) 132 | self.linebuf = "" 133 | 134 | 135 | def iter_over_async( 136 | async_gen: AsyncGenerator, event_loop: AbstractEventLoop 137 | ) -> Generator: 138 | """ 139 | Convert async generator to sync generator 140 | 141 | :param async_gen: the AsyncGenerator to convert 142 | :param event_loop: the event loop to run on 143 | :returns: Sync generator 144 | """ 145 | ait = async_gen.__aiter__() 146 | 147 | async def get_next(): 148 | try: 149 | obj = await ait.__anext__() 150 | return False, obj 151 | except StopAsyncIteration: 152 | return True, None 153 | 154 | while True: 155 | done, obj = event_loop.run_until_complete(get_next()) 156 | if done: 157 | break 158 | yield obj 159 | -------------------------------------------------------------------------------- /src/olah/configs.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | from typing import Any, Dict, List, Literal, Optional, Union 9 | import toml 10 | import re 11 | import fnmatch 12 | 13 | from olah.utils.disk_utils import convert_to_bytes 14 | 15 | DEFAULT_PROXY_RULES = [ 16 | {"repo": "*", "allow": True, "use_re": False}, 17 | {"repo": "*/*", "allow": True, "use_re": False}, 18 | ] 19 | 20 | DEFAULT_CACHE_RULES = [ 21 | {"repo": "*", "allow": True, "use_re": False}, 22 | {"repo": "*/*", "allow": True, "use_re": False}, 23 | ] 24 | 25 | 26 | class OlahRule(object): 27 | def __init__(self, repo: str = "", type: str = "*", allow: bool = False, use_re: bool = False) -> None: 28 | self.repo = repo 29 | self.type = type 30 | self.allow = allow 31 | self.use_re = use_re 32 | 33 | @staticmethod 34 | def from_dict(data: Dict[str, Any]) -> "OlahRule": 35 | out = OlahRule() 36 | if "repo" in data: 37 | out.repo = data["repo"] 38 | if "allow" in data: 39 | out.allow = data["allow"] 40 | if "use_re" in data: 41 | out.use_re = data["use_re"] 42 | return out 43 | 44 | def match(self, repo_name: str) -> bool: 45 | if self.use_re: 46 | return self.match_re(repo_name) 47 | else: 48 | return self.match_fn(repo_name) 49 | 50 | def match_fn(self, repo_name: str) -> bool: 51 | return fnmatch.fnmatch(repo_name, self.repo) 52 | 53 | def match_re(self, repo_name: str) -> bool: 54 | return re.match(self.repo, repo_name) is not None 55 | 56 | 57 | class OlahRuleList(object): 58 | def __init__(self) -> None: 59 | self.rules: List[OlahRule] = [] 60 | 61 | @staticmethod 62 | def from_list(data: List[Dict[str, Any]]) -> "OlahRuleList": 63 | out = OlahRuleList() 64 | for item in data: 65 | out.rules.append(OlahRule.from_dict(item)) 66 | return out 67 | 68 | def clear(self): 69 | self.rules.clear() 70 | 71 | def allow(self, repo_name: str) -> bool: 72 | allow = False 73 | for rule in self.rules: 74 | if rule.match(repo_name): 75 | allow = rule.allow 76 | return allow 77 | 78 | 79 | class OlahConfig(object): 80 | def __init__(self, path: Optional[str] = None) -> None: 81 | 82 | # basic 83 | self.host: Union[List[str], str] = "localhost" 84 | self.port = 8090 85 | self.ssl_key = None 86 | self.ssl_cert = None 87 | self.repos_path = "./repos" 88 | self.cache_size_limit: Optional[int] = None 89 | self.cache_clean_strategy: Literal["LRU", "FIFO", "LARGE_FIRST"] = "LRU" 90 | 91 | self.hf_scheme: str = "https" 92 | self.hf_netloc: str = "huggingface.co" 93 | self.hf_lfs_netloc: str = "cdn-lfs.huggingface.co" 94 | 95 | self.mirror_scheme: str = "http" if self.ssl_key is None else "https" 96 | self.mirror_netloc: str = ( 97 | f"{self.host if self._is_specific_addr(self.host) else 'localhost'}:{self.port}" 98 | ) 99 | self.mirror_lfs_netloc: str = ( 100 | f"{self.host if self._is_specific_addr(self.host) else 'localhost'}:{self.port}" 101 | ) 102 | 103 | self.mirrors_path: List[str] = [] 104 | 105 | # accessibility 106 | self.offline = False 107 | self.proxy = OlahRuleList.from_list(DEFAULT_PROXY_RULES) 108 | self.cache = OlahRuleList.from_list(DEFAULT_CACHE_RULES) 109 | 110 | if path is not None: 111 | self.read_toml(path) 112 | 113 | def _is_specific_addr(self, host: Union[List[str], str]) -> bool: 114 | if isinstance(host, str): 115 | return host not in ['0.0.0.0', '::'] 116 | else: 117 | return False 118 | 119 | def hf_url_base(self) -> str: 120 | return f"{self.hf_scheme}://{self.hf_netloc}" 121 | 122 | def hf_lfs_url_base(self) -> str: 123 | return f"{self.hf_scheme}://{self.hf_lfs_netloc}" 124 | 125 | def mirror_url_base(self) -> str: 126 | return f"{self.mirror_scheme}://{self.mirror_netloc}" 127 | 128 | def mirror_lfs_url_base(self) -> str: 129 | return f"{self.mirror_scheme}://{self.mirror_lfs_netloc}" 130 | 131 | def empty_str(self, s: str) -> Optional[str]: 132 | if s == "": 133 | return None 134 | else: 135 | return s 136 | 137 | def read_toml(self, path: str) -> None: 138 | config = toml.load(path) 139 | 140 | if "basic" in config: 141 | basic = config["basic"] 142 | self.host = basic.get("host", self.host) 143 | self.port = basic.get("port", self.port) 144 | self.ssl_key = self.empty_str(basic.get("ssl-key", self.ssl_key)) 145 | self.ssl_cert = self.empty_str(basic.get("ssl-cert", self.ssl_cert)) 146 | self.repos_path = basic.get("repos-path", self.repos_path) 147 | self.cache_size_limit = convert_to_bytes(basic.get("cache-size-limit", self.cache_size_limit)) 148 | self.cache_clean_strategy = basic.get("cache-clean-strategy", self.cache_clean_strategy) 149 | 150 | self.hf_scheme = basic.get("hf-scheme", self.hf_scheme) 151 | self.hf_netloc = basic.get("hf-netloc", self.hf_netloc) 152 | self.hf_lfs_netloc = basic.get("hf-lfs-netloc", self.hf_lfs_netloc) 153 | 154 | self.mirror_scheme = basic.get("mirror-scheme", self.mirror_scheme) 155 | self.mirror_netloc = basic.get("mirror-netloc", self.mirror_netloc) 156 | self.mirror_lfs_netloc = basic.get( 157 | "mirror-lfs-netloc", self.mirror_lfs_netloc 158 | ) 159 | 160 | self.mirrors_path = basic.get("mirrors-path", self.mirrors_path) 161 | 162 | if "accessibility" in config: 163 | accessibility = config["accessibility"] 164 | self.offline = accessibility.get("offline", self.offline) 165 | self.proxy = OlahRuleList.from_list(accessibility.get("proxy", DEFAULT_PROXY_RULES)) 166 | self.cache = OlahRuleList.from_list(accessibility.get("cache", DEFAULT_CACHE_RULES)) 167 | -------------------------------------------------------------------------------- /src/olah/utils/zip_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Handlers for Content-Encoding. 3 | 4 | See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding 5 | """ 6 | 7 | import codecs 8 | import io 9 | import typing 10 | from typing import List, Optional, Union 11 | import zlib 12 | import brotli 13 | import httpx 14 | 15 | 16 | class DecodingError(httpx.RequestError): 17 | """ 18 | Decoding of the response failed, due to a malformed encoding. 19 | """ 20 | 21 | 22 | class ContentDecoder: 23 | def decode(self, data: bytes) -> bytes: 24 | raise NotImplementedError() # pragma: no cover 25 | 26 | def flush(self) -> bytes: 27 | raise NotImplementedError() # pragma: no cover 28 | 29 | 30 | class IdentityDecoder(ContentDecoder): 31 | """ 32 | Handle unencoded data. 33 | """ 34 | 35 | def decode(self, data: bytes) -> bytes: 36 | return data 37 | 38 | def flush(self) -> bytes: 39 | return b"" 40 | 41 | 42 | class DeflateDecoder(ContentDecoder): 43 | """ 44 | Handle 'deflate' decoding. 45 | 46 | See: https://stackoverflow.com/questions/1838699 47 | """ 48 | 49 | def __init__(self) -> None: 50 | self.first_attempt = True 51 | self.decompressor = zlib.decompressobj() 52 | 53 | def decode(self, data: bytes) -> bytes: 54 | was_first_attempt = self.first_attempt 55 | self.first_attempt = False 56 | try: 57 | return self.decompressor.decompress(data) 58 | except zlib.error as exc: 59 | if was_first_attempt: 60 | self.decompressor = zlib.decompressobj(-zlib.MAX_WBITS) 61 | return self.decode(data) 62 | raise DecodingError(str(exc)) from exc 63 | 64 | def flush(self) -> bytes: 65 | try: 66 | return self.decompressor.flush() 67 | except zlib.error as exc: # pragma: no cover 68 | raise DecodingError(str(exc)) from exc 69 | 70 | 71 | class GZipDecoder(ContentDecoder): 72 | """ 73 | Handle 'gzip' decoding. 74 | 75 | See: https://stackoverflow.com/questions/1838699 76 | """ 77 | 78 | def __init__(self) -> None: 79 | self.decompressor = zlib.decompressobj(zlib.MAX_WBITS | 16) 80 | 81 | def decode(self, data: bytes) -> bytes: 82 | try: 83 | return self.decompressor.decompress(data) 84 | except zlib.error as exc: 85 | raise DecodingError(str(exc)) from exc 86 | 87 | def flush(self) -> bytes: 88 | try: 89 | return self.decompressor.flush() 90 | except zlib.error as exc: # pragma: no cover 91 | raise DecodingError(str(exc)) from exc 92 | 93 | 94 | class BrotliDecoder(ContentDecoder): 95 | """ 96 | Handle 'brotli' decoding. 97 | 98 | Requires `pip install brotlipy`. See: https://brotlipy.readthedocs.io/ 99 | or `pip install brotli`. See https://github.com/google/brotli 100 | Supports both 'brotlipy' and 'Brotli' packages since they share an import 101 | name. The top branches are for 'brotlipy' and bottom branches for 'Brotli' 102 | """ 103 | 104 | def __init__(self) -> None: 105 | if brotli is None: # pragma: no cover 106 | raise ImportError( 107 | "Using 'BrotliDecoder', but neither of the 'brotlicffi' or 'brotli' " 108 | "packages have been installed. " 109 | "Make sure to install httpx using `pip install httpx[brotli]`." 110 | ) from None 111 | 112 | self.decompressor = brotli.Decompressor() 113 | self.seen_data = False 114 | self._decompress: typing.Callable[[bytes], bytes] 115 | if hasattr(self.decompressor, "decompress"): 116 | # The 'brotlicffi' package. 117 | self._decompress = self.decompressor.decompress # pragma: no cover 118 | else: 119 | # The 'brotli' package. 120 | self._decompress = self.decompressor.process # pragma: no cover 121 | 122 | def decode(self, data: bytes) -> bytes: 123 | if not data: 124 | return b"" 125 | self.seen_data = True 126 | try: 127 | return self._decompress(data) 128 | except brotli.error as exc: 129 | raise DecodingError(str(exc)) from exc 130 | 131 | def flush(self) -> bytes: 132 | if not self.seen_data: 133 | return b"" 134 | try: 135 | if hasattr(self.decompressor, "finish"): 136 | # Only available in the 'brotlicffi' package. 137 | 138 | # As the decompressor decompresses eagerly, this 139 | # will never actually emit any data. However, it will potentially throw 140 | # errors if a truncated or damaged data stream has been used. 141 | self.decompressor.finish() # pragma: no cover 142 | return b"" 143 | except brotli.error as exc: # pragma: no cover 144 | raise DecodingError(str(exc)) from exc 145 | 146 | 147 | class MultiDecoder(ContentDecoder): 148 | """ 149 | Handle the case where multiple encodings have been applied. 150 | """ 151 | 152 | def __init__(self, children: typing.Sequence[ContentDecoder]) -> None: 153 | """ 154 | 'children' should be a sequence of decoders in the order in which 155 | each was applied. 156 | """ 157 | # Note that we reverse the order for decoding. 158 | self.children = list(reversed(children)) 159 | 160 | def decode(self, data: bytes) -> bytes: 161 | for child in self.children: 162 | data = child.decode(data) 163 | return data 164 | 165 | def flush(self) -> bytes: 166 | data = b"" 167 | for child in self.children: 168 | data = child.decode(data) + child.flush() 169 | return data 170 | 171 | 172 | SUPPORTED_DECODERS = { 173 | "identity": IdentityDecoder, 174 | "gzip": GZipDecoder, 175 | "deflate": DeflateDecoder, 176 | "br": BrotliDecoder, 177 | } 178 | 179 | 180 | class Decompressor(object): 181 | def __init__(self, algorithms: Union[str, List[str]]) -> None: 182 | if isinstance(algorithms, str): 183 | self.algorithms = [algorithms] 184 | else: 185 | self.algorithms = algorithms 186 | 187 | self.decoders = [] 188 | for algo in self.algorithms: 189 | algo = algo.strip().lower() 190 | if algo in SUPPORTED_DECODERS: 191 | self.decoders.append(SUPPORTED_DECODERS[algo]()) 192 | else: 193 | print(f"Unsupported compression algorithm: {algo}") 194 | 195 | self.decoder = MultiDecoder(self.decoders) 196 | 197 | def decompress(self, raw_chunk: bytes) -> bytes: 198 | return self.decoder.decode(raw_chunk) 199 | 200 | 201 | def decompress_data(raw_data: bytes, content_encoding: Optional[str]) -> bytes: 202 | # If result is compressed 203 | if content_encoding is not None: 204 | final_data = raw_data 205 | algorithms = content_encoding.split(",") 206 | for algo in algorithms: 207 | algo = algo.strip().lower() 208 | if algo == "gzip": 209 | try: 210 | final_data = zlib.decompress( 211 | raw_data, zlib.MAX_WBITS | 16 212 | ) # 解压缩 213 | except Exception as e: 214 | print(f"Error decompressing gzip data: {e}") 215 | elif algo == "compress": 216 | print(f"Unsupported decompression algorithm: {algo}") 217 | elif algo == "deflate": 218 | try: 219 | final_data = zlib.decompress(raw_data) 220 | except Exception as e: 221 | print(f"Error decompressing deflate data: {e}") 222 | elif algo == "br": 223 | try: 224 | import brotli 225 | 226 | final_data = brotli.decompress(raw_data) 227 | except Exception as e: 228 | print(f"Error decompressing Brotli data: {e}") 229 | elif algo == "zstd": 230 | try: 231 | import zstandard 232 | 233 | final_data = zstandard.ZstdDecompressor().decompress(raw_data) 234 | except Exception as e: 235 | print(f"Error decompressing Zstandard data: {e}") 236 | else: 237 | print(f"Unsupported compression algorithm: {algo}") 238 | return final_data 239 | else: 240 | return raw_data 241 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Olah

2 | 3 |

4 | Self-hosted Lightweight Huggingface Mirror Service 5 | 6 | Olah is a self-hosted lightweight huggingface mirror service. `Olah` means `hello` in Hilichurlian. 7 | Olah implemented the `mirroring` feature for huggingface resources, rather than just a simple `reverse proxy`. 8 | Olah does not immediately mirror the entire huggingface website but mirrors the resources at the file block level when users download them (or we can say cache them). 9 | 10 | Other languages: [中文](README_zh.md) 11 | 12 | ## Advantages of Olah 13 | Olah has the capability to cache files in chunks while users download them. Upon subsequent downloads, the files can be directly retrieved from the cache, greatly enhancing download speeds and saving bandwidth. 14 | Additionally, Olah offers a range of cache control policies. Administrators can configure which repositories are accessible and which ones can be cached through a configuration file. 15 | 16 | ## Features 17 | * Huggingface Data Cache 18 | * Models mirror 19 | * Datasets mirror 20 | * Spaces mirror 21 | 22 | ## Install 23 | 24 | ### Method 1: With pip 25 | 26 | ```bash 27 | pip install olah 28 | ``` 29 | 30 | or: 31 | 32 | ```bash 33 | pip install git+https://github.com/vtuber-plan/olah.git 34 | ``` 35 | 36 | ### Method 2: From source 37 | 38 | 1. Clone this repository 39 | ```bash 40 | git clone https://github.com/vtuber-plan/olah.git 41 | cd olah 42 | ``` 43 | 44 | 2. Install the Package 45 | ```bash 46 | pip install --upgrade pip 47 | pip install -e . 48 | ``` 49 | 50 | ## Quick Start 51 | Run the command in the console: 52 | ```bash 53 | olah-cli 54 | ``` 55 | 56 | Then set the Environment Variable `HF_ENDPOINT` to the mirror site (Here is http://localhost:8090). 57 | 58 | Linux: 59 | ```bash 60 | export HF_ENDPOINT=http://localhost:8090 61 | ``` 62 | 63 | Windows Powershell: 64 | ```bash 65 | $env:HF_ENDPOINT = "http://localhost:8090" 66 | ``` 67 | 68 | Starting from now on, all download operations in the HuggingFace library will be proxied through this mirror site. 69 | ```bash 70 | pip install -U huggingface_hub 71 | ``` 72 | 73 | ```python 74 | from huggingface_hub import snapshot_download 75 | 76 | snapshot_download(repo_id='Qwen/Qwen-7B', repo_type='model', 77 | local_dir='./model_dir', resume_download=True, 78 | max_workers=8) 79 | ``` 80 | 81 | Or you can download models and datasets by using huggingface cli. 82 | 83 | Download GPT2: 84 | ```bash 85 | huggingface-cli download --resume-download openai-community/gpt2 --local-dir gpt2 86 | ``` 87 | 88 | Download WikiText: 89 | ```bash 90 | huggingface-cli download --repo-type dataset --resume-download Salesforce/wikitext --local-dir wikitext 91 | ``` 92 | 93 | You can check the path `./repos`, in which olah stores all cached datasets and models. 94 | 95 | ## Start the server 96 | Run the command in the console: 97 | ```bash 98 | olah-cli 99 | ``` 100 | 101 | Or you can specify the host address and listening port: 102 | ```bash 103 | olah-cli --host localhost --port 8090 104 | ``` 105 | **Note: Please change --mirror-netloc and --mirror-lfs-netloc to the actual URLs of the mirror sites when modifying the host and port.** 106 | ```bash 107 | olah-cli --host 192.168.1.100 --port 8090 --mirror-netloc 192.168.1.100:8090 108 | ``` 109 | 110 | The default mirror cache path is `./repos`, you can change it by `--repos-path` parameter: 111 | ```bash 112 | olah-cli --host localhost --port 8090 --repos-path ./hf_mirrors 113 | ``` 114 | 115 | **Note that the cached data between different versions cannot be migrated. Please delete the cache folder before upgrading to the latest version of Olah.** 116 | 117 | In deployment scenarios, there may be high concurrent downloads, leading to Timeout errors for new connections. 118 | You can set the `WEB_CONCURRENCY` variable for uvicorn to increase the number of workers, thereby enhancing concurrency in production environments. 119 | For example, on Linux: 120 | ```bash 121 | export WEB_CONCURRENCY=4 122 | ``` 123 | 124 | ## More Configurations 125 | 126 | Additional configurations can be controlled through a configuration file by passing the `configs.toml` file as a command parameter: 127 | ```bash 128 | olah-cli -c configs.toml 129 | ``` 130 | 131 | The complete content of the configuration file can be found at [assets/full_configs.toml](https://github.com/vtuber-plan/olah/blob/main/assets/full_configs.toml). 132 | 133 | ### Configuration Details 134 | The first section, `basic`, is used to set up basic configurations for the mirror site: 135 | ```toml 136 | [basic] 137 | host = "localhost" 138 | port = 8090 139 | ssl-key = "" 140 | ssl-cert = "" 141 | repos-path = "./repos" 142 | cache-size-limit = "" 143 | cache-clean-strategy = "LRU" 144 | hf-scheme = "https" 145 | hf-netloc = "huggingface.co" 146 | hf-lfs-netloc = "cdn-lfs.huggingface.co" 147 | mirror-scheme = "http" 148 | mirror-netloc = "localhost:8090" 149 | mirror-lfs-netloc = "localhost:8090" 150 | mirrors-path = ["./mirrors_dir"] 151 | ``` 152 | - `host`: Sets the host address that Olah listens to. 153 | - `port`: Sets the port that Olah listens to. 154 | - `ssl-key` and `ssl-cert`: When enabling HTTPS, specify the file paths for the key and certificate. 155 | - `repos-path`: Specifies the directory for storing cached data. 156 | - `cache-size-limit`: Specifies cache size limit (For example, 100G, 500GB, 2TB). Olah will scan the size of the cache folder every hour. If it exceeds the limit, olah will delete some cache files. 157 | - `cache-clean-strategy`: Specifies cache cleaning strategy (Available strategies: LRU, FIFO, LARGE_FIRST). 158 | - `hf-scheme`: Network protocol for the Hugging Face official site (usually no need to modify). 159 | - `hf-netloc`: Network location of the Hugging Face official site (usually no need to modify). 160 | - `hf-lfs-netloc`: Network location for Hugging Face official site's LFS files (usually no need to modify). 161 | - `mirror-scheme`: Network protocol for the Olah mirror site (should match the above settings; change to HTTPS if providing `ssl-key` and `ssl-cert`). 162 | - `mirror-netloc`: Network location of the Olah mirror site (should match `host` and `port` settings). 163 | - `mirror-lfs-netloc`: Network location for Olah mirror site's LFS (should match `host` and `port` settings). 164 | - `mirrors-path`: Additional mirror file directories. If you have already cloned some Git repositories, you can place them in this directory for downloading. In this example, the directory is `./mirrors_dir`. To add a dataset like `Salesforce/wikitext`, you can place the Git repository in the directory `./mirrors_dir/datasets/Salesforce/wikitext`. Similarly, models can be placed under `./mirrors_dir/models/organization/repository`. 165 | 166 | The second section allows for accessibility restrictions: 167 | ```toml 168 | [accessibility] 169 | offline = false 170 | 171 | [[accessibility.proxy]] 172 | repo = "cais/mmlu" 173 | allow = true 174 | 175 | [[accessibility.proxy]] 176 | repo = "adept/fuyu-8b" 177 | allow = false 178 | 179 | [[accessibility.proxy]] 180 | repo = "mistralai/*" 181 | allow = true 182 | 183 | [[accessibility.proxy]] 184 | repo = "mistralai/Mistral.*" 185 | allow = false 186 | use_re = true 187 | 188 | [[accessibility.cache]] 189 | repo = "cais/mmlu" 190 | allow = true 191 | 192 | [[accessibility.cache]] 193 | repo = "adept/fuyu-8b" 194 | allow = false 195 | ``` 196 | - `offline`: Sets whether the Olah mirror site enters offline mode, no longer making requests to the Hugging Face official site for data updates. However, cached repositories can still be downloaded. 197 | - `proxy`: Determines if the repository can be accessed through a proxy. By default, all repositories are allowed. The `repo` field is used to match the repository name. Regular expressions and wildcards can be used by setting `use_re` to control whether to use regular expressions (default is to use wildcards). The `allow` field controls whether the repository is allowed to be proxied. 198 | - `cache`: Determines if the repository will be cached. By default, all repositories are allowed. The `repo` field is used to match the repository name. Regular expressions and wildcards can be used by setting `use_re` to control whether to use regular expressions (default is to use wildcards). The `allow` field controls whether the repository is allowed to be cached. 199 | 200 | ## Future Work 201 | 202 | * Administrator and user system 203 | * OOS backend support 204 | * Mirror Update Schedule Task 205 | 206 | ## License 207 | 208 | olah is released under the MIT License. 209 | 210 | 211 | ## See also 212 | 213 | - [olah-docs](https://github.com/vtuber-plan/olah/tree/main/docs) 214 | - [olah-source](https://github.com/vtuber-plan/olah) 215 | 216 | 217 | ## Star History 218 | 219 | [![Star History Chart](https://api.star-history.com/svg?repos=vtuber-plan/olah&type=Date)](https://star-history.com/#vtuber-plan/olah&Date) 220 | 221 | -------------------------------------------------------------------------------- /src/olah/utils/url_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | import datetime 9 | import os 10 | import glob 11 | from typing import Dict, List, Literal, Optional, Tuple, Union 12 | import json 13 | from urllib.parse import ParseResult, urlencode, urljoin, urlparse, parse_qs, urlunparse 14 | import httpx 15 | from olah.configs import OlahConfig 16 | from olah.constants import WORKER_API_TIMEOUT 17 | 18 | 19 | def get_url_tail(parsed_url: Union[str, ParseResult]) -> str: 20 | """ 21 | Extracts the tail of a URL, including path, parameters, query, and fragment. 22 | 23 | Args: 24 | parsed_url (Union[str, ParseResult]): The parsed URL or a string URL. 25 | 26 | Returns: 27 | str: The tail of the URL, including path, parameters, query, and fragment. 28 | """ 29 | if isinstance(parsed_url, str): 30 | parsed_url = urlparse(parsed_url) 31 | url_tail = parsed_url.path 32 | if len(parsed_url.params) != 0: 33 | url_tail += f";{parsed_url.params}" 34 | if len(parsed_url.query) != 0: 35 | url_tail += f"?{parsed_url.query}" 36 | if len(parsed_url.fragment) != 0: 37 | url_tail += f"#{parsed_url.fragment}" 38 | return url_tail 39 | 40 | 41 | def parse_content_range(content_range: str) -> Tuple[str, Optional[int], Optional[int], Optional[int]]: 42 | """ 43 | Parses a Content-Range header string and extracts the unit, start position, end position, and resource size. 44 | 45 | Args: 46 | content_range (str): The Content-Range header string, e.g., "bytes 0-999/1000". 47 | 48 | Returns: 49 | Tuple[str, Optional[int], Optional[int], Optional[int]]: A tuple containing: 50 | - unit (str): The unit of the range, typically "bytes". 51 | - start_pos (Optional[int]): The starting position of the range. None if the range is "*". 52 | - end_pos (Optional[int]): The ending position of the range. None if the range is "*". 53 | - resource_size (Optional[int]): The total size of the resource. None if the size is unknown. 54 | 55 | Raises: 56 | Exception: If the range unit is invalid or the range format is incorrect. 57 | """ 58 | if content_range.startswith("bytes "): 59 | unit = "bytes" 60 | content_range_part = content_range[len("bytes "):] 61 | else: 62 | raise Exception("Invalid range unit") 63 | 64 | 65 | if "/" in content_range_part: 66 | data_range, resource_size = content_range_part.split("/", maxsplit=1) 67 | resource_size = int(resource_size) 68 | else: 69 | data_range = content_range_part 70 | resource_size = None 71 | 72 | if "-" in data_range: 73 | start_pos, end_pos = data_range.split("-") 74 | start_pos, end_pos = int(start_pos), int(end_pos) 75 | elif "*" == data_range.strip(): 76 | start_pos, end_pos = None, None 77 | else: 78 | raise Exception("Invalid range") 79 | return unit, start_pos, end_pos, resource_size 80 | 81 | 82 | def parse_range_params(range_header: str) -> Tuple[str, List[Tuple[Optional[int], Optional[int]]], Optional[int]]: 83 | """ 84 | Parses the HTTP Range request header and returns the unit and a list of ranges. 85 | 86 | Args: 87 | range_header (str): The HTTP Range request header string, e.g., "bytes=0-499" or "bytes=200-999, 2000-2499, 9500-". 88 | 89 | Returns: 90 | Tuple[str, List[Tuple[int, int]], Optional[int]]: A tuple containing the unit (e.g., "bytes") and a list of ranges. 91 | Each range is represented as a tuple of start and end positions. If the end position is not specified, 92 | it is set to None. For suffix-length ranges (e.g., "-500"), the start position is negative. 93 | 94 | Raises: 95 | ValueError: If the Range header is empty or has an invalid format. 96 | """ 97 | if not range_header: 98 | raise ValueError("Range header cannot be empty") 99 | 100 | # Split the unit and range specifiers 101 | parts = range_header.split('=') 102 | if len(parts) != 2: 103 | raise ValueError("Invalid Range header format") 104 | 105 | unit = parts[0].strip() # Get the unit, typically "bytes" 106 | range_specifiers = parts[1].strip() # Get the range part 107 | 108 | if range_specifiers.startswith("-") and range_specifiers[1:].isdigit(): 109 | return unit, [], int(range_specifiers[1:]) 110 | 111 | # Parse multiple ranges 112 | range_list = [] 113 | for range_spec in range_specifiers.split(','): 114 | range_spec = range_spec.strip() 115 | if '-' not in range_spec: 116 | raise ValueError("Invalid range specifier") 117 | 118 | start, end = range_spec.split('-') 119 | start = start.strip() 120 | end = end.strip() 121 | 122 | # Handle suffix-length ranges (e.g., "-500") 123 | if not start and end: 124 | range_list.append((None, int(end))) # Negative start indicates suffix-length 125 | continue 126 | 127 | # Handle open-ended ranges (e.g., "500-") 128 | if not end and start: 129 | range_list.append((int(start), None)) 130 | continue 131 | 132 | # Handle full ranges (e.g., "200-999") 133 | if start and end: 134 | range_list.append((int(start), int(end))) 135 | continue 136 | 137 | # If neither start nor end is provided, it's invalid 138 | raise ValueError("Invalid range specifier") 139 | 140 | return unit, range_list, None 141 | 142 | 143 | def get_all_ranges(file_size: int, unit: str, ranges: List[Tuple[Optional[int], Optional[int]]], suffix: Optional[int]) -> List[Tuple[int, int]]: 144 | all_ranges: List[Tuple[int, int]] = [] 145 | if suffix is not None: 146 | all_ranges.append((file_size - suffix, file_size)) 147 | else: 148 | for r in ranges: 149 | r_start = r[0] if r[0] is not None else 0 150 | r_end = r[1] if r[1] is not None else file_size - 1 151 | start_pos = max(0, r_start) 152 | end_pos = min(file_size - 1, r_end) 153 | if end_pos < start_pos: 154 | continue 155 | all_ranges.append((start_pos, end_pos + 1)) 156 | return all_ranges 157 | 158 | 159 | class RemoteInfo(object): 160 | def __init__(self, method: str, url: str, headers: Dict[str, str]) -> None: 161 | """ 162 | Represents information about a remote request. 163 | 164 | Args: 165 | method (str): The HTTP method of the request. 166 | url (str): The URL of the request. 167 | headers (Dict[str, str]): The headers of the request. 168 | """ 169 | self.method = method 170 | self.url = url 171 | self.headers = headers 172 | 173 | 174 | def check_url_has_param_name(url: str, param_name: str) -> bool: 175 | """ 176 | Checks if a URL contains a specific query parameter. 177 | 178 | Args: 179 | url (str): The URL to check. 180 | param_name (str): The name of the query parameter. 181 | 182 | Returns: 183 | bool: True if the URL contains the parameter, False otherwise. 184 | """ 185 | parsed_url = urlparse(url) 186 | query_params = parse_qs(parsed_url.query) 187 | return param_name in query_params 188 | 189 | 190 | def get_url_param_name(url: str, param_name: str) -> Optional[str]: 191 | """ 192 | Retrieves the value of a specific query parameter from a URL. 193 | 194 | Args: 195 | url (str): The URL to retrieve the parameter from. 196 | param_name (str): The name of the query parameter. 197 | 198 | Returns: 199 | Optional[str]: The value of the query parameter if found, None otherwise. 200 | """ 201 | parsed_url = urlparse(url) 202 | query_params = parse_qs(parsed_url.query) 203 | original_location = query_params.get(param_name) 204 | if original_location: 205 | return original_location[0] 206 | else: 207 | return None 208 | 209 | 210 | def add_query_param(url: str, param_name: str, param_value: str) -> str: 211 | """ 212 | Adds a query parameter to a URL. 213 | 214 | Args: 215 | url (str): The URL to add the parameter to. 216 | param_name (str): The name of the query parameter. 217 | param_value (str): The value of the query parameter. 218 | 219 | Returns: 220 | str: The modified URL with the added query parameter. 221 | """ 222 | parsed_url = urlparse(url) 223 | query_params = parse_qs(parsed_url.query) 224 | 225 | query_params[param_name] = [param_value] 226 | 227 | new_query = urlencode(query_params, doseq=True) 228 | new_url = urlunparse(parsed_url._replace(query=new_query)) 229 | 230 | return new_url 231 | 232 | 233 | def remove_query_param(url: str, param_name: str) -> str: 234 | """ 235 | Removes a query parameter from a URL. 236 | 237 | Args: 238 | url (str): The URL to remove the parameter from. 239 | param_name (str): The name of the query parameter. 240 | 241 | Returns: 242 | str: The modified URL with the parameter removed. 243 | """ 244 | parsed_url = urlparse(url) 245 | query_params = parse_qs(parsed_url.query) 246 | 247 | if param_name in query_params: 248 | del query_params[param_name] 249 | 250 | new_query = urlencode(query_params, doseq=True) 251 | new_url = urlunparse(parsed_url._replace(query=new_query)) 252 | 253 | return new_url 254 | 255 | 256 | def clean_path(path: str) -> str: 257 | while ".." in path: 258 | path = path.replace("..", "") 259 | path = path.replace("\\", "/") 260 | while "//" in path: 261 | path = path.replace("//", "/") 262 | return path -------------------------------------------------------------------------------- /src/olah/utils/repo_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | import datetime 9 | import os 10 | import glob 11 | import tenacity 12 | from typing import Dict, Literal, Optional, Tuple, Union 13 | import json 14 | from urllib.parse import urljoin 15 | import httpx 16 | from olah.constants import WORKER_API_TIMEOUT 17 | from olah.utils.cache_utils import read_cache_request 18 | 19 | 20 | def get_org_repo(org: Optional[str], repo: str) -> str: 21 | """ 22 | Constructs the organization/repository name. 23 | 24 | Args: 25 | org: The organization name (optional). 26 | repo: The repository name. 27 | 28 | Returns: 29 | The organization/repository name as a string. 30 | 31 | """ 32 | if org is None: 33 | org_repo = repo 34 | else: 35 | org_repo = f"{org}/{repo}" 36 | return org_repo 37 | 38 | 39 | def parse_org_repo(org_repo: str) -> Tuple[Optional[str], Optional[str]]: 40 | """ 41 | Parses the organization/repository name. 42 | 43 | Args: 44 | org_repo: The organization/repository name. 45 | 46 | Returns: 47 | A tuple containing the organization name and repository name. 48 | 49 | """ 50 | if "/" in org_repo and org_repo.count("/") != 1: 51 | return None, None 52 | if "/" in org_repo: 53 | org, repo = org_repo.split("/") 54 | else: 55 | org = None 56 | repo = org_repo 57 | return org, repo 58 | 59 | 60 | def get_meta_save_path( 61 | repos_path: str, repo_type: str, org: Optional[str], repo: str, commit: str 62 | ) -> str: 63 | """ 64 | Constructs the path to save the meta.json file. 65 | 66 | Args: 67 | repos_path: The base path where repositories are stored. 68 | repo_type: The type of repository. 69 | org: The organization name (optional). 70 | repo: The repository name. 71 | commit: The commit hash. 72 | 73 | Returns: 74 | The path to save the meta.json file as a string. 75 | 76 | """ 77 | org_repo = get_org_repo(org, repo) 78 | return os.path.join( 79 | repos_path, f"api/{repo_type}/{org_repo}/revision/{commit}/meta_get.json" 80 | ) 81 | 82 | 83 | def get_meta_save_dir( 84 | repos_path: str, repo_type: str, org: Optional[str], repo: str 85 | ) -> str: 86 | """ 87 | Constructs the directory path to save the meta.json file. 88 | 89 | Args: 90 | repos_path: The base path where repositories are stored. 91 | repo_type: The type of repository. 92 | org: The organization name (optional). 93 | repo: The repository name. 94 | 95 | Returns: 96 | The directory path to save the meta.json file as a string. 97 | 98 | """ 99 | org_repo = get_org_repo(org, repo) 100 | return os.path.join(repos_path, f"api/{repo_type}/{org_repo}/revision") 101 | 102 | 103 | def get_file_save_path( 104 | repos_path: str, 105 | repo_type: str, 106 | org: Optional[str], 107 | repo: str, 108 | commit: str, 109 | file_path: str, 110 | ) -> str: 111 | """ 112 | Constructs the path to save a file in the repository. 113 | 114 | Args: 115 | repos_path: The base path where repositories are stored. 116 | repo_type: The type of repository. 117 | org: The organization name (optional). 118 | repo: The repository name. 119 | commit: The commit hash. 120 | file_path: The path of the file within the repository. 121 | 122 | Returns: 123 | The path to save the file as a string. 124 | 125 | """ 126 | org_repo = get_org_repo(org, repo) 127 | return os.path.join( 128 | repos_path, f"heads/{repo_type}/{org_repo}/resolve_head/{commit}/{file_path}" 129 | ) 130 | 131 | 132 | async def get_newest_commit_hf_offline( 133 | app, 134 | repo_type: Optional[Literal["models", "datasets", "spaces"]], 135 | org: str, 136 | repo: str, 137 | ) -> Optional[str]: 138 | """ 139 | Retrieves the newest commit hash for a repository in offline mode. 140 | 141 | Args: 142 | app: The application object. 143 | repo_type: The type of repository. 144 | org: The organization name. 145 | repo: The repository name. 146 | 147 | Returns: 148 | The newest commit hash as a string. 149 | 150 | """ 151 | repos_path = app.state.app_settings.config.repos_path 152 | save_dir = get_meta_save_dir(repos_path, repo_type, org, repo) 153 | files = glob.glob(os.path.join(save_dir, "*", "meta_head.json")) 154 | 155 | time_revisions = [] 156 | for file in files: 157 | with open(file, "r", encoding="utf-8") as f: 158 | obj = json.loads(f.read()) 159 | datetime_object = datetime.datetime.fromisoformat(obj["lastModified"]) 160 | time_revisions.append((datetime_object, obj["sha"])) 161 | 162 | time_revisions = sorted(time_revisions) 163 | if len(time_revisions) == 0: 164 | return None 165 | else: 166 | return time_revisions[-1][1] 167 | 168 | 169 | async def get_newest_commit_hf( 170 | app, 171 | repo_type: Optional[Literal["models", "datasets", "spaces"]], 172 | org: Optional[str], 173 | repo: str, 174 | authorization: Optional[str] = None, 175 | ) -> Optional[str]: 176 | """ 177 | Retrieves the newest commit hash for a repository. 178 | 179 | Args: 180 | app: The application object. 181 | repo_type: The type of repository. 182 | org: The organization name (optional). 183 | repo: The repository name. 184 | 185 | Returns: 186 | The newest commit hash as a string, or None if it cannot be obtained. 187 | 188 | """ 189 | org_repo = get_org_repo(org, repo) 190 | url = urljoin( 191 | app.state.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org_repo}" 192 | ) 193 | if app.state.app_settings.config.offline: 194 | return await get_newest_commit_hf_offline(app, repo_type, org, repo) 195 | try: 196 | async with httpx.AsyncClient() as client: 197 | headers = {} 198 | if authorization is not None: 199 | headers["authorization"] = authorization 200 | response = await client.get(url, headers=headers, timeout=WORKER_API_TIMEOUT) 201 | if response.status_code != 200: 202 | return await get_newest_commit_hf_offline(app, repo_type, org, repo) 203 | obj = json.loads(response.text) 204 | return obj.get("sha", None) 205 | except httpx.TimeoutException as e: 206 | return await get_newest_commit_hf_offline(app, repo_type, org, repo) 207 | 208 | 209 | async def get_commit_hf_offline( 210 | app, 211 | repo_type: Optional[Literal["models", "datasets", "spaces"]], 212 | org: Optional[str], 213 | repo: str, 214 | commit: str, 215 | ) -> Optional[str]: 216 | """ 217 | Retrieves the commit SHA for a given repository and commit from the offline cache. 218 | 219 | This function is used when the application is in offline mode and the commit information is not available from the API. 220 | 221 | Args: 222 | app: The application instance. 223 | repo_type: Optional. The type of repository ("models", "datasets", or "spaces"). 224 | org: Optional. The organization name for the repository. 225 | repo: The name of the repository. 226 | commit: The commit identifier. 227 | 228 | Returns: 229 | The commit SHA as a string if available in the offline cache, or None if the information is not cached. 230 | """ 231 | repos_path = app.state.app_settings.config.repos_path 232 | save_path = get_meta_save_path(repos_path, repo_type, org, repo, commit) 233 | if os.path.exists(save_path): 234 | request_cache = await read_cache_request(save_path) 235 | request_cache_json = json.loads(request_cache["content"]) 236 | return request_cache_json["sha"] 237 | else: 238 | return None 239 | 240 | 241 | async def get_commit_hf( 242 | app, 243 | repo_type: Optional[Literal["models", "datasets", "spaces"]], 244 | org: Optional[str], 245 | repo: str, 246 | commit: str, 247 | authorization: Optional[str] = None, 248 | ) -> Optional[str]: 249 | """ 250 | Retrieves the commit SHA for a given repository and commit from the Hugging Face API. 251 | 252 | Args: 253 | app: The application instance. 254 | repo_type: Optional. The type of repository ("models", "datasets", or "spaces"). 255 | org: Optional. The organization name for the repository. 256 | repo: The name of the repository. 257 | commit: The commit identifier. 258 | authorization: Optional. The authorization token for accessing the API. 259 | 260 | Returns: 261 | The commit SHA as a string, or None if the commit cannot be retrieved. 262 | 263 | Raises: 264 | This function does not raise any explicit exceptions but may propagate exceptions from underlying functions. 265 | """ 266 | org_repo = get_org_repo(org, repo) 267 | url = urljoin( 268 | app.state.app_settings.config.hf_url_base(), 269 | f"/api/{repo_type}/{org_repo}/revision/{commit}", 270 | ) 271 | if app.state.app_settings.config.offline: 272 | return await get_commit_hf_offline(app, repo_type, org, repo, commit) 273 | try: 274 | headers = {} 275 | if authorization is not None: 276 | headers["authorization"] = authorization 277 | async with httpx.AsyncClient() as client: 278 | response = await client.get( 279 | url, headers=headers, timeout=WORKER_API_TIMEOUT, follow_redirects=True 280 | ) 281 | if response.status_code not in [200, 307]: 282 | return await get_commit_hf_offline(app, repo_type, org, repo, commit) 283 | obj = json.loads(response.text) 284 | return obj.get("sha", None) 285 | except: 286 | return await get_commit_hf_offline(app, repo_type, org, repo, commit) 287 | 288 | 289 | @tenacity.retry(stop=tenacity.stop_after_attempt(3)) 290 | async def check_commit_hf( 291 | app, 292 | repo_type: Optional[Literal["models", "datasets", "spaces"]], 293 | org: Optional[str], 294 | repo: str, 295 | commit: Optional[str] = None, 296 | authorization: Optional[str] = None, 297 | ) -> bool: 298 | """ 299 | Checks the commit status of a repository in the Hugging Face ecosystem. 300 | 301 | Args: 302 | app: The application object. 303 | repo_type: The type of repository (models, datasets, or spaces). 304 | org: The organization name (optional). 305 | repo: The repository name. 306 | commit: The commit hash (optional). 307 | authorization: The authorization token (optional). 308 | 309 | Returns: 310 | A boolean indicating if the commit is valid (status code 200 or 307) or not. 311 | 312 | """ 313 | org_repo = get_org_repo(org, repo) 314 | if commit is None: 315 | url = urljoin( 316 | app.state.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org_repo}" 317 | ) 318 | else: 319 | url = urljoin( 320 | app.state.app_settings.config.hf_url_base(), 321 | f"/api/{repo_type}/{org_repo}/revision/{commit}", 322 | ) 323 | 324 | headers = {} 325 | if authorization is not None: 326 | headers["authorization"] = authorization 327 | async with httpx.AsyncClient() as client: 328 | response = await client.request(method="HEAD", url=url, headers=headers, timeout=WORKER_API_TIMEOUT) 329 | status_code = response.status_code 330 | return status_code in [200, 307] 331 | -------------------------------------------------------------------------------- /src/olah/mirror/repos.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | import hashlib 8 | import io 9 | import os 10 | import re 11 | from typing import Any, Dict, List, Union 12 | import gitdb 13 | from git import Commit, Optional, Repo, Tree 14 | from git.objects.base import IndexObjUnion 15 | from gitdb.base import OStream 16 | import yaml 17 | 18 | from olah.mirror.meta import RepoMeta 19 | 20 | 21 | class LocalMirrorRepo(object): 22 | def __init__(self, path: str, repo_type: str, org: str, repo: str) -> None: 23 | self._path = path 24 | self._repo_type = repo_type 25 | self._org = org 26 | self._repo = repo 27 | 28 | self._git_repo = Repo(self._path) 29 | 30 | def _sha256(self, text: Union[str, bytes]) -> str: 31 | if isinstance(text, bytes) or isinstance(text, bytearray): 32 | bin = text 33 | elif isinstance(text, str): 34 | bin = text.encode("utf-8") 35 | else: 36 | raise Exception("Invalid sha256 param type.") 37 | sha256_hash = hashlib.sha256() 38 | sha256_hash.update(bin) 39 | hashed_string = sha256_hash.hexdigest() 40 | return hashed_string 41 | 42 | def _match_card(self, readme: str) -> str: 43 | pattern = r"\s*---(.*?)---" 44 | 45 | match = re.match(pattern, readme, flags=re.S) 46 | 47 | if match: 48 | card_string = match.group(1) 49 | return card_string 50 | else: 51 | return "" 52 | 53 | def _remove_card(self, readme: str) -> str: 54 | pattern = r"\s*---(.*?)---" 55 | out = re.sub(pattern, "", readme, flags=re.S) 56 | return out 57 | 58 | def _get_readme(self, commit: Commit) -> str: 59 | if "README.md" not in commit.tree: 60 | return "" 61 | else: 62 | out: bytes = commit.tree["README.md"].data_stream.read() 63 | return out.decode() 64 | 65 | def _get_description(self, commit: Commit) -> str: 66 | readme = self._get_readme(commit) 67 | return self._remove_card(readme) 68 | 69 | def _get_tree_filepaths_recursive(self, tree: Tree, include_dir: bool = False) -> List[str]: 70 | out_paths = [] 71 | for entry in tree: 72 | if entry.type == "tree": 73 | out_paths.extend(self._get_tree_filepaths_recursive(entry)) 74 | if include_dir: 75 | out_paths.append(entry.path) 76 | else: 77 | out_paths.append(entry.path) 78 | return out_paths 79 | 80 | def _get_commit_filepaths_recursive(self, commit: Commit) -> List[str]: 81 | return self._get_tree_filepaths_recursive(commit.tree) 82 | 83 | def _get_path_info(self, entry: IndexObjUnion, expand: bool = False) -> Dict[str, Union[int, str]]: 84 | lfs = False 85 | if entry.type != "tree": 86 | t = "file" 87 | repr_size = entry.size 88 | if repr_size > 120 and repr_size < 150: 89 | # check lfs 90 | lfs_data = entry.data_stream.read().decode("utf-8") 91 | match_groups = re.match( 92 | r"version https://git-lfs\.github\.com/spec/v[0-9]\noid sha256:([0-9a-z]{64})\nsize ([0-9]+?)\n", 93 | lfs_data, 94 | ) 95 | if match_groups is not None: 96 | lfs = True 97 | sha256 = match_groups.group(1) 98 | repr_size = int(match_groups.group(2)) 99 | lfs_data = { 100 | "oid": sha256, 101 | "size": repr_size, 102 | "pointerSize": entry.size, 103 | } 104 | else: 105 | t = "directory" 106 | repr_size = entry.size 107 | 108 | if not lfs: 109 | item = { 110 | "type": t, 111 | "oid": entry.hexsha, 112 | "size": repr_size, 113 | "path": entry.path, 114 | "name": entry.name, 115 | } 116 | else: 117 | item = { 118 | "type": t, 119 | "oid": entry.hexsha, 120 | "size": repr_size, 121 | "path": entry.path, 122 | "name": entry.name, 123 | "lfs": lfs_data, 124 | } 125 | if expand: 126 | last_commit = next(self._git_repo.iter_commits(paths=entry.path, max_count=1)) 127 | item["lastCommit"] = { 128 | "id": last_commit.hexsha, 129 | "title": last_commit.message, 130 | "date": last_commit.committed_datetime.strftime( 131 | "%Y-%m-%dT%H:%M:%S.%fZ" 132 | ) 133 | } 134 | item["security"] = { 135 | "blobId": entry.hexsha, 136 | "name": entry.name, 137 | "safe": True, 138 | "indexed": False, 139 | "avScan": { 140 | "virusFound": False, 141 | "virusNames": None 142 | }, 143 | "pickleImportScan": None 144 | } 145 | return item 146 | 147 | def _get_tree_files( 148 | self, tree: Tree, recursive: bool = False, expand: bool = False 149 | ) -> List[Dict[str, Union[int, str]]]: 150 | entries = [] 151 | for entry in tree: 152 | entries.append(self._get_path_info(entry=entry, expand=expand)) 153 | 154 | if recursive: 155 | for entry in tree: 156 | if entry.type == "tree": 157 | entries.extend(self._get_tree_files(entry, recursive=recursive, expand=expand)) 158 | return entries 159 | 160 | def _get_commit_files(self, commit: Commit) -> List[Dict[str, Union[int, str]]]: 161 | return self._get_tree_files(commit.tree) 162 | 163 | def _get_earliest_commit(self) -> Commit: 164 | earliest_commit = None 165 | earliest_commit_date = None 166 | 167 | for commit in self._git_repo.iter_commits(): 168 | commit_date = commit.committed_datetime 169 | 170 | if earliest_commit_date is None or commit_date < earliest_commit_date: 171 | earliest_commit = commit 172 | earliest_commit_date = commit_date 173 | 174 | return earliest_commit 175 | 176 | def get_index_object_by_path( 177 | self, commit_hash: str, path: str 178 | ) -> Optional[IndexObjUnion]: 179 | try: 180 | commit = self._git_repo.commit(commit_hash) 181 | except gitdb.exc.BadName: 182 | return None 183 | path_part = path.split("/") 184 | path_part = [part for part in path_part if len(part.strip()) != 0] 185 | tree = commit.tree 186 | items = self._get_tree_files(tree=tree) 187 | if len(path_part) == 0: 188 | return None 189 | for i, part in enumerate(path_part): 190 | if i != len(path_part) - 1: 191 | if part not in [ 192 | item["name"] for item in items if item["type"] == "directory" 193 | ]: 194 | return None 195 | else: 196 | if part not in [ 197 | item["name"] for item in items 198 | ]: 199 | return None 200 | tree = tree[part] 201 | if tree.type == "tree": 202 | items = self._get_tree_files(tree=tree, recursive=False) 203 | return tree 204 | 205 | def get_pathinfos( 206 | self, commit_hash: str, paths: List[str] 207 | ) -> Optional[List[Dict[str, Any]]]: 208 | try: 209 | commit = self._git_repo.commit(commit_hash) 210 | except gitdb.exc.BadName: 211 | return None 212 | 213 | results = [] 214 | for path in paths: 215 | index_obj = self.get_index_object_by_path( 216 | commit_hash=commit_hash, path=path 217 | ) 218 | if index_obj is not None: 219 | results.append(self._get_path_info(index_obj)) 220 | 221 | for r in results: 222 | if "name" in r: 223 | r.pop("name") 224 | return results 225 | 226 | def get_tree( 227 | self, commit_hash: str, path: str, recursive: bool = False, expand: bool = False 228 | ) -> Optional[Dict[str, Any]]: 229 | try: 230 | commit = self._git_repo.commit(commit_hash) 231 | except gitdb.exc.BadName: 232 | return None 233 | 234 | index_obj = self.get_index_object_by_path(commit_hash=commit_hash, path=path) 235 | items = self._get_tree_files(tree=index_obj, recursive=recursive, expand=expand) 236 | for r in items: 237 | r.pop("name") 238 | return items 239 | 240 | def get_commits(self, commit_hash: str) -> Optional[Dict[str, Any]]: 241 | try: 242 | commit = self._git_repo.commit(commit_hash) 243 | except gitdb.exc.BadName: 244 | return None 245 | 246 | parent_commits = [commit] + [each_commit for each_commit in commit.iter_parents()] 247 | items = [] 248 | for each_commit in parent_commits: 249 | item = { 250 | "id": each_commit.hexsha, 251 | "title": each_commit.message, 252 | "message": "", 253 | "authors": [], 254 | "date": each_commit.committed_datetime.strftime("%Y-%m-%dT%H:%M:%S.%fZ") 255 | } 256 | item["authors"].append({ 257 | "name": each_commit.author.name, 258 | "avatar": None 259 | }) 260 | items.append(item) 261 | return items 262 | 263 | def get_meta(self, commit_hash: str) -> Optional[Dict[str, Any]]: 264 | try: 265 | commit = self._git_repo.commit(commit_hash) 266 | except gitdb.exc.BadName: 267 | return None 268 | meta = RepoMeta() 269 | 270 | meta._id = self._sha256(f"{self._org}/{self._repo}/{commit.hexsha}") 271 | meta.id = f"{self._org}/{self._repo}" 272 | meta.author = self._org 273 | meta.sha = commit.hexsha 274 | meta.lastModified = self._git_repo.head.commit.committed_datetime.strftime( 275 | "%Y-%m-%dT%H:%M:%S.%fZ" 276 | ) 277 | meta.private = False 278 | meta.gated = False 279 | meta.disabled = False 280 | meta.tags = [] 281 | meta.description = self._get_description(commit) 282 | meta.paperswithcode_id = None 283 | meta.downloads = 0 284 | meta.likes = 0 285 | meta.cardData = yaml.load( 286 | self._match_card(self._get_readme(commit)), Loader=yaml.CLoader 287 | ) 288 | meta.siblings = [ 289 | {"rfilename": p} for p in self._get_commit_filepaths_recursive(commit) 290 | ] 291 | meta.createdAt = self._get_earliest_commit().committed_datetime.strftime( 292 | "%Y-%m-%dT%H:%M:%S.%fZ" 293 | ) 294 | return meta.to_dict() 295 | 296 | def _contain_path(self, path: str, tree: Tree) -> bool: 297 | norm_p = os.path.normpath(path).replace("\\", "/") 298 | parts = norm_p.split("/") 299 | for part in parts: 300 | if all([t.name != part for t in tree]): 301 | return False 302 | else: 303 | entry = tree[part] 304 | if entry.type == "tree": 305 | tree = entry 306 | else: 307 | tree = {} 308 | return True 309 | 310 | def get_file_head(self, commit_hash: str, path: str) -> Optional[Dict[str, Any]]: 311 | try: 312 | commit = self._git_repo.commit(commit_hash) 313 | except gitdb.exc.BadName: 314 | return None 315 | 316 | if not self._contain_path(path, commit.tree): 317 | return None 318 | else: 319 | header = {} 320 | header["content-length"] = str(commit.tree[path].data_stream.size) 321 | header["x-repo-commit"] = commit.hexsha 322 | header["etag"] = commit.tree[path].binsha.hex() 323 | if (commit.tree[path].data_stream.size > 120) and (commit.tree[path].data_stream.size < 150): 324 | lfs_data = commit.tree[path].data_stream.read().decode("utf-8") 325 | match_groups = re.match( 326 | r"version https://git-lfs\.github\.com/spec/v[0-9]\noid sha256:([0-9a-z]{64})\nsize ([0-9]+?)\n", 327 | lfs_data, 328 | ) 329 | if match_groups is not None: 330 | oid_sha256 = match_groups.group(1) 331 | objects_dir = os.path.join(self._git_repo.working_dir, '.git', 'lfs', 'objects') 332 | oid_dir = os.path.join(objects_dir, oid_sha256[:2], oid_sha256[2:4], oid_sha256) 333 | header["content-length"] = str(os.path.getsize(oid_dir)) 334 | with open(oid_dir, mode='rb') as lfs_file: 335 | header["etag"] = self._sha256(lfs_file.read()) 336 | 337 | return header 338 | 339 | def get_file(self, commit_hash: str, path: str) -> Optional[OStream]: 340 | try: 341 | commit = self._git_repo.commit(commit_hash) 342 | except gitdb.exc.BadName: 343 | return None 344 | 345 | lfs = False 346 | oid_dir = "" 347 | if (commit.tree[path].size > 120) and (commit.tree[path].size < 150): 348 | lfs_data = commit.tree[path].data_stream.read().decode("utf-8") 349 | match_groups = re.match( 350 | r"version https://git-lfs\.github\.com/spec/v[0-9]\noid sha256:([0-9a-z]{64})\nsize ([0-9]+?)\n", 351 | lfs_data, 352 | ) 353 | if match_groups is not None: 354 | lfs = True 355 | oid_sha256 = match_groups.group(1) 356 | objects_dir = os.path.join(self._git_repo.working_dir, '.git', 'lfs', 'objects') 357 | oid_dir = os.path.join(objects_dir, oid_sha256[:2], oid_sha256[2:4], oid_sha256) 358 | 359 | def stream_wrapper(file_bytes: bytes): 360 | file_stream = io.BytesIO(file_bytes) 361 | while True: 362 | chunk = file_stream.read(4096) 363 | if len(chunk) == 0: 364 | break 365 | else: 366 | yield chunk 367 | 368 | if not self._contain_path(path, commit.tree): 369 | return None 370 | else: 371 | if lfs: 372 | with open(oid_dir, mode='rb') as lfs_file: 373 | return stream_wrapper(lfs_file.read()) 374 | else: 375 | return stream_wrapper(commit.tree[path].data_stream.read()) 376 | -------------------------------------------------------------------------------- /src/olah/cache/olah_cache.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | import asyncio 9 | import lzma 10 | import mmap 11 | import os 12 | import string 13 | import struct 14 | import threading 15 | import gzip 16 | from typing import BinaryIO, Dict, List, Optional 17 | 18 | import aiofiles 19 | import fastapi 20 | import fastapi.concurrency 21 | import portalocker 22 | from .bitset import Bitset 23 | 24 | CURRENT_OLAH_CACHE_VERSION = 9 25 | # Due to the download chunk settings: https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/constants.py#L37 26 | DEFAULT_BLOCK_SIZE = 50 * 1024 * 1024 27 | MAX_BLOCK_NUM = 8192 28 | DEFAULT_COMPRESSION_ALGO = 1 29 | """ 30 | 0: no compression 31 | 1: gzip 32 | 2: lzma 33 | 3: blosc 34 | 4: zlib 35 | 5: zstd 36 | 6: ... 37 | """ 38 | 39 | class OlahCacheHeader(object): 40 | MAGIC_NUMBER = "OLAH".encode("ascii") 41 | HEADER_FIX_SIZE = 36 42 | 43 | def __init__( 44 | self, 45 | version: int = CURRENT_OLAH_CACHE_VERSION, 46 | block_size: int = DEFAULT_BLOCK_SIZE, 47 | file_size: int = 0, 48 | compression_algo: int = DEFAULT_COMPRESSION_ALGO, 49 | ) -> None: 50 | self._version = version 51 | self._block_size = block_size 52 | self._file_size = file_size 53 | self._compression_algo = compression_algo 54 | 55 | self._block_number = (file_size + block_size - 1) // block_size 56 | 57 | @property 58 | def version(self) -> int: 59 | return self._version 60 | 61 | @property 62 | def block_size(self) -> int: 63 | return self._block_size 64 | 65 | @property 66 | def file_size(self) -> int: 67 | return self._file_size 68 | 69 | @property 70 | def block_number(self) -> int: 71 | return self._block_number 72 | 73 | @property 74 | def compression_algo(self) -> int: 75 | return self._compression_algo 76 | 77 | def get_header_size(self) -> int: 78 | return self.HEADER_FIX_SIZE 79 | 80 | def _valid_header(self) -> None: 81 | if self._file_size > MAX_BLOCK_NUM * self._block_size: 82 | raise Exception( 83 | f"The size of file {self._file_size} is out of the max capability of container ({MAX_BLOCK_NUM} * {self._block_size})." 84 | ) 85 | if self._version < CURRENT_OLAH_CACHE_VERSION: 86 | raise Exception( 87 | f"This Olah Cache file is created by older version Olah. Please remove cache files and retry." 88 | ) 89 | 90 | if self._version > CURRENT_OLAH_CACHE_VERSION: 91 | raise Exception( 92 | f"This Olah Cache file is created by newer version Olah. Please remove cache files and retry." 93 | ) 94 | 95 | @staticmethod 96 | def read(stream) -> "OlahCacheHeader": 97 | obj = OlahCacheHeader() 98 | try: 99 | magic = struct.unpack( 100 | "<4s", stream.read(4) 101 | ) 102 | except struct.error: 103 | raise Exception("File is not a Olah cache file.") 104 | if magic[0] != OlahCacheHeader.MAGIC_NUMBER: 105 | raise Exception("File is not a Olah cache file.") 106 | 107 | version, block_size, file_size, compression_algo = struct.unpack( 108 | " None: 134 | self.path: Optional[str] = path 135 | self.header: Optional[OlahCacheHeader] = None 136 | self.is_open: bool = False 137 | 138 | # Lock 139 | self._header_lock = threading.Lock() 140 | 141 | # Path 142 | self._meta_path = os.path.join(path, "meta.bin") 143 | self._data_path = os.path.join(path, "blocks/block_${block_index}.bin") 144 | 145 | self.open(path, block_size=block_size) 146 | 147 | @staticmethod 148 | def create(path: str, block_size: int = DEFAULT_BLOCK_SIZE): 149 | return OlahCache(path, block_size=block_size) 150 | 151 | def open(self, path: str, block_size: int = DEFAULT_BLOCK_SIZE): 152 | if self.is_open: 153 | raise Exception("This file has been open.") 154 | if self.path is None: 155 | raise Exception("The file path is None.") 156 | 157 | if os.path.exists(path): 158 | if not os.path.isdir(path): 159 | raise Exception("The cache path shall be a folder instead of a file.") 160 | with self._header_lock: 161 | with portalocker.Lock(self._meta_path, "rb", timeout=60, flags=portalocker.LOCK_SH) as f: 162 | f.seek(0) 163 | self.header = OlahCacheHeader.read(f) 164 | else: 165 | os.makedirs(self.path, exist_ok=True) 166 | os.makedirs(os.path.join(self.path, "blocks"), exist_ok=True) 167 | with self._header_lock: 168 | # Create new file 169 | with portalocker.Lock(self._meta_path, "wb", timeout=60, flags=portalocker.LOCK_EX) as f: 170 | f.seek(0) 171 | self.header = OlahCacheHeader( 172 | version=CURRENT_OLAH_CACHE_VERSION, 173 | block_size=block_size, 174 | file_size=0, 175 | ) 176 | self.header.write(f) 177 | 178 | self.is_open = True 179 | 180 | def close(self): 181 | if not self.is_open: 182 | raise Exception("This file has been close.") 183 | 184 | self._flush_header() 185 | self.path = None 186 | self.header = None 187 | 188 | self.is_open = False 189 | 190 | def _flush_header(self): 191 | if self.header is None: 192 | raise Exception("The header of cache file is None") 193 | if self.path is None: 194 | raise Exception("The path of cache file is None") 195 | with self._header_lock: 196 | with portalocker.Lock(self._meta_path, "rb+", flags=portalocker.LOCK_EX) as f: 197 | f.seek(0) 198 | self.header.write(f) 199 | 200 | def _get_file_size(self) -> int: 201 | if self.header is None: 202 | raise Exception("The header of cache file is None") 203 | with self._header_lock: 204 | file_size = self.header.file_size 205 | return file_size 206 | 207 | def _get_block_number(self) -> int: 208 | if self.header is None: 209 | raise Exception("The header of cache file is None") 210 | with self._header_lock: 211 | block_number = self.header.block_number 212 | return block_number 213 | 214 | def _get_block_size(self) -> int: 215 | if self.header is None: 216 | raise Exception("The header of cache file is None") 217 | with self._header_lock: 218 | block_size = self.header.block_size 219 | return block_size 220 | 221 | def _get_header_size(self) -> int: 222 | if self.header is None: 223 | raise Exception("The header of cache file is None") 224 | with self._header_lock: 225 | header_size = self.header.get_header_size() 226 | return header_size 227 | 228 | def _resize_header(self, block_num: int, file_size: int): 229 | if self.header is None: 230 | raise Exception("The header of cache file is None") 231 | with self._header_lock: 232 | self.header._block_number = block_num 233 | self.header._file_size = file_size 234 | self.header._valid_header() 235 | 236 | def _pad_block(self, raw_block: bytes) -> bytes: 237 | if len(raw_block) < self._get_block_size(): 238 | block = raw_block + b"\x00" * (self._get_block_size() - len(raw_block)) 239 | else: 240 | block = raw_block 241 | return block 242 | 243 | def flush(self): 244 | if not self.is_open: 245 | raise Exception("This file has been close.") 246 | self._flush_header() 247 | 248 | def has_block(self, block_index: int) -> bool: 249 | block_path = string.Template(self._data_path).substitute(block_index=f"{block_index:0>8}") 250 | return os.path.exists(block_path) 251 | 252 | async def read_block(self, block_index: int) -> Optional[bytes]: 253 | if not self.is_open: 254 | raise Exception("This file has been closed.") 255 | 256 | if self.path is None: 257 | raise Exception("The path of the cache file is None.") 258 | 259 | if block_index >= self._get_block_number(): 260 | raise Exception("Invalid block index.") 261 | 262 | if self.header is None: 263 | raise Exception("The header of cache file is None") 264 | 265 | if not self.has_block(block_index=block_index): 266 | return None 267 | 268 | block_path = string.Template(self._data_path).substitute(block_index=f"{block_index:0>8}") 269 | 270 | with portalocker.Lock(block_path, "rb", timeout=60, flags=portalocker.LOCK_SH) as fh: 271 | async with aiofiles.open(block_path, mode='rb') as f: 272 | raw_block = await f.read(self._get_block_size()) 273 | 274 | def decompression(block_data: bytes, compression_algo: int): 275 | # compression 276 | if compression_algo == 0: 277 | return block_data 278 | elif compression_algo == 1: 279 | block_data = gzip.decompress(block_data) 280 | elif compression_algo == 2: 281 | lzma_dec = lzma.LZMADecompressor() 282 | block_data = lzma_dec.decompress(block_data) 283 | else: 284 | raise Exception("Unsupported compression algorithm.") 285 | return block_data 286 | 287 | raw_block = await fastapi.concurrency.run_in_threadpool( 288 | decompression, 289 | raw_block, 290 | self.header.compression_algo 291 | ) 292 | 293 | block = self._pad_block(raw_block) 294 | return block 295 | 296 | async def write_block(self, block_index: int, block_bytes: bytes) -> None: 297 | if not self.is_open: 298 | raise Exception("This file has been closed.") 299 | 300 | if self.path is None: 301 | raise Exception("The path of the cache file is None. ") 302 | 303 | if block_index >= self._get_block_number(): 304 | raise Exception("Invalid block index.") 305 | 306 | if self.header is None: 307 | raise Exception("The header of cache file is None") 308 | 309 | if len(block_bytes) != self._get_block_size(): 310 | raise Exception("Block size does not match the cache's block size.") 311 | 312 | # Truncation 313 | if (block_index + 1) * self._get_block_size() > self._get_file_size(): 314 | real_block_bytes = block_bytes[ 315 | : self._get_file_size() - block_index * self._get_block_size() 316 | ] 317 | else: 318 | real_block_bytes = block_bytes 319 | 320 | def compression(block_data: bytes, compression_algo: int): 321 | if compression_algo == 0: 322 | return block_data 323 | elif compression_algo == 1: 324 | block_data = gzip.compress(block_data, compresslevel=4) 325 | elif compression_algo == 2: 326 | lzma_enc = lzma.LZMACompressor() 327 | block_data = lzma_enc.compress(block_data) 328 | else: 329 | raise Exception("Unsupported compression algorithm.") 330 | return block_data 331 | 332 | # Run in the default thread pool executor 333 | real_block_bytes = await fastapi.concurrency.run_in_threadpool( 334 | compression, 335 | real_block_bytes, 336 | self.header.compression_algo 337 | ) 338 | 339 | block_path = string.Template(self._data_path).substitute(block_index=f"{block_index:0>8}") 340 | 341 | with portalocker.Lock(block_path, 'wb+', timeout=60, flags=portalocker.LOCK_EX) as fh: 342 | async with aiofiles.open(block_path, mode='wb+') as f: 343 | await f.write(real_block_bytes) 344 | 345 | self._flush_header() 346 | 347 | def _resize_file_size(self, file_size: int): 348 | """ 349 | Deprecation 350 | """ 351 | if not self.is_open: 352 | raise Exception("This file has been closed.") 353 | 354 | if self.path is None: 355 | raise Exception("The path of the cache file is None. ") 356 | 357 | if file_size == self._get_file_size(): 358 | return 359 | if file_size < self._get_file_size(): 360 | raise Exception( 361 | "Invalid resize file size. New file size must be greater than the current file size." 362 | ) 363 | 364 | with open(self.path, "rb") as f: 365 | with mmap.mmap(f.fileno(), 0, mmap.MAP_SHARED, mmap.PROT_READ) as mm: 366 | mm.seek(0, os.SEEK_END) 367 | bin_size = mm.tell() 368 | 369 | # FIXME: limit the resize method, because it may influence the _block_mask 370 | new_bin_size = self._get_header_size() + file_size 371 | with open(self.path, "rb+") as f: 372 | with mmap.mmap(f.fileno(), 0, mmap.MAP_SHARED, mmap.PROT_WRITE) as mm: 373 | mm.seek(new_bin_size - 1) 374 | mm.write(b'\0') 375 | mm.truncate() 376 | 377 | # Extend file size (slow) 378 | # mm.seek(0, os.SEEK_END) 379 | # mm.write(b"\x00" * (new_bin_size - bin_size)) 380 | 381 | def resize(self, file_size: int): 382 | """ 383 | Deprecation 384 | """ 385 | if not self.is_open: 386 | raise Exception("This file has been closed.") 387 | bs = self._get_block_size() 388 | new_block_num = (file_size + bs - 1) // bs 389 | self._resize_header(new_block_num, file_size) 390 | self._flush_header() 391 | -------------------------------------------------------------------------------- /src/olah/proxy/files.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | import hashlib 9 | import json 10 | import os 11 | from typing import Dict, List, Literal, Optional, Tuple 12 | from fastapi import Request 13 | import httpx 14 | from urllib.parse import urlparse, urljoin 15 | 16 | from olah.constants import ( 17 | CHUNK_SIZE, 18 | WORKER_API_TIMEOUT, 19 | HUGGINGFACE_HEADER_X_REPO_COMMIT, 20 | HUGGINGFACE_HEADER_X_LINKED_ETAG, 21 | HUGGINGFACE_HEADER_X_LINKED_SIZE, 22 | ORIGINAL_LOC, 23 | ) 24 | from olah.cache.olah_cache import OlahCache 25 | from olah.errors import error_entry_not_found, error_proxy_invalid_data, error_proxy_timeout 26 | from olah.proxy.pathsinfo import pathsinfo_generator 27 | from olah.utils.cache_utils import read_cache_request, write_cache_request 28 | from olah.utils.disk_utils import touch_file_access_time 29 | from olah.utils.url_utils import ( 30 | RemoteInfo, 31 | add_query_param, 32 | check_url_has_param_name, 33 | get_all_ranges, 34 | get_url_param_name, 35 | get_url_tail, 36 | parse_range_params, 37 | remove_query_param, 38 | ) 39 | from olah.utils.repo_utils import get_org_repo 40 | from olah.utils.rule_utils import check_cache_rules_hf 41 | from olah.utils.file_utils import make_dirs 42 | from olah.constants import CHUNK_SIZE, LFS_FILE_BLOCK, WORKER_API_TIMEOUT 43 | from olah.utils.zip_utils import Decompressor, decompress_data 44 | 45 | 46 | def get_block_info(pos: int, block_size: int, file_size: int) -> Tuple[int, int, int]: 47 | cur_block = pos // block_size 48 | block_start_pos = cur_block * block_size 49 | block_end_pos = min((cur_block + 1) * block_size, file_size) 50 | return cur_block, block_start_pos, block_end_pos 51 | 52 | 53 | def get_contiguous_ranges( 54 | cache_file: OlahCache, start_pos: int, end_pos: int 55 | ) -> List[Tuple[Tuple[int, int], bool]]: 56 | start_block = start_pos // cache_file._get_block_size() 57 | end_block = (end_pos - 1) // cache_file._get_block_size() 58 | 59 | range_start_pos = start_pos 60 | range_is_remote = not cache_file.has_block(start_block) 61 | cur_pos = start_pos 62 | # Get contiguous ranges: (range_start_pos, range_end_pos), is_remote 63 | ranges_and_cache_list: List[Tuple[Tuple[int, int], bool]] = [] 64 | for cur_block in range(start_block, end_block + 1): 65 | cur_block, block_start_pos, block_end_pos = get_block_info( 66 | cur_pos, cache_file._get_block_size(), cache_file._get_file_size() 67 | ) 68 | 69 | if cache_file.has_block(cur_block): 70 | cur_is_remote = False 71 | else: 72 | cur_is_remote = True 73 | if range_is_remote != cur_is_remote: 74 | if range_start_pos < cur_pos: 75 | ranges_and_cache_list.append( 76 | ((range_start_pos, cur_pos), range_is_remote) 77 | ) 78 | range_start_pos = cur_pos 79 | range_is_remote = cur_is_remote 80 | cur_pos = block_end_pos 81 | 82 | ranges_and_cache_list.append(((range_start_pos, end_pos), range_is_remote)) 83 | range_start_pos = end_pos 84 | return ranges_and_cache_list 85 | 86 | 87 | async def _get_file_range_from_cache( 88 | cache_file: OlahCache, start_pos: int, end_pos: int 89 | ): 90 | start_block = start_pos // cache_file._get_block_size() 91 | end_block = (end_pos - 1) // cache_file._get_block_size() 92 | cur_pos = start_pos 93 | for cur_block in range(start_block, end_block + 1): 94 | _, block_start_pos, block_end_pos = get_block_info( 95 | cur_pos, cache_file._get_block_size(), cache_file._get_file_size() 96 | ) 97 | if not cache_file.has_block(cur_block): 98 | raise Exception("Unknown exception: read block which has not been cached.") 99 | raw_block = await cache_file.read_block(cur_block) 100 | chunk = raw_block[ 101 | max(start_pos, block_start_pos) 102 | - block_start_pos : min(end_pos, block_end_pos) 103 | - block_start_pos 104 | ] 105 | yield chunk 106 | cur_pos += len(chunk) 107 | 108 | if cur_pos != end_pos: 109 | raise Exception("The cache range from {} to {} is incomplete.") 110 | 111 | 112 | async def _get_file_range_from_remote( 113 | client: httpx.AsyncClient, 114 | remote_info: RemoteInfo, 115 | cache_file: OlahCache, 116 | start_pos: int, 117 | end_pos: int, 118 | ): 119 | headers = {} 120 | if remote_info.headers.get("authorization", None) is not None: 121 | headers["authorization"] = remote_info.headers.get("authorization", None) 122 | headers["range"] = f"bytes={start_pos}-{end_pos - 1}" 123 | 124 | chunk_bytes = 0 125 | decompressor: Optional[Decompressor] = None 126 | async with client.stream( 127 | method=remote_info.method, 128 | url=remote_info.url, 129 | headers=headers, 130 | timeout=WORKER_API_TIMEOUT, 131 | follow_redirects=True, 132 | ) as response: 133 | status_code = response.status_code 134 | 135 | if status_code == 429: 136 | raise Exception("Too many requests in a given amount of time.") 137 | 138 | is_compressed = "content-encoding" in response.headers 139 | if is_compressed: 140 | decompressor = Decompressor(response.headers["content-encoding"].split(",")) 141 | 142 | async for raw_chunk in response.aiter_raw(): 143 | if not raw_chunk: 144 | continue 145 | if is_compressed and decompressor is not None: 146 | real_chunk = decompressor.decompress(raw_chunk) 147 | yield real_chunk 148 | chunk_bytes += len(real_chunk) 149 | else: 150 | yield raw_chunk 151 | chunk_bytes += len(raw_chunk) 152 | 153 | if is_compressed: 154 | response_content_length = chunk_bytes 155 | else: 156 | response_content_length = int(response.headers["content-length"]) 157 | 158 | # Post check 159 | if end_pos - start_pos != response_content_length: 160 | raise Exception( 161 | f"The content of the response is incomplete. File size: {cache_file._get_file_size()}. Start-end: {start_pos}-{end_pos}. Expected-{end_pos - start_pos}. Accepted-{response_content_length}" 162 | ) 163 | 164 | 165 | async def _file_chunk_get( 166 | app, 167 | save_path: str, 168 | head_path: str, 169 | client: httpx.AsyncClient, 170 | method: str, 171 | url: str, 172 | headers: Dict[str, str], 173 | allow_cache: bool, 174 | file_size: int, 175 | ): 176 | # Redirect Chunks 177 | if os.path.exists(save_path): 178 | cache_file = OlahCache(save_path) 179 | else: 180 | cache_file = OlahCache.create(save_path) 181 | cache_file.resize(file_size=file_size) 182 | 183 | # Refresh access time 184 | touch_file_access_time(save_path) 185 | 186 | try: 187 | unit, ranges, suffix = parse_range_params(headers.get("range", f"bytes={0}-{file_size-1}")) 188 | all_ranges = get_all_ranges(file_size, unit, ranges, suffix) 189 | 190 | for start_pos, end_pos in all_ranges: 191 | ranges_and_cache_list = get_contiguous_ranges(cache_file, start_pos, end_pos) 192 | # Stream ranges 193 | for (range_start_pos, range_end_pos), is_remote in ranges_and_cache_list: 194 | # range_start_pos is zero-index and range_end_pos is exclusive 195 | if is_remote: 196 | generator = _get_file_range_from_remote( 197 | client, 198 | RemoteInfo(method, url, headers), 199 | cache_file, 200 | range_start_pos, 201 | range_end_pos, 202 | ) 203 | else: 204 | generator = _get_file_range_from_cache( 205 | cache_file, 206 | range_start_pos, 207 | range_end_pos, 208 | ) 209 | 210 | cur_pos = range_start_pos 211 | stream_cache = bytearray() 212 | last_block, last_block_start_pos, last_block_end_pos = get_block_info( 213 | cur_pos, cache_file._get_block_size(), cache_file._get_file_size() 214 | ) 215 | async for chunk in generator: 216 | if len(chunk) != 0: 217 | yield bytes(chunk) 218 | stream_cache += chunk 219 | cur_pos += len(chunk) 220 | 221 | cur_block = cur_pos // cache_file._get_block_size() 222 | 223 | if cur_block == last_block: 224 | continue 225 | split_pos = last_block_end_pos - max( 226 | last_block_start_pos, range_start_pos 227 | ) 228 | raw_block = stream_cache[:split_pos] 229 | stream_cache = stream_cache[split_pos:] 230 | if len(raw_block) == cache_file._get_block_size(): 231 | if not cache_file.has_block(last_block) and allow_cache: 232 | await cache_file.write_block(last_block, raw_block) 233 | last_block, last_block_start_pos, last_block_end_pos = get_block_info( 234 | cur_pos, cache_file._get_block_size(), cache_file._get_file_size() 235 | ) 236 | 237 | raw_block = stream_cache 238 | if cur_block == cache_file._get_block_number() - 1: 239 | if ( 240 | len(raw_block) 241 | == cache_file._get_file_size() % cache_file._get_block_size() 242 | ): 243 | raw_block += b"\x00" * ( 244 | cache_file._get_block_size() - len(raw_block) 245 | ) 246 | last_block = cur_block 247 | if len(raw_block) == cache_file._get_block_size(): 248 | if not cache_file.has_block(last_block) and allow_cache: 249 | await cache_file.write_block(last_block, raw_block) 250 | 251 | if cur_pos != range_end_pos: 252 | if is_remote: 253 | raise Exception( 254 | f"The size of remote range ({range_end_pos - range_start_pos}) is different from sent size ({cur_pos - range_start_pos})." 255 | ) 256 | else: 257 | raise Exception( 258 | f"The size of cached range ({range_end_pos - range_start_pos}) is different from sent size ({cur_pos - range_start_pos})." 259 | ) 260 | finally: 261 | cache_file.close() 262 | 263 | 264 | async def _file_chunk_head( 265 | app, 266 | save_path: str, 267 | head_path: str, 268 | client: httpx.AsyncClient, 269 | method: str, 270 | url: str, 271 | headers: Dict[str, str], 272 | allow_cache: bool, 273 | file_size: int, 274 | ): 275 | if not app.state.app_settings.config.offline: 276 | async with client.stream( 277 | method=method, 278 | url=url, 279 | headers=headers, 280 | timeout=WORKER_API_TIMEOUT, 281 | ) as response: 282 | async for raw_chunk in response.aiter_raw(): 283 | if not raw_chunk: 284 | continue 285 | yield raw_chunk 286 | else: 287 | yield b"" 288 | 289 | 290 | async def _resource_etag(hf_url: str, authorization: Optional[str]=None, offline: bool = False) -> Optional[str]: 291 | ret_etag = None 292 | sha256_hash = hashlib.sha256() 293 | sha256_hash.update(hf_url.encode("utf-8")) 294 | content_hash = sha256_hash.hexdigest() 295 | if offline: 296 | ret_etag = f'"{content_hash[:32]}-10"' 297 | else: 298 | etag_headers = {} 299 | if authorization is not None: 300 | etag_headers["authorization"] = authorization 301 | try: 302 | async with httpx.AsyncClient() as client: 303 | response = await client.request( 304 | method="head", 305 | url=hf_url, 306 | headers=etag_headers, 307 | timeout=WORKER_API_TIMEOUT, 308 | ) 309 | if "etag" in response.headers: 310 | ret_etag = response.headers["etag"] 311 | else: 312 | ret_etag = f'"{content_hash[:32]}-10"' 313 | except httpx.TimeoutException: 314 | ret_etag = None 315 | return ret_etag 316 | 317 | async def _file_realtime_stream( 318 | app, 319 | repo_type: Literal["models", "datasets", "spaces"], 320 | org: str, 321 | repo: str, 322 | file_path: str, 323 | save_path: str, 324 | head_path: str, 325 | url: str, 326 | request: Request, 327 | method="GET", 328 | allow_cache=True, 329 | commit: Optional[str] = None, 330 | ): 331 | if check_url_has_param_name(url, ORIGINAL_LOC): 332 | clean_url = remove_query_param(url, ORIGINAL_LOC) 333 | original_loc = get_url_param_name(url, ORIGINAL_LOC) 334 | 335 | hf_loc = urlparse(original_loc) 336 | if len(hf_loc.netloc) != 0: 337 | hf_url = urljoin( 338 | f"{hf_loc.scheme}://{hf_loc.netloc}", get_url_tail(clean_url) 339 | ) 340 | else: 341 | hf_url = urljoin( 342 | app.state.app_settings.config.hf_lfs_url_base(), get_url_tail(clean_url) 343 | ) 344 | else: 345 | if urlparse(url).netloc in [ 346 | app.state.app_settings.config.hf_netloc, 347 | app.state.app_settings.config.hf_lfs_netloc, 348 | ]: 349 | hf_url = url 350 | else: 351 | hf_url = urljoin( 352 | app.state.app_settings.config.hf_lfs_url_base(), get_url_tail(url) 353 | ) 354 | 355 | request_headers = {k: v for k, v in request.headers.items()} 356 | if "host" in request_headers: 357 | request_headers["host"] = urlparse(hf_url).netloc 358 | 359 | generator = pathsinfo_generator( 360 | app, 361 | repo_type, 362 | org, 363 | repo, 364 | commit, 365 | [file_path], 366 | override_cache=False, 367 | method="post", 368 | authorization=request.headers.get("authorization", None), 369 | ) 370 | status_code = await generator.__anext__() 371 | headers = await generator.__anext__() 372 | content = await generator.__anext__() 373 | try: 374 | pathsinfo = json.loads(content) 375 | except json.JSONDecodeError: 376 | response = error_proxy_invalid_data() 377 | yield response.status_code 378 | yield response.headers 379 | yield response.body 380 | return 381 | 382 | if len(pathsinfo) == 0: 383 | response = error_entry_not_found() 384 | yield response.status_code 385 | yield response.headers 386 | yield response.body 387 | return 388 | 389 | if len(pathsinfo) != 1: 390 | response = error_proxy_timeout() 391 | yield response.status_code 392 | yield response.headers 393 | yield response.body 394 | return 395 | 396 | pathinfo = pathsinfo[0] 397 | if "size" not in pathinfo: 398 | response = error_proxy_timeout() 399 | yield response.status_code 400 | yield response.headers 401 | yield response.body 402 | return 403 | file_size = pathinfo["size"] 404 | 405 | response_headers = {} 406 | # Create content-length 407 | unit, ranges, suffix = parse_range_params(request_headers.get("range", f"bytes={0}-{file_size-1}")) 408 | all_ranges = get_all_ranges(file_size, unit, ranges, suffix) 409 | 410 | response_headers["content-length"] = str(sum(r[1] - r[0] for r in all_ranges)) 411 | if suffix is not None: 412 | response_headers["content-range"] = f"bytes -{suffix}/{file_size}" 413 | else: 414 | response_headers["content-range"] = f"bytes {','.join(f'{r[0]}-{r[1]-1}' for r in all_ranges)}/{file_size}" 415 | # Commit info 416 | if commit is not None: 417 | response_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = commit 418 | # Create fake headers when offline mode 419 | etag = await _resource_etag( 420 | hf_url=hf_url, 421 | authorization=request.headers.get("authorization", None), 422 | offline=app.state.app_settings.config.offline, 423 | ) 424 | response_headers["etag"] = etag 425 | 426 | if etag is None: 427 | error_response = error_proxy_timeout() 428 | yield error_response.status_code 429 | yield error_response.headers 430 | yield error_response.body 431 | return 432 | else: 433 | yield 200 434 | yield response_headers 435 | 436 | async with httpx.AsyncClient() as client: 437 | if method.lower() == "get": 438 | async for each_chunk in _file_chunk_get( 439 | app=app, 440 | save_path=save_path, 441 | head_path=head_path, 442 | client=client, 443 | method=method, 444 | url=hf_url, 445 | headers=request_headers, 446 | allow_cache=allow_cache, 447 | file_size=file_size, 448 | ): 449 | yield each_chunk 450 | elif method.lower() == "head": 451 | async for each_chunk in _file_chunk_head( 452 | app=app, 453 | save_path=save_path, 454 | head_path=head_path, 455 | client=client, 456 | method=method, 457 | url=hf_url, 458 | headers=request_headers, 459 | allow_cache=allow_cache, 460 | file_size=0, 461 | ): 462 | yield each_chunk 463 | else: 464 | raise Exception(f"Unsupported method: {method}") 465 | 466 | 467 | async def file_get_generator( 468 | app, 469 | repo_type: Literal["models", "datasets", "spaces"], 470 | org: str, 471 | repo: str, 472 | commit: str, 473 | file_path: str, 474 | method: Literal["HEAD", "GET"], 475 | request: Request, 476 | ): 477 | org_repo = get_org_repo(org, repo) 478 | # save 479 | repos_path = app.state.app_settings.config.repos_path 480 | head_path = os.path.join( 481 | repos_path, f"heads/{repo_type}/{org_repo}/resolve/{commit}/{file_path}" 482 | ) 483 | save_path = os.path.join( 484 | repos_path, f"files/{repo_type}/{org_repo}/resolve/{commit}/{file_path}" 485 | ) 486 | make_dirs(head_path) 487 | make_dirs(save_path) 488 | 489 | # use_cache = os.path.exists(head_path) and os.path.exists(save_path) 490 | allow_cache = await check_cache_rules_hf(app, repo_type, org, repo) 491 | 492 | # proxy 493 | if repo_type == "models": 494 | url = urljoin( 495 | app.state.app_settings.config.hf_url_base(), 496 | f"/{org_repo}/resolve/{commit}/{file_path}", 497 | ) 498 | else: 499 | url = urljoin( 500 | app.state.app_settings.config.hf_url_base(), 501 | f"/{repo_type}/{org_repo}/resolve/{commit}/{file_path}", 502 | ) 503 | return _file_realtime_stream( 504 | app=app, 505 | repo_type=repo_type, 506 | org=org, 507 | repo=repo, 508 | file_path=file_path, 509 | save_path=save_path, 510 | head_path=head_path, 511 | url=url, 512 | request=request, 513 | method=method, 514 | allow_cache=allow_cache, 515 | commit=commit, 516 | ) 517 | 518 | 519 | async def cdn_file_get_generator( 520 | app, 521 | repo_type: Literal["models", "datasets", "spaces"], 522 | org: str, 523 | repo: str, 524 | file_hash: str, 525 | method: Literal["HEAD", "GET"], 526 | request: Request, 527 | ): 528 | headers = {k: v for k, v in request.headers.items()} 529 | headers.pop("host") 530 | 531 | org_repo = get_org_repo(org, repo) 532 | # save 533 | repos_path = app.state.app_settings.config.repos_path 534 | head_path = os.path.join( 535 | repos_path, f"heads/{repo_type}/{org_repo}/cdn/{file_hash}" 536 | ) 537 | save_path = os.path.join( 538 | repos_path, f"files/{repo_type}/{org_repo}/cdn/{file_hash}" 539 | ) 540 | make_dirs(head_path) 541 | make_dirs(save_path) 542 | 543 | # use_cache = os.path.exists(head_path) and os.path.exists(save_path) 544 | allow_cache = await check_cache_rules_hf(app, repo_type, org, repo) 545 | 546 | # proxy 547 | # request_url = urlparse(str(request.url)) 548 | # if request_url.netloc == app.state.app_settings.config.hf_lfs_netloc: 549 | # redirected_url = urljoin(app.state.app_settings.config.mirror_lfs_url_base(), get_url_tail(request_url)) 550 | # else: 551 | # redirected_url = urljoin(app.state.app_settings.config.mirror_url_base(), get_url_tail(request_url)) 552 | 553 | return _file_realtime_stream( 554 | app=app, 555 | save_path=save_path, 556 | head_path=head_path, 557 | url=str(request.url), 558 | request=request, 559 | method=method, 560 | allow_cache=allow_cache, 561 | ) 562 | --------------------------------------------------------------------------------