├── tests ├── .gitignore └── simple_test.py ├── src └── olah │ ├── __init__.py │ ├── auth │ └── __init__.py │ ├── cache │ ├── __init__.py │ ├── stat.py │ ├── bitset.py │ └── olah_cache.py │ ├── mirror │ ├── __init__.py │ ├── meta.py │ └── repos.py │ ├── proxy │ ├── __init__.py │ ├── lfs.py │ ├── commits.py │ ├── meta.py │ ├── tree.py │ ├── pathsinfo.py │ └── files.py │ ├── utils │ ├── __init__.py │ ├── file_utils.py │ ├── olah_utils.py │ ├── rule_utils.py │ ├── cache_utils.py │ ├── disk_utils.py │ ├── logging.py │ ├── zip_utils.py │ ├── url_utils.py │ └── repo_utils.py │ ├── database │ ├── __init__.py │ └── models.py │ ├── constants.py │ ├── static │ ├── repos.html │ └── index.html │ ├── errors.py │ └── configs.py ├── docker ├── up │ ├── repos │ │ └── .gitignore │ ├── mirrors │ │ └── .gitignore │ ├── docker-compose.yml │ └── configs.toml ├── build@source │ └── dockerfile └── build@pypi │ └── dockerfile ├── environment.yml ├── .dockerignore ├── docs ├── zh │ ├── main.md │ └── quickstart.md └── en │ ├── main.md │ └── quickstart.md ├── requirements.txt ├── .vscode ├── launch.json └── tasks.json ├── .github ├── dependabot.yml └── workflows │ ├── dev.yml │ ├── release.yml │ ├── docker-image-tag-commit.yml │ └── docker-image-tag-version.yml ├── CONTRIBUTING.md ├── LICENSE ├── assets └── full_configs.toml ├── pyproject.toml ├── .gitignore ├── README_zh.md └── README.md /tests/.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/olah/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/olah/auth/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/olah/cache/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/olah/mirror/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/olah/proxy/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/olah/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/olah/database/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docker/up/repos/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /docker/up/mirrors/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: olah 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python=3.12 6 | 7 | - pip: 8 | - -e . 9 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .git/ 2 | .github/ 3 | .vscode/ 4 | assets/ 5 | benchmark/ 6 | cache/ 7 | docker/ 8 | docs/ 9 | dist/ 10 | logs/ 11 | mirrors_dir/ 12 | olah.egg-info/ 13 | playground/ 14 | scripts/ 15 | tests/ 16 | repos/ 17 | .dockerignore 18 | .gitignore 19 | environment.yml -------------------------------------------------------------------------------- /docs/zh/main.md: -------------------------------------------------------------------------------- 1 |
5 | 自托管的轻量级HuggingFace镜像服务 6 | 7 | Olah是开源的自托管轻量级HuggingFace镜像服务。`Olah`来源于丘丘人语,在丘丘人语中意味着`你好`。 8 | Olah真正地实现了huggingface资源的`镜像`功能,而不仅仅是一个简单的`反向代理`。 9 | Olah并不会立刻对huggingface全站进行镜像,而是在用户下载的同时在文件块级别对资源进行镜像(或者我们可以说是缓存)。 10 | -------------------------------------------------------------------------------- /docker/build@source/dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.12 2 | 3 | WORKDIR /app 4 | 5 | RUN pip3 install --upgrade pip 6 | 7 | COPY . /app 8 | RUN pip3 install --no-cache-dir -e . 9 | 10 | EXPOSE 8090 11 | 12 | VOLUME /data/repos 13 | VOLUME /data/mirrors 14 | 15 | ENTRYPOINT [ "olah-cli" ] 16 | 17 | CMD ["--repos-path", "/repos"] 18 | -------------------------------------------------------------------------------- /docker/up/docker-compose.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | services: 4 | olah: 5 | image: xiahan2019/olah:lastet 6 | container_name: olah 7 | command: -c /app/configs.toml 8 | ports: 9 | - 8090:8090 10 | volumes: 11 | - ./configs.toml:/app/configs.toml 12 | - ./mirrors:/data/mirrors 13 | - ./repos:/data/repos 14 | -------------------------------------------------------------------------------- /docker/build@pypi/dockerfile: -------------------------------------------------------------------------------- 1 | ARG OLAH_SOURCE=olah==0.3.3 2 | 3 | FROM python:3.12 4 | 5 | ARG OLAH_SOURCE 6 | 7 | WORKDIR /app 8 | 9 | RUN pip3 install --upgrade pip 10 | 11 | RUN pip install --no-cache-dir ${OLAH_SOURCE} 12 | 13 | EXPOSE 8090 14 | 15 | VOLUME /data/repos 16 | VOLUME /data/mirrors 17 | 18 | ENTRYPOINT [ "olah-cli" ] 19 | 20 | CMD ["--repos-path", "/repos"] 21 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi==0.115.2 2 | fastapi-utils==0.7.0 3 | GitPython==3.1.43 4 | httpx==0.27.0 5 | pydantic==2.8.2 6 | pydantic-settings==2.4.0 7 | toml==0.10.2 8 | huggingface_hub==0.26.0 9 | pytest==8.3.3 10 | cachetools==5.4.0 11 | PyYAML==6.0.1 12 | tenacity==8.5.0 13 | peewee==3.17.6 14 | typing_inspect==0.9.0 15 | jinja2==3.1.4 16 | python-multipart==0.0.9 17 | portalocker==3.1.1 18 | aiofiles==24.1.0 -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.2.0", 3 | "configurations": [ 4 | { 5 | "name": "debugpy: server", 6 | "type": "debugpy", 7 | "request": "launch", 8 | "program": "${workspaceFolder}/olah/server.py", 9 | "console": "integratedTerminal", 10 | "args": ["-c", "${workspaceFolder}/assets/full_configs.toml"], 11 | "justMyCode": false 12 | } 13 | ] 14 | } -------------------------------------------------------------------------------- /src/olah/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | import os 9 | 10 | 11 | def make_dirs(path: str): 12 | if os.path.isdir(path): 13 | save_dir = path 14 | else: 15 | save_dir = os.path.dirname(path) 16 | if not os.path.exists(save_dir): 17 | os.makedirs(save_dir, exist_ok=True) 18 | -------------------------------------------------------------------------------- /src/olah/utils/olah_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | import platform 9 | import os 10 | 11 | 12 | def get_olah_path() -> str: 13 | if platform.system() == "Windows": 14 | olah_path = os.path.expanduser("~\\.olah") 15 | else: 16 | olah_path = os.path.expanduser("~/.olah") 17 | return olah_path 18 | -------------------------------------------------------------------------------- /docs/en/main.md: -------------------------------------------------------------------------------- 1 |
4 | Self-hosted Lightweight Huggingface Mirror Service 5 | 6 | Olah is a self-hosted lightweight huggingface mirror service. `Olah` means `hello` in Hilichurlian. 7 | Olah implemented the `mirroring` feature for huggingface resources, rather than just a simple `reverse proxy`. 8 | Olah does not immediately mirror the entire huggingface website but mirrors the resources at the file block level when users download them (or we can say cache them). 9 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "pip" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "weekly" 12 | -------------------------------------------------------------------------------- /src/olah/constants.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | import os 9 | 10 | 11 | WORKER_API_TIMEOUT = 15 12 | CHUNK_SIZE = 4096 13 | LFS_FILE_BLOCK = 64 * 1024 * 1024 14 | 15 | DEFAULT_LOGGER_DIR = "./logs" 16 | OLAH_CODE_DIR = os.path.dirname(os.path.abspath(__file__)) 17 | 18 | ORIGINAL_LOC = "oriloc" 19 | 20 | from huggingface_hub.constants import ( 21 | REPO_TYPES_MAPPING, 22 | HUGGINGFACE_CO_URL_TEMPLATE, 23 | HUGGINGFACE_HEADER_X_REPO_COMMIT, 24 | HUGGINGFACE_HEADER_X_LINKED_ETAG, 25 | HUGGINGFACE_HEADER_X_LINKED_SIZE, 26 | ) 27 | -------------------------------------------------------------------------------- /tests/simple_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import time 4 | 5 | from huggingface_hub import snapshot_download 6 | 7 | def test_dataset(): 8 | process = subprocess.Popen(['python', '-m', 'olah.server']) 9 | 10 | os.environ['HF_ENDPOINT'] = 'http://localhost:8090' 11 | snapshot_download(repo_id='Nerfgun3/bad_prompt', repo_type='dataset', 12 | local_dir='./dataset_dir', max_workers=8) 13 | 14 | process.terminate() 15 | 16 | def test_model(): 17 | process = subprocess.Popen(['python', '-m', 'olah.server']) 18 | 19 | os.environ['HF_ENDPOINT'] = 'http://localhost:8090' 20 | snapshot_download(repo_id='prajjwal1/bert-tiny', repo_type='model', 21 | local_dir='./model_dir', max_workers=8) 22 | 23 | process.terminate() -------------------------------------------------------------------------------- /docs/zh/quickstart.md: -------------------------------------------------------------------------------- 1 | ## 快速开始 2 | 在控制台运行以下命令: 3 | ```bash 4 | python -m olah.server 5 | ``` 6 | 7 | 然后将环境变量`HF_ENDPOINT`设置为镜像站点(这里是http://localhost:8090/)。 8 | 9 | Linux: 10 | ```bash 11 | export HF_ENDPOINT=http://localhost:8090 12 | ``` 13 | 14 | Windows Powershell: 15 | ```bash 16 | $env:HF_ENDPOINT = "http://localhost:8090" 17 | ``` 18 | 19 | 从现在开始,HuggingFace库中的所有下载操作都将通过此镜像站点代理进行。 20 | ```bash 21 | pip install -U huggingface_hub 22 | ``` 23 | 24 | ```python 25 | from huggingface_hub import snapshot_download 26 | 27 | snapshot_download(repo_id='Qwen/Qwen-7B', repo_type='model', 28 | local_dir='./model_dir', resume_download=True, 29 | max_workers=8) 30 | 31 | ``` 32 | 33 | 或者你也可以使用huggingface cli直接下载模型和数据集. 34 | 35 | 下载GPT2: 36 | ```bash 37 | huggingface-cli download --resume-download openai-community/gpt2 --local-dir gpt2 38 | ``` 39 | 40 | 下载WikiText: 41 | ```bash 42 | huggingface-cli download --repo-type dataset --resume-download Salesforce/wikitext --local-dir wikitext 43 | ``` 44 | 45 | 您可以查看路径`./repos`,其中存储了所有数据集和模型的缓存。 46 | -------------------------------------------------------------------------------- /src/olah/utils/rule_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | 9 | from typing import Dict, Literal, Optional, Tuple, Union 10 | 11 | from fastapi import FastAPI 12 | from olah.configs import OlahConfig 13 | from .repo_utils import get_org_repo 14 | 15 | 16 | async def check_proxy_rules_hf( 17 | app: FastAPI, 18 | repo_type: Optional[Literal["models", "datasets", "spaces"]], 19 | org: Optional[str], 20 | repo: str, 21 | ) -> bool: 22 | config: OlahConfig = app.state.app_settings.config 23 | org_repo = get_org_repo(org, repo) 24 | return config.proxy.allow(org_repo) 25 | 26 | 27 | async def check_cache_rules_hf( 28 | app: FastAPI, 29 | repo_type: Optional[Literal["models", "datasets", "spaces"]], 30 | org: Optional[str], 31 | repo: str, 32 | ) -> bool: 33 | config: OlahConfig = app.state.app_settings.config 34 | org_repo = get_org_repo(org, repo) 35 | return config.cache.allow(org_repo) 36 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How can I contribute to Olah? 2 | 3 | Everyone is welcome to contribute, and we value everybody's contribution. Code contributions are not the only way to help the community. Answering questions, helping others, and improving the documentation are also immensely valuable. 4 | 5 | It also helps us if you spread the word! Reference the library in blog posts about the awesome projects it made possible, shout out on Twitter every time it has helped you, or simply ⭐️ the repository to say thank you. 6 | 7 | However you choose to contribute, please be mindful and respect our code of conduct. 8 | 9 | ## Ways to contribute 10 | 11 | There are lots of ways you can contribute to Olah: 12 | * Submitting issues on Github to report bugs or make feature requests 13 | * Fixing outstanding issues with the existing code 14 | * Implementing new features 15 | * Contributing to the examples or to the documentation 16 | 17 | *All are equally valuable to the community.* 18 | 19 | #### This guide was heavily inspired by the awesome [transformers guide to contributing](https://github.com/huggingface/transformers/blob/master/CONTRIBUTING.md) 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Vtuber Plan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/olah/database/models.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | import os 9 | from peewee import * 10 | import datetime 11 | 12 | from olah.utils.olah_utils import get_olah_path 13 | 14 | db_path = os.path.join(get_olah_path(), "database.db") 15 | db = SqliteDatabase(db_path) 16 | 17 | class BaseModel(Model): 18 | class Meta: 19 | database = db 20 | 21 | class Token(BaseModel): 22 | token = CharField(unique=True) 23 | first_dt = DateTimeField() 24 | last_dt = DateTimeField() 25 | 26 | class DownloadLogs(BaseModel): 27 | id = CharField(unique=True) 28 | org = CharField() 29 | repo = CharField() 30 | path = CharField() 31 | range_start = BigIntegerField() 32 | range_end = BigIntegerField() 33 | datetime = DateTimeField() 34 | token = CharField() 35 | 36 | class FileLevelLRU(BaseModel): 37 | org = CharField() 38 | repo = CharField() 39 | path = CharField() 40 | datetime = DateTimeField(default=datetime.datetime.now) 41 | 42 | db.connect() 43 | db.create_tables([ 44 | Token, 45 | DownloadLogs, 46 | FileLevelLRU, 47 | ]) 48 | -------------------------------------------------------------------------------- /assets/full_configs.toml: -------------------------------------------------------------------------------- 1 | [basic] 2 | host = "localhost" 3 | port = 8090 4 | ssl-key = "" 5 | ssl-cert = "" 6 | repos-path = "./repos" 7 | cache-size-limit = "" 8 | cache-clean-strategy = "LRU" 9 | hf-scheme = "https" 10 | hf-netloc = "huggingface.co" 11 | hf-lfs-netloc = "cdn-lfs.huggingface.co" 12 | mirror-scheme = "http" 13 | mirror-netloc = "localhost:8090" 14 | mirror-lfs-netloc = "localhost:8090" 15 | mirrors-path = ["./mirrors_dir"] 16 | 17 | [accessibility] 18 | offline = false 19 | 20 | # allow other or will be in whitelist mode. 21 | [[accessibility.proxy]] 22 | repo = "*" 23 | allow = true 24 | 25 | [[accessibility.proxy]] 26 | repo = "*/*" 27 | allow = true 28 | 29 | [[accessibility.proxy]] 30 | repo = "cais/mmlu" 31 | allow = true 32 | 33 | [[accessibility.proxy]] 34 | repo = "adept/fuyu-8b" 35 | allow = false 36 | 37 | [[accessibility.proxy]] 38 | repo = "mistralai/*" 39 | allow = true 40 | 41 | [[accessibility.proxy]] 42 | repo = "mistralai/Mistral.*" 43 | allow = false 44 | use_re = true 45 | 46 | # allow other or will be in whitelist mode. 47 | [[accessibility.cache]] 48 | repo = "*" 49 | allow = true 50 | 51 | [[accessibility.cache]] 52 | repo = "*/*" 53 | allow = true 54 | 55 | [[accessibility.cache]] 56 | repo = "cais/mmlu" 57 | allow = true 58 | 59 | [[accessibility.cache]] 60 | repo = "adept/fuyu-8b" 61 | allow = false 62 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "olah" 7 | version = "0.4.1" 8 | description = "Self-hosted lightweight huggingface mirror." 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: MIT License", 14 | ] 15 | dependencies = [ 16 | "fastapi", "fastapi-utils", "httpx", "numpy", "pydantic<=2.8.2", "pydantic-settings<=2.4.0", "requests", "toml", 17 | "rich>=10.0.0", "shortuuid", "uvicorn", "tenacity>=8.2.2", "pytz", "cachetools", "GitPython", 18 | "PyYAML", "typing_inspect>=0.9.0", "huggingface_hub", "jinja2", "python-multipart", "portalocker", 19 | "aiofiles", "brotli" 20 | ] 21 | 22 | [project.optional-dependencies] 23 | dev = ["black==24.10.0", "pylint==3.3.1", "pytest==8.3.3"] 24 | 25 | [project.urls] 26 | "Homepage" = "https://github.com/vtuber-plan/olah" 27 | "Bug Tracker" = "https://github.com/vtuber-plan/olah/issues" 28 | 29 | [project.scripts] 30 | olah-cli = "olah.server:cli" 31 | 32 | [tool.setuptools.packages.find] 33 | where = ["src"] 34 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 35 | 36 | [tool.wheel] 37 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 38 | -------------------------------------------------------------------------------- /.github/workflows/dev.yml: -------------------------------------------------------------------------------- 1 | name: Olah GitHub Actions for Development 2 | run-name: Olah GitHub Actions for Development 3 | on: 4 | push: 5 | branches: [ "dev" ] 6 | 7 | jobs: 8 | build: 9 | runs-on: ubuntu-latest 10 | strategy: 11 | matrix: 12 | python-version: ["3.9", "3.10", "3.11", "3.12"] 13 | 14 | steps: 15 | - name: Check out repository code 16 | uses: actions/checkout@v4 17 | - name: Set up Apache Arrow 18 | run: | 19 | sudo apt update 20 | sudo apt install -y -V ca-certificates lsb-release wget 21 | wget https://apache.jfrog.io/artifactory/arrow/$(lsb_release --id --short | tr 'A-Z' 'a-z')/apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb 22 | sudo apt install -y -V ./apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb 23 | sudo apt update 24 | sudo apt install -y -V libarrow-dev libarrow-glib-dev libparquet-dev libparquet-glib-dev 25 | - name: Set up Python ${{ matrix.python-version }} 26 | uses: actions/setup-python@v5 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | - name: Install Olah 30 | run: | 31 | cd ${{ github.workspace }} 32 | pip install --upgrade pip 33 | pip install -e . 34 | pip install -r requirements.txt 35 | 36 | - name: Test Olah 37 | run: | 38 | cd ${{ github.workspace }} 39 | python -m pytest tests 40 | -------------------------------------------------------------------------------- /src/olah/mirror/meta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | 9 | from typing import Any, Dict 10 | 11 | 12 | class RepoMeta(object): 13 | def __init__(self) -> None: 14 | self._id = None 15 | self.id = None 16 | self.author = None 17 | self.sha = None 18 | self.lastModified = None 19 | self.private = False 20 | self.gated = False 21 | self.disabled = False 22 | self.tags = [] 23 | self.description = "" 24 | self.paperswithcode_id = None 25 | self.downloads = 0 26 | self.likes = 0 27 | self.cardData = None 28 | self.siblings = None 29 | self.createdAt = None 30 | 31 | def to_dict(self) -> Dict[str, Any]: 32 | return { 33 | "_id": self._id, 34 | "id": self.id, 35 | "author": self.author, 36 | "sha": self.sha, 37 | "lastModified": self.lastModified, 38 | "private": self.private, 39 | "gated": self.gated, 40 | "disabled": self.disabled, 41 | "tags": self.tags, 42 | "description": self.description, 43 | "paperswithcode_id": self.paperswithcode_id, 44 | "downloads": self.downloads, 45 | "likes": self.likes, 46 | "cardData": self.cardData, 47 | "siblings": self.siblings, 48 | "createdAt": self.createdAt, 49 | } 50 | -------------------------------------------------------------------------------- /docker/up/configs.toml: -------------------------------------------------------------------------------- 1 | [basic] 2 | host = "0.0.0.0" 3 | port = 8090 4 | ssl-key = "" 5 | ssl-cert = "" 6 | repos-path = "/data/repos" 7 | cache-size-limit = "" 8 | cache-clean-strategy = "LRU" 9 | hf-scheme = "https" 10 | hf-netloc = "huggingface.co" 11 | hf-lfs-netloc = "cdn-lfs.huggingface.co" 12 | mirror-scheme = "http" 13 | mirror-netloc = "localhost:8090" 14 | mirror-lfs-netloc = "localhost:8090" 15 | mirrors-path = ["/data/mirrors"] 16 | 17 | 18 | [accessibility] 19 | offline = false 20 | 21 | 22 | [[accessibility.proxy]] 23 | repo = "*" 24 | allow = true 25 | 26 | [[accessibility.proxy]] 27 | repo = "*/*" 28 | allow = true 29 | 30 | [[accessibility.proxy]] 31 | repo = "vikp/surya_det3" 32 | allow = false 33 | 34 | # [[accessibility.proxy]] 35 | # repo = "vikp/surya_layout3" 36 | # allow = true 37 | 38 | # [[accessibility.proxy]] 39 | # repo = "vikp/surya_order" 40 | # allow = true 41 | 42 | # [[accessibility.proxy]] 43 | # repo = "vikp/surya_rec2" 44 | # allow = true 45 | 46 | # [[accessibility.proxy]] 47 | # repo = "vikp/surya_tablerec" 48 | # allow = true 49 | 50 | 51 | [[accessibility.cache]] 52 | repo = "*" 53 | allow = true 54 | 55 | [[accessibility.cache]] 56 | repo = "*/*" 57 | allow = true 58 | 59 | [[accessibility.cache]] 60 | repo = "vikp/surya_det3" 61 | allow = false 62 | 63 | # [[accessibility.cache]] 64 | # repo = "vikp/surya_layout3" 65 | # allow = true 66 | 67 | # [[accessibility.cache]] 68 | # repo = "vikp/surya_order" 69 | # allow = true 70 | 71 | # [[accessibility.cache]] 72 | # repo = "vikp/surya_rec2" 73 | # allow = true 74 | 75 | # [[accessibility.cache]] 76 | # repo = "vikp/surya_tablerec" 77 | # allow = true 78 | -------------------------------------------------------------------------------- /src/olah/utils/cache_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | 9 | import json 10 | from typing import Dict, Mapping, Union 11 | 12 | 13 | async def write_cache_request( 14 | save_path: str, 15 | status_code: int, 16 | headers: Union[Dict[str, str], Mapping], 17 | content: bytes, 18 | ) -> None: 19 | """ 20 | Write the request's status code, headers, and content to a cache file. 21 | 22 | Args: 23 | head_path (str): The path to the cache file. 24 | status_code (int): The status code of the request. 25 | headers (Dict[str, str]): The dictionary of response headers. 26 | content (bytes): The content of the request. 27 | 28 | Returns: 29 | None 30 | """ 31 | if not isinstance(headers, dict): 32 | headers = {k.lower(): v for k, v in headers.items()} 33 | rq = { 34 | "status_code": status_code, 35 | "headers": headers, 36 | "content": content.hex(), 37 | } 38 | with open(save_path, "w", encoding="utf-8") as f: 39 | f.write(json.dumps(rq, ensure_ascii=False)) 40 | 41 | 42 | async def read_cache_request(save_path: str) -> Dict[str, str]: 43 | """ 44 | Read the request's status code, headers, and content from a cache file. 45 | 46 | Args: 47 | save_path (str): The path to the cache file. 48 | 49 | Returns: 50 | Dict[str, str]: A dictionary containing the status code, headers, and content of the request. 51 | """ 52 | with open(save_path, "r", encoding="utf-8") as f: 53 | rq = json.loads(f.read()) 54 | 55 | rq["content"] = bytes.fromhex(rq["content"]) 56 | return rq 57 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Olah GitHub Actions to release 2 | run-name: Olah GitHub Actions to release 3 | on: 4 | push: 5 | tags: 6 | - "v*" 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python-version: ["3.12"] 14 | 15 | steps: 16 | - name: Check out repository code 17 | uses: actions/checkout@v4 18 | - name: Set up Apache Arrow 19 | run: | 20 | sudo apt update 21 | sudo apt install -y -V ca-certificates lsb-release wget 22 | wget https://apache.jfrog.io/artifactory/arrow/$(lsb_release --id --short | tr 'A-Z' 'a-z')/apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb 23 | sudo apt install -y -V ./apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb 24 | sudo apt update 25 | sudo apt install -y -V libarrow-dev libarrow-glib-dev libparquet-dev libparquet-glib-dev 26 | - name: Set up Python ${{ matrix.python-version }} 27 | uses: actions/setup-python@v5 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | - name: Install Olah 31 | run: | 32 | cd ${{ github.workspace }} 33 | pip install --upgrade pip 34 | pip install -e . 35 | pip install -r requirements.txt 36 | 37 | - name: Test Olah 38 | run: | 39 | cd ${{ github.workspace }} 40 | python -m pytest tests 41 | 42 | - name: Build Olah 43 | run: | 44 | cd ${{ github.workspace }} 45 | pip install build 46 | python -m build 47 | 48 | - name: Release 49 | uses: "marvinpinto/action-automatic-releases@latest" 50 | with: 51 | repo_token: "${{ secrets.GITHUB_TOKEN }}" 52 | prerelease: true 53 | files: | 54 | dist/*.tar.gz 55 | dist/*.whl -------------------------------------------------------------------------------- /docs/en/quickstart.md: -------------------------------------------------------------------------------- 1 | 2 | ## Quick Start 3 | Run the command in the console: 4 | ```bash 5 | python -m olah.server 6 | ``` 7 | 8 | Then set the Environment Variable `HF_ENDPOINT` to the mirror site (Here is http://localhost:8090). 9 | 10 | Linux: 11 | ```bash 12 | export HF_ENDPOINT=http://localhost:8090 13 | ``` 14 | 15 | Windows Powershell: 16 | ```bash 17 | $env:HF_ENDPOINT = "http://localhost:8090" 18 | ``` 19 | 20 | Starting from now on, all download operations in the HuggingFace library will be proxied through this mirror site. 21 | ```bash 22 | pip install -U huggingface_hub 23 | ``` 24 | 25 | ```python 26 | from huggingface_hub import snapshot_download 27 | 28 | snapshot_download(repo_id='Qwen/Qwen-7B', repo_type='model', 29 | local_dir='./model_dir', resume_download=True, 30 | max_workers=8) 31 | ``` 32 | 33 | Or you can download models and datasets by using huggingface cli. 34 | 35 | Download GPT2: 36 | ```bash 37 | huggingface-cli download --resume-download openai-community/gpt2 --local-dir gpt2 38 | ``` 39 | 40 | Download WikiText: 41 | ```bash 42 | huggingface-cli download --repo-type dataset --resume-download Salesforce/wikitext --local-dir wikitext 43 | ``` 44 | 45 | You can check the path `./repos`, in which olah stores all cached datasets and models. 46 | 47 | ## Start the server 48 | Run the command in the console: 49 | ```bash 50 | python -m olah.server 51 | ``` 52 | 53 | Or you can specify the host address and listening port: 54 | ```bash 55 | python -m olah.server --host localhost --port 8090 56 | ``` 57 | **Note: Please change --mirror-netloc and --mirror-lfs-netloc to the actual URLs of the mirror sites when modifying the host and port.** 58 | ```bash 59 | python -m olah.server --host 192.168.1.100 --port 8090 --mirror-netloc 192.168.1.100:8090 60 | ``` 61 | 62 | The default mirror cache path is `./repos`, you can change it by `--repos-path` parameter: 63 | ```bash 64 | python -m olah.server --host localhost --port 8090 --repos-path ./hf_mirrors 65 | ``` 66 | 67 | **Note that the cached data between different versions cannot be migrated. Please delete the cache folder before upgrading to the latest version of Olah.** 68 | -------------------------------------------------------------------------------- /src/olah/proxy/lfs.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | import os 9 | from typing import Literal 10 | from fastapi import FastAPI, Header, Request 11 | 12 | from olah.proxy.files import _file_realtime_stream 13 | from olah.utils.file_utils import make_dirs 14 | 15 | 16 | async def lfs_head_generator( 17 | app, dir1: str, dir2: str, hash_repo: str, hash_file: str, request: Request 18 | ): 19 | # save 20 | repos_path = app.state.app_settings.config.repos_path 21 | head_path = os.path.join( 22 | repos_path, f"lfs/heads/{dir1}/{dir2}/{hash_repo}/{hash_file}" 23 | ) 24 | save_path = os.path.join( 25 | repos_path, f"lfs/files/{dir1}/{dir2}/{hash_repo}/{hash_file}" 26 | ) 27 | make_dirs(head_path) 28 | make_dirs(save_path) 29 | 30 | # use_cache = os.path.exists(head_path) and os.path.exists(save_path) 31 | allow_cache = True 32 | 33 | # proxy 34 | return _file_realtime_stream( 35 | app=app, 36 | save_path=save_path, 37 | head_path=head_path, 38 | url=str(request.url), 39 | request=request, 40 | method="HEAD", 41 | allow_cache=allow_cache, 42 | commit=None, 43 | ) 44 | 45 | 46 | async def lfs_get_generator( 47 | app, dir1: str, dir2: str, hash_repo: str, hash_file: str, request: Request 48 | ): 49 | # save 50 | repos_path = app.state.app_settings.config.repos_path 51 | head_path = os.path.join( 52 | repos_path, f"lfs/heads/{dir1}/{dir2}/{hash_repo}/{hash_file}" 53 | ) 54 | save_path = os.path.join( 55 | repos_path, f"lfs/files/{dir1}/{dir2}/{hash_repo}/{hash_file}" 56 | ) 57 | make_dirs(head_path) 58 | make_dirs(save_path) 59 | 60 | # use_cache = os.path.exists(head_path) and os.path.exists(save_path) 61 | allow_cache = True 62 | 63 | # proxy 64 | return _file_realtime_stream( 65 | app=app, 66 | save_path=save_path, 67 | head_path=head_path, 68 | url=str(request.url), 69 | request=request, 70 | method="GET", 71 | allow_cache=allow_cache, 72 | commit=None, 73 | ) 74 | -------------------------------------------------------------------------------- /src/olah/cache/stat.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | import argparse 9 | import os 10 | import sys 11 | from olah.cache.olah_cache import OlahCache 12 | 13 | def get_size_human(size: int) -> str: 14 | if size > 1024 * 1024 * 1024: 15 | return f"{int(size / (1024 * 1024 * 1024)):.4f}GB" 16 | elif size > 1024 * 1024: 17 | return f"{int(size / (1024 * 1024)):.4f}MB" 18 | elif size > 1024: 19 | return f"{int(size / (1024)):.4f}KB" 20 | else: 21 | return f"{size:.4f}B" 22 | 23 | def insert_newlines(input_str, every=10): 24 | return '\n'.join(input_str[i:i+every] for i in range(0, len(input_str), every)) 25 | 26 | if __name__ == "__main__": 27 | parser = argparse.ArgumentParser(description="Olah Cache Visualization Tool.") 28 | parser.add_argument("--file", "-f", type=str, required=True, help="The path of Olah cache file") 29 | parser.add_argument("--export", "-e", type=str, default="", help="Export the cached file if all blocks are cached") 30 | args = parser.parse_args() 31 | print(args) 32 | 33 | with open(args.file, "rb") as f: 34 | f.seek(0, os.SEEK_END) 35 | bin_size = f.tell() 36 | 37 | try: 38 | cache = OlahCache(args.file) 39 | except Exception as e: 40 | print(e) 41 | sys.exit(1) 42 | print(f"File: {args.file}") 43 | print(f"Olah Cache Version: {cache.header.version}") 44 | print(f"File Size: {get_size_human(cache.header.file_size)}") 45 | print(f"Cache Total Size: {get_size_human(bin_size)}") 46 | print(f"Block Size: {cache.header.block_size}") 47 | print(f"Block Number: {cache.header.block_number}") 48 | print(f"Cache Status: ") 49 | cache_status = cache.header.block_mask.__str__()[:cache.header._block_number] 50 | print(insert_newlines(cache_status, every=50)) 51 | 52 | if args.export != "": 53 | if all([c == "1" for c in cache_status]): 54 | with open(args.file, "rb") as f: 55 | f.seek(cache._get_header_size(), os.SEEK_SET) 56 | with open(args.export, "wb") as fout: 57 | fout.write(f.read()) 58 | else: 59 | print("Some blocks are not cached, so the export is skipped.") -------------------------------------------------------------------------------- /src/olah/static/repos.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 |
5 | 7 | 24 | 25 | 26 | 27 |Set the Environment Variable HF_ENDPOINT to the mirror site (Here is
76 | {{scheme}}://{{netloc}}).
Linux:
79 |export HF_ENDPOINT={{scheme}}://{{netloc}}
80 |
81 | Windows Powershell:
82 |$env:HF_ENDPOINT = "{{scheme}}://{{netloc}}"
83 |
84 | Starting from now on, all download operations in the HuggingFace library will be proxied through this 85 | mirror site.
86 | 87 |
88 | from huggingface_hub import snapshot_download
89 |
90 | snapshot_download(repo_id='Qwen/Qwen-7B', repo_type='model',
91 | local_dir='./model_dir', resume_download=True,
92 | max_workers=8)
93 |
94 |
95 | Or you can download models and datasets by using huggingface cli.
96 | 97 |pip install -U huggingface_hub
98 |
99 | Download GPT2:
100 |huggingface-cli download --resume-download openai-community/gpt2 --local-dir gpt2
101 |
102 | Download WikiText:
103 |huggingface-cli download --repo-type dataset --resume-download Salesforce/wikitext --local-dir wikitext
104 | 5 | 自托管的轻量级HuggingFace镜像服务 6 | 7 | Olah是开源的自托管轻量级HuggingFace镜像服务。`Olah`来源于丘丘人语,在丘丘人语中意味着`你好`。 8 | Olah真正地实现了huggingface资源的`镜像`功能,而不仅仅是一个简单的`反向代理`。 9 | Olah并不会立刻对huggingface全站进行镜像,而是在用户下载的同时在文件块级别对资源进行镜像(或者我们可以说是缓存)。 10 | 11 | ## Olah的优势 12 | Olah能够在用户下载的同时分块缓存文件。当第二次下载时,直接从缓存中读取,极大地提升了下载速度并节约了流量。 13 | 同时Olah提供了丰富的缓存控制策略,管理员可以通过配置文件设置哪些仓库可以访问,哪些仓库可以缓存。 14 | 15 | ## 特性 16 | * 数据缓存,减少下载流量 17 | * 模型镜像 18 | * 数据集镜像 19 | * 空间镜像 20 | 21 | ## 安装 22 | 23 | ### 方法1:使用pip 24 | 25 | ```bash 26 | pip install olah 27 | ``` 28 | 29 | 或者: 30 | 31 | ```bash 32 | pip install git+https://github.com/vtuber-plan/olah.git 33 | ``` 34 | 35 | ### 方法2:从源代码安装 36 | 37 | 1. 克隆这个仓库 38 | ```bash 39 | git clone https://github.com/vtuber-plan/olah.git 40 | cd olah 41 | ``` 42 | 43 | 2. 安装包 44 | ```bash 45 | pip install --upgrade pip 46 | pip install -e . 47 | ``` 48 | 49 | ## 快速开始 50 | 在控制台运行以下命令: 51 | ```bash 52 | python -m olah.server 53 | ``` 54 | 55 | 然后将环境变量`HF_ENDPOINT`设置为镜像站点(这里是http://localhost:8090/)。 56 | 57 | Linux: 58 | ```bash 59 | export HF_ENDPOINT=http://localhost:8090 60 | ``` 61 | 62 | Windows Powershell: 63 | ```bash 64 | $env:HF_ENDPOINT = "http://localhost:8090" 65 | ``` 66 | 67 | 从现在开始,HuggingFace库中的所有下载操作都将通过此镜像站点代理进行。 68 | ```bash 69 | pip install -U huggingface_hub 70 | ``` 71 | 72 | ```python 73 | from huggingface_hub import snapshot_download 74 | 75 | snapshot_download(repo_id='Qwen/Qwen-7B', repo_type='model', 76 | local_dir='./model_dir', resume_download=True, 77 | max_workers=8) 78 | 79 | ``` 80 | 81 | 或者你也可以使用huggingface cli直接下载模型和数据集. 82 | 83 | 下载GPT2: 84 | ```bash 85 | huggingface-cli download --resume-download openai-community/gpt2 --local-dir gpt2 86 | ``` 87 | 88 | 下载WikiText: 89 | ```bash 90 | huggingface-cli download --repo-type dataset --resume-download Salesforce/wikitext --local-dir wikitext 91 | ``` 92 | 93 | 您可以查看路径`./repos`,其中存储了所有数据集和模型的缓存。 94 | 95 | ## 启动服务器 96 | 在控制台运行以下命令: 97 | ```bash 98 | olah-cli 99 | ``` 100 | 101 | 或者您可以指定主机地址和监听端口: 102 | ```bash 103 | olah-cli --host localhost --port 8090 104 | ``` 105 | **注意:请记得在修改主机和端口时将`--mirror-netloc`和`--mirror-lfs-netloc`更改为镜像站点的实际URL。** 106 | 107 | ```bash 108 | olah-cli --host 192.168.1.100 --port 8090 --mirror-netloc 192.168.1.100:8090 109 | ``` 110 | 111 | 默认的镜像缓存路径是`./repos`,您可以通过`--repos-path`参数进行更改: 112 | ```bash 113 | olah-cli --host localhost --port 8090 --repos-path ./hf_mirrors 114 | ``` 115 | 116 | **注意,不同版本之间的缓存数据不能迁移,请删除缓存文件夹后再进行olah的升级** 117 | 118 | 在实际部署中可能出现下载并发量很大,导致新的连接出现Timeout错误。 119 | 可以设置uvicorn的WEB_CONCURRENCY变量以增加worker数量以提升产品场景的并发量。 120 | 例如Linux下: 121 | ```bash 122 | export WEB_CONCURRENCY=4 123 | ``` 124 | 125 | ## 更多配置 126 | 127 | 更多配置可以通过配置文件进行控制,通过命令参数传入`configs.toml`以设置配置文件路径: 128 | ```bash 129 | olah-cli -c configs.toml 130 | ``` 131 | 132 | 完整的配置文件内容见[assets/full_configs.toml](https://github.com/vtuber-plan/olah/blob/main/assets/full_configs.toml) 133 | 134 | ### 配置详解 135 | 第一部分basic字段用于对镜像站进行基本设置 136 | ```toml 137 | [basic] 138 | host = "localhost" 139 | port = 8090 140 | ssl-key = "" 141 | ssl-cert = "" 142 | repos-path = "./repos" 143 | cache-size-limit = "" 144 | cache-clean-strategy = "LRU" 145 | hf-scheme = "https" 146 | hf-netloc = "huggingface.co" 147 | hf-lfs-netloc = "cdn-lfs.huggingface.co" 148 | mirror-scheme = "http" 149 | mirror-netloc = "localhost:8090" 150 | mirror-lfs-netloc = "localhost:8090" 151 | mirrors-path = ["./mirrors_dir"] 152 | ``` 153 | 154 | - host: 设置olah监听的host地址 155 | - port: 设置olah监听的端口 156 | - ssl-key和ssl-cert: 当需要开启HTTPS时传入key和cert的文件路径 157 | - repos-path: 用于保存缓存数据的目录 158 | - cache-size-limit: 指定缓存大小限制(例如,100G,500GB,2TB)。Olah会每小时扫描缓存文件夹的大小。如果超出限制,Olah会删除一些缓存文件 159 | - cache-clean-strategy: 指定缓存清理策略(可用策略:LRU,FIFO,LARGE_FIRST) 160 | - hf-scheme: huggingface官方站点的网络协议(一般不需要改动) 161 | - hf-netloc: huggingface官方站点的网络位置(一般不需要改动) 162 | - hf-lfs-netloc: huggingface官方站点LFS文件的网络位置(一般不需要改动) 163 | - mirror-scheme: Olah镜像站的网络协议(应当和上面的设置一致,当提供ssl-key和ssl-cert时,应改为https) 164 | - mirror-netloc: Olah镜像站的网络位置(应与host和port设置一致) 165 | - mirror-lfs-netloc: Olah镜像站LFS的网络位置(应与host和port设置一致) 166 | - mirrors-path: 额外的镜像文件目录。当你已经clone了一些git仓库时可以放入该目录下以供下载。此处例子目录为`./mirrors_dir`, 若要添加数据集`Salesforce/wikitext`,可将git仓库放置于`./mirrors_dir/datasets/Salesforce/wikitext`目录。同理,模型放置于`./mirrors_dir/models/organization/repository`下。 167 | 168 | 169 | 第二部分可以对可访问性进行限制 170 | ```toml 171 | 172 | [accessibility] 173 | offline = false 174 | 175 | [[accessibility.proxy]] 176 | repo = "cais/mmlu" 177 | allow = true 178 | 179 | [[accessibility.proxy]] 180 | repo = "adept/fuyu-8b" 181 | allow = false 182 | 183 | [[accessibility.proxy]] 184 | repo = "mistralai/*" 185 | allow = true 186 | 187 | [[accessibility.proxy]] 188 | repo = "mistralai/Mistral.*" 189 | allow = false 190 | use_re = true 191 | 192 | [[accessibility.cache]] 193 | repo = "cais/mmlu" 194 | allow = true 195 | 196 | [[accessibility.cache]] 197 | repo = "adept/fuyu-8b" 198 | allow = false 199 | ``` 200 | - offline: 设置Olah镜像站是否进入离线模式,不再向huggingface官方站点发出请求以进行数据更新,但已经缓存的仓库仍可以下载 201 | - proxy: 用于设置该仓库是否可以被代理,默认全部允许,`repo`用于匹配仓库名字; 可使用正则表达式和通配符两种模式,`use_re`用于控制是否使用正则表达式,默认使用通配符; `allow`控制该规则的属性是允许代理还是不允许代理。 202 | - cache: 用于设置该仓库是否会被缓存,默认全部允许,`repo`用于匹配仓库名字; 可使用正则表达式和通配符两种模式,`use_re`用于控制是否使用正则表达式,默认使用通配符; `allow`控制该规则的属性是允许代理还是不允许缓存。 203 | 204 | ## 许可证 205 | 206 | olah采用MIT许可证发布。 207 | 208 | ## 另请参阅 209 | 210 | - [olah-docs](https://github.com/vtuber-plan/olah/tree/main/docs) 211 | - [olah-source](https://github.com/vtuber-plan/olah) 212 | 213 | ## Star历史 214 | 215 | [![Star历史图表]()](https://star-history.com/#vtuber-plan/olah&Date) -------------------------------------------------------------------------------- /src/olah/utils/logging.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | from asyncio import AbstractEventLoop 9 | import json 10 | import logging 11 | import logging.handlers 12 | import os 13 | import platform 14 | import re 15 | import sys 16 | from typing import AsyncGenerator, Generator 17 | import warnings 18 | from olah.constants import DEFAULT_LOGGER_DIR 19 | 20 | handler = None 21 | 22 | 23 | # Define a custom formatter without color codes 24 | class NoColorFormatter(logging.Formatter): 25 | color_pattern = re.compile(r"\x1b[^m]*m") # Regex pattern to match color codes 26 | 27 | def format(self, record): 28 | message = super().format(record) 29 | # Remove color codes from the log message 30 | message = self.color_pattern.sub("", message) 31 | return message 32 | 33 | 34 | def build_logger( 35 | logger_name, logger_filename, logger_dir=DEFAULT_LOGGER_DIR 36 | ) -> logging.Logger: 37 | global handler 38 | 39 | formatter = logging.Formatter( 40 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 41 | datefmt="%Y-%m-%d %H:%M:%S", 42 | ) 43 | 44 | nocolor_formatter = NoColorFormatter( 45 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 46 | datefmt="%Y-%m-%d %H:%M:%S", 47 | ) 48 | 49 | # Set the format of root handlers 50 | if logging.getLogger().handlers is None or len(logging.getLogger().handlers) == 0: 51 | if sys.version_info[1] >= 9: 52 | # This is for windows 53 | logging.basicConfig(level=logging.INFO, encoding="utf-8") 54 | else: 55 | if platform.system() == "Windows": 56 | warnings.warn( 57 | "If you are running on Windows, " 58 | "we recommend you use Python >= 3.9 for UTF-8 encoding." 59 | ) 60 | logging.basicConfig(level=logging.INFO) 61 | logging.getLogger().handlers[0].setFormatter(formatter) 62 | 63 | # Redirect stdout and stderr to loggers 64 | stdout_logger = logging.getLogger("stdout") 65 | stdout_logger.setLevel(logging.DEBUG) 66 | sl = StreamToLogger(stdout_logger, logging.INFO) 67 | sys.stdout = sl 68 | 69 | stderr_logger = logging.getLogger("stderr") 70 | stderr_logger.setLevel(logging.ERROR) 71 | sl = StreamToLogger(stderr_logger, logging.ERROR) 72 | sys.stderr = sl 73 | 74 | # Get logger 75 | logger = logging.getLogger(logger_name) 76 | logger.setLevel(logging.DEBUG) 77 | 78 | # Add a file handler for all loggers 79 | if handler is None: 80 | os.makedirs(logger_dir, exist_ok=True) 81 | filename = os.path.join(logger_dir, logger_filename) 82 | handler = logging.handlers.TimedRotatingFileHandler( 83 | filename, when="H", utc=True, encoding="utf-8" 84 | ) 85 | handler.setFormatter(nocolor_formatter) 86 | handler.namer = lambda name: name.replace(".log", "") + ".log" 87 | 88 | for name, item in logging.root.manager.loggerDict.items(): 89 | if isinstance(item, logging.Logger): 90 | item.addHandler(handler) 91 | 92 | return logger 93 | 94 | 95 | class StreamToLogger(object): 96 | """ 97 | Fake file-like stream object that redirects writes to a logger instance. 98 | """ 99 | 100 | def __init__(self, logger, log_level=logging.INFO): 101 | self.terminal = sys.stdout 102 | self.logger = logger 103 | self.log_level = log_level 104 | self.linebuf = "" 105 | 106 | def __getattr__(self, attr): 107 | try: 108 | attr_value = getattr(self.terminal, attr) 109 | except: 110 | return None 111 | return attr_value 112 | 113 | def write(self, buf): 114 | temp_linebuf = self.linebuf + buf 115 | self.linebuf = "" 116 | for line in temp_linebuf.splitlines(True): 117 | # From the io.TextIOWrapper docs: 118 | # On output, if newline is None, any '\n' characters written 119 | # are translated to the system default line separator. 120 | # By default sys.stdout.write() expects '\n' newlines and then 121 | # translates them so this is still cross platform. 122 | if line[-1] == "\n": 123 | encoded_message = line.encode("utf-8", "ignore").decode("utf-8") 124 | self.logger.log(self.log_level, encoded_message.rstrip()) 125 | else: 126 | self.linebuf += line 127 | 128 | def flush(self): 129 | if self.linebuf != "": 130 | encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8") 131 | self.logger.log(self.log_level, encoded_message.rstrip()) 132 | self.linebuf = "" 133 | 134 | 135 | def iter_over_async( 136 | async_gen: AsyncGenerator, event_loop: AbstractEventLoop 137 | ) -> Generator: 138 | """ 139 | Convert async generator to sync generator 140 | 141 | :param async_gen: the AsyncGenerator to convert 142 | :param event_loop: the event loop to run on 143 | :returns: Sync generator 144 | """ 145 | ait = async_gen.__aiter__() 146 | 147 | async def get_next(): 148 | try: 149 | obj = await ait.__anext__() 150 | return False, obj 151 | except StopAsyncIteration: 152 | return True, None 153 | 154 | while True: 155 | done, obj = event_loop.run_until_complete(get_next()) 156 | if done: 157 | break 158 | yield obj 159 | -------------------------------------------------------------------------------- /src/olah/configs.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 XiaHan 3 | # 4 | # Use of this source code is governed by an MIT-style 5 | # license that can be found in the LICENSE file or at 6 | # https://opensource.org/licenses/MIT. 7 | 8 | from typing import Any, Dict, List, Literal, Optional, Union 9 | import toml 10 | import re 11 | import fnmatch 12 | 13 | from olah.utils.disk_utils import convert_to_bytes 14 | 15 | DEFAULT_PROXY_RULES = [ 16 | {"repo": "*", "allow": True, "use_re": False}, 17 | {"repo": "*/*", "allow": True, "use_re": False}, 18 | ] 19 | 20 | DEFAULT_CACHE_RULES = [ 21 | {"repo": "*", "allow": True, "use_re": False}, 22 | {"repo": "*/*", "allow": True, "use_re": False}, 23 | ] 24 | 25 | 26 | class OlahRule(object): 27 | def __init__(self, repo: str = "", type: str = "*", allow: bool = False, use_re: bool = False) -> None: 28 | self.repo = repo 29 | self.type = type 30 | self.allow = allow 31 | self.use_re = use_re 32 | 33 | @staticmethod 34 | def from_dict(data: Dict[str, Any]) -> "OlahRule": 35 | out = OlahRule() 36 | if "repo" in data: 37 | out.repo = data["repo"] 38 | if "allow" in data: 39 | out.allow = data["allow"] 40 | if "use_re" in data: 41 | out.use_re = data["use_re"] 42 | return out 43 | 44 | def match(self, repo_name: str) -> bool: 45 | if self.use_re: 46 | return self.match_re(repo_name) 47 | else: 48 | return self.match_fn(repo_name) 49 | 50 | def match_fn(self, repo_name: str) -> bool: 51 | return fnmatch.fnmatch(repo_name, self.repo) 52 | 53 | def match_re(self, repo_name: str) -> bool: 54 | return re.match(self.repo, repo_name) is not None 55 | 56 | 57 | class OlahRuleList(object): 58 | def __init__(self) -> None: 59 | self.rules: List[OlahRule] = [] 60 | 61 | @staticmethod 62 | def from_list(data: List[Dict[str, Any]]) -> "OlahRuleList": 63 | out = OlahRuleList() 64 | for item in data: 65 | out.rules.append(OlahRule.from_dict(item)) 66 | return out 67 | 68 | def clear(self): 69 | self.rules.clear() 70 | 71 | def allow(self, repo_name: str) -> bool: 72 | allow = False 73 | for rule in self.rules: 74 | if rule.match(repo_name): 75 | allow = rule.allow 76 | return allow 77 | 78 | 79 | class OlahConfig(object): 80 | def __init__(self, path: Optional[str] = None) -> None: 81 | 82 | # basic 83 | self.host: Union[List[str], str] = "localhost" 84 | self.port = 8090 85 | self.ssl_key = None 86 | self.ssl_cert = None 87 | self.repos_path = "./repos" 88 | self.cache_size_limit: Optional[int] = None 89 | self.cache_clean_strategy: Literal["LRU", "FIFO", "LARGE_FIRST"] = "LRU" 90 | 91 | self.hf_scheme: str = "https" 92 | self.hf_netloc: str = "huggingface.co" 93 | self.hf_lfs_netloc: str = "cdn-lfs.huggingface.co" 94 | 95 | self.mirror_scheme: str = "http" if self.ssl_key is None else "https" 96 | self.mirror_netloc: str = ( 97 | f"{self.host if self._is_specific_addr(self.host) else 'localhost'}:{self.port}" 98 | ) 99 | self.mirror_lfs_netloc: str = ( 100 | f"{self.host if self._is_specific_addr(self.host) else 'localhost'}:{self.port}" 101 | ) 102 | 103 | self.mirrors_path: List[str] = [] 104 | 105 | # accessibility 106 | self.offline = False 107 | self.proxy = OlahRuleList.from_list(DEFAULT_PROXY_RULES) 108 | self.cache = OlahRuleList.from_list(DEFAULT_CACHE_RULES) 109 | 110 | if path is not None: 111 | self.read_toml(path) 112 | 113 | def _is_specific_addr(self, host: Union[List[str], str]) -> bool: 114 | if isinstance(host, str): 115 | return host not in ['0.0.0.0', '::'] 116 | else: 117 | return False 118 | 119 | def hf_url_base(self) -> str: 120 | return f"{self.hf_scheme}://{self.hf_netloc}" 121 | 122 | def hf_lfs_url_base(self) -> str: 123 | return f"{self.hf_scheme}://{self.hf_lfs_netloc}" 124 | 125 | def mirror_url_base(self) -> str: 126 | return f"{self.mirror_scheme}://{self.mirror_netloc}" 127 | 128 | def mirror_lfs_url_base(self) -> str: 129 | return f"{self.mirror_scheme}://{self.mirror_lfs_netloc}" 130 | 131 | def empty_str(self, s: str) -> Optional[str]: 132 | if s == "": 133 | return None 134 | else: 135 | return s 136 | 137 | def read_toml(self, path: str) -> None: 138 | config = toml.load(path) 139 | 140 | if "basic" in config: 141 | basic = config["basic"] 142 | self.host = basic.get("host", self.host) 143 | self.port = basic.get("port", self.port) 144 | self.ssl_key = self.empty_str(basic.get("ssl-key", self.ssl_key)) 145 | self.ssl_cert = self.empty_str(basic.get("ssl-cert", self.ssl_cert)) 146 | self.repos_path = basic.get("repos-path", self.repos_path) 147 | self.cache_size_limit = convert_to_bytes(basic.get("cache-size-limit", self.cache_size_limit)) 148 | self.cache_clean_strategy = basic.get("cache-clean-strategy", self.cache_clean_strategy) 149 | 150 | self.hf_scheme = basic.get("hf-scheme", self.hf_scheme) 151 | self.hf_netloc = basic.get("hf-netloc", self.hf_netloc) 152 | self.hf_lfs_netloc = basic.get("hf-lfs-netloc", self.hf_lfs_netloc) 153 | 154 | self.mirror_scheme = basic.get("mirror-scheme", self.mirror_scheme) 155 | self.mirror_netloc = basic.get("mirror-netloc", self.mirror_netloc) 156 | self.mirror_lfs_netloc = basic.get( 157 | "mirror-lfs-netloc", self.mirror_lfs_netloc 158 | ) 159 | 160 | self.mirrors_path = basic.get("mirrors-path", self.mirrors_path) 161 | 162 | if "accessibility" in config: 163 | accessibility = config["accessibility"] 164 | self.offline = accessibility.get("offline", self.offline) 165 | self.proxy = OlahRuleList.from_list(accessibility.get("proxy", DEFAULT_PROXY_RULES)) 166 | self.cache = OlahRuleList.from_list(accessibility.get("cache", DEFAULT_CACHE_RULES)) 167 | -------------------------------------------------------------------------------- /src/olah/utils/zip_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Handlers for Content-Encoding. 3 | 4 | See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding 5 | """ 6 | 7 | import codecs 8 | import io 9 | import typing 10 | from typing import List, Optional, Union 11 | import zlib 12 | import brotli 13 | import httpx 14 | 15 | 16 | class DecodingError(httpx.RequestError): 17 | """ 18 | Decoding of the response failed, due to a malformed encoding. 19 | """ 20 | 21 | 22 | class ContentDecoder: 23 | def decode(self, data: bytes) -> bytes: 24 | raise NotImplementedError() # pragma: no cover 25 | 26 | def flush(self) -> bytes: 27 | raise NotImplementedError() # pragma: no cover 28 | 29 | 30 | class IdentityDecoder(ContentDecoder): 31 | """ 32 | Handle unencoded data. 33 | """ 34 | 35 | def decode(self, data: bytes) -> bytes: 36 | return data 37 | 38 | def flush(self) -> bytes: 39 | return b"" 40 | 41 | 42 | class DeflateDecoder(ContentDecoder): 43 | """ 44 | Handle 'deflate' decoding. 45 | 46 | See: https://stackoverflow.com/questions/1838699 47 | """ 48 | 49 | def __init__(self) -> None: 50 | self.first_attempt = True 51 | self.decompressor = zlib.decompressobj() 52 | 53 | def decode(self, data: bytes) -> bytes: 54 | was_first_attempt = self.first_attempt 55 | self.first_attempt = False 56 | try: 57 | return self.decompressor.decompress(data) 58 | except zlib.error as exc: 59 | if was_first_attempt: 60 | self.decompressor = zlib.decompressobj(-zlib.MAX_WBITS) 61 | return self.decode(data) 62 | raise DecodingError(str(exc)) from exc 63 | 64 | def flush(self) -> bytes: 65 | try: 66 | return self.decompressor.flush() 67 | except zlib.error as exc: # pragma: no cover 68 | raise DecodingError(str(exc)) from exc 69 | 70 | 71 | class GZipDecoder(ContentDecoder): 72 | """ 73 | Handle 'gzip' decoding. 74 | 75 | See: https://stackoverflow.com/questions/1838699 76 | """ 77 | 78 | def __init__(self) -> None: 79 | self.decompressor = zlib.decompressobj(zlib.MAX_WBITS | 16) 80 | 81 | def decode(self, data: bytes) -> bytes: 82 | try: 83 | return self.decompressor.decompress(data) 84 | except zlib.error as exc: 85 | raise DecodingError(str(exc)) from exc 86 | 87 | def flush(self) -> bytes: 88 | try: 89 | return self.decompressor.flush() 90 | except zlib.error as exc: # pragma: no cover 91 | raise DecodingError(str(exc)) from exc 92 | 93 | 94 | class BrotliDecoder(ContentDecoder): 95 | """ 96 | Handle 'brotli' decoding. 97 | 98 | Requires `pip install brotlipy`. See: https://brotlipy.readthedocs.io/ 99 | or `pip install brotli`. See https://github.com/google/brotli 100 | Supports both 'brotlipy' and 'Brotli' packages since they share an import 101 | name. The top branches are for 'brotlipy' and bottom branches for 'Brotli' 102 | """ 103 | 104 | def __init__(self) -> None: 105 | if brotli is None: # pragma: no cover 106 | raise ImportError( 107 | "Using 'BrotliDecoder', but neither of the 'brotlicffi' or 'brotli' " 108 | "packages have been installed. " 109 | "Make sure to install httpx using `pip install httpx[brotli]`." 110 | ) from None 111 | 112 | self.decompressor = brotli.Decompressor() 113 | self.seen_data = False 114 | self._decompress: typing.Callable[[bytes], bytes] 115 | if hasattr(self.decompressor, "decompress"): 116 | # The 'brotlicffi' package. 117 | self._decompress = self.decompressor.decompress # pragma: no cover 118 | else: 119 | # The 'brotli' package. 120 | self._decompress = self.decompressor.process # pragma: no cover 121 | 122 | def decode(self, data: bytes) -> bytes: 123 | if not data: 124 | return b"" 125 | self.seen_data = True 126 | try: 127 | return self._decompress(data) 128 | except brotli.error as exc: 129 | raise DecodingError(str(exc)) from exc 130 | 131 | def flush(self) -> bytes: 132 | if not self.seen_data: 133 | return b"" 134 | try: 135 | if hasattr(self.decompressor, "finish"): 136 | # Only available in the 'brotlicffi' package. 137 | 138 | # As the decompressor decompresses eagerly, this 139 | # will never actually emit any data. However, it will potentially throw 140 | # errors if a truncated or damaged data stream has been used. 141 | self.decompressor.finish() # pragma: no cover 142 | return b"" 143 | except brotli.error as exc: # pragma: no cover 144 | raise DecodingError(str(exc)) from exc 145 | 146 | 147 | class MultiDecoder(ContentDecoder): 148 | """ 149 | Handle the case where multiple encodings have been applied. 150 | """ 151 | 152 | def __init__(self, children: typing.Sequence[ContentDecoder]) -> None: 153 | """ 154 | 'children' should be a sequence of decoders in the order in which 155 | each was applied. 156 | """ 157 | # Note that we reverse the order for decoding. 158 | self.children = list(reversed(children)) 159 | 160 | def decode(self, data: bytes) -> bytes: 161 | for child in self.children: 162 | data = child.decode(data) 163 | return data 164 | 165 | def flush(self) -> bytes: 166 | data = b"" 167 | for child in self.children: 168 | data = child.decode(data) + child.flush() 169 | return data 170 | 171 | 172 | SUPPORTED_DECODERS = { 173 | "identity": IdentityDecoder, 174 | "gzip": GZipDecoder, 175 | "deflate": DeflateDecoder, 176 | "br": BrotliDecoder, 177 | } 178 | 179 | 180 | class Decompressor(object): 181 | def __init__(self, algorithms: Union[str, List[str]]) -> None: 182 | if isinstance(algorithms, str): 183 | self.algorithms = [algorithms] 184 | else: 185 | self.algorithms = algorithms 186 | 187 | self.decoders = [] 188 | for algo in self.algorithms: 189 | algo = algo.strip().lower() 190 | if algo in SUPPORTED_DECODERS: 191 | self.decoders.append(SUPPORTED_DECODERS[algo]()) 192 | else: 193 | print(f"Unsupported compression algorithm: {algo}") 194 | 195 | self.decoder = MultiDecoder(self.decoders) 196 | 197 | def decompress(self, raw_chunk: bytes) -> bytes: 198 | return self.decoder.decode(raw_chunk) 199 | 200 | 201 | def decompress_data(raw_data: bytes, content_encoding: Optional[str]) -> bytes: 202 | # If result is compressed 203 | if content_encoding is not None: 204 | final_data = raw_data 205 | algorithms = content_encoding.split(",") 206 | for algo in algorithms: 207 | algo = algo.strip().lower() 208 | if algo == "gzip": 209 | try: 210 | final_data = zlib.decompress( 211 | raw_data, zlib.MAX_WBITS | 16 212 | ) # 解压缩 213 | except Exception as e: 214 | print(f"Error decompressing gzip data: {e}") 215 | elif algo == "compress": 216 | print(f"Unsupported decompression algorithm: {algo}") 217 | elif algo == "deflate": 218 | try: 219 | final_data = zlib.decompress(raw_data) 220 | except Exception as e: 221 | print(f"Error decompressing deflate data: {e}") 222 | elif algo == "br": 223 | try: 224 | import brotli 225 | 226 | final_data = brotli.decompress(raw_data) 227 | except Exception as e: 228 | print(f"Error decompressing Brotli data: {e}") 229 | elif algo == "zstd": 230 | try: 231 | import zstandard 232 | 233 | final_data = zstandard.ZstdDecompressor().decompress(raw_data) 234 | except Exception as e: 235 | print(f"Error decompressing Zstandard data: {e}") 236 | else: 237 | print(f"Unsupported compression algorithm: {algo}") 238 | return final_data 239 | else: 240 | return raw_data 241 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
4 | Self-hosted Lightweight Huggingface Mirror Service
5 |
6 | Olah is a self-hosted lightweight huggingface mirror service. `Olah` means `hello` in Hilichurlian.
7 | Olah implemented the `mirroring` feature for huggingface resources, rather than just a simple `reverse proxy`.
8 | Olah does not immediately mirror the entire huggingface website but mirrors the resources at the file block level when users download them (or we can say cache them).
9 |
10 | Other languages: [中文](README_zh.md)
11 |
12 | ## Advantages of Olah
13 | Olah has the capability to cache files in chunks while users download them. Upon subsequent downloads, the files can be directly retrieved from the cache, greatly enhancing download speeds and saving bandwidth.
14 | Additionally, Olah offers a range of cache control policies. Administrators can configure which repositories are accessible and which ones can be cached through a configuration file.
15 |
16 | ## Features
17 | * Huggingface Data Cache
18 | * Models mirror
19 | * Datasets mirror
20 | * Spaces mirror
21 |
22 | ## Install
23 |
24 | ### Method 1: With pip
25 |
26 | ```bash
27 | pip install olah
28 | ```
29 |
30 | or:
31 |
32 | ```bash
33 | pip install git+https://github.com/vtuber-plan/olah.git
34 | ```
35 |
36 | ### Method 2: From source
37 |
38 | 1. Clone this repository
39 | ```bash
40 | git clone https://github.com/vtuber-plan/olah.git
41 | cd olah
42 | ```
43 |
44 | 2. Install the Package
45 | ```bash
46 | pip install --upgrade pip
47 | pip install -e .
48 | ```
49 |
50 | ## Quick Start
51 | Run the command in the console:
52 | ```bash
53 | olah-cli
54 | ```
55 |
56 | Then set the Environment Variable `HF_ENDPOINT` to the mirror site (Here is http://localhost:8090).
57 |
58 | Linux:
59 | ```bash
60 | export HF_ENDPOINT=http://localhost:8090
61 | ```
62 |
63 | Windows Powershell:
64 | ```bash
65 | $env:HF_ENDPOINT = "http://localhost:8090"
66 | ```
67 |
68 | Starting from now on, all download operations in the HuggingFace library will be proxied through this mirror site.
69 | ```bash
70 | pip install -U huggingface_hub
71 | ```
72 |
73 | ```python
74 | from huggingface_hub import snapshot_download
75 |
76 | snapshot_download(repo_id='Qwen/Qwen-7B', repo_type='model',
77 | local_dir='./model_dir', resume_download=True,
78 | max_workers=8)
79 | ```
80 |
81 | Or you can download models and datasets by using huggingface cli.
82 |
83 | Download GPT2:
84 | ```bash
85 | huggingface-cli download --resume-download openai-community/gpt2 --local-dir gpt2
86 | ```
87 |
88 | Download WikiText:
89 | ```bash
90 | huggingface-cli download --repo-type dataset --resume-download Salesforce/wikitext --local-dir wikitext
91 | ```
92 |
93 | You can check the path `./repos`, in which olah stores all cached datasets and models.
94 |
95 | ## Start the server
96 | Run the command in the console:
97 | ```bash
98 | olah-cli
99 | ```
100 |
101 | Or you can specify the host address and listening port:
102 | ```bash
103 | olah-cli --host localhost --port 8090
104 | ```
105 | **Note: Please change --mirror-netloc and --mirror-lfs-netloc to the actual URLs of the mirror sites when modifying the host and port.**
106 | ```bash
107 | olah-cli --host 192.168.1.100 --port 8090 --mirror-netloc 192.168.1.100:8090
108 | ```
109 |
110 | The default mirror cache path is `./repos`, you can change it by `--repos-path` parameter:
111 | ```bash
112 | olah-cli --host localhost --port 8090 --repos-path ./hf_mirrors
113 | ```
114 |
115 | **Note that the cached data between different versions cannot be migrated. Please delete the cache folder before upgrading to the latest version of Olah.**
116 |
117 | In deployment scenarios, there may be high concurrent downloads, leading to Timeout errors for new connections.
118 | You can set the `WEB_CONCURRENCY` variable for uvicorn to increase the number of workers, thereby enhancing concurrency in production environments.
119 | For example, on Linux:
120 | ```bash
121 | export WEB_CONCURRENCY=4
122 | ```
123 |
124 | ## More Configurations
125 |
126 | Additional configurations can be controlled through a configuration file by passing the `configs.toml` file as a command parameter:
127 | ```bash
128 | olah-cli -c configs.toml
129 | ```
130 |
131 | The complete content of the configuration file can be found at [assets/full_configs.toml](https://github.com/vtuber-plan/olah/blob/main/assets/full_configs.toml).
132 |
133 | ### Configuration Details
134 | The first section, `basic`, is used to set up basic configurations for the mirror site:
135 | ```toml
136 | [basic]
137 | host = "localhost"
138 | port = 8090
139 | ssl-key = ""
140 | ssl-cert = ""
141 | repos-path = "./repos"
142 | cache-size-limit = ""
143 | cache-clean-strategy = "LRU"
144 | hf-scheme = "https"
145 | hf-netloc = "huggingface.co"
146 | hf-lfs-netloc = "cdn-lfs.huggingface.co"
147 | mirror-scheme = "http"
148 | mirror-netloc = "localhost:8090"
149 | mirror-lfs-netloc = "localhost:8090"
150 | mirrors-path = ["./mirrors_dir"]
151 | ```
152 | - `host`: Sets the host address that Olah listens to.
153 | - `port`: Sets the port that Olah listens to.
154 | - `ssl-key` and `ssl-cert`: When enabling HTTPS, specify the file paths for the key and certificate.
155 | - `repos-path`: Specifies the directory for storing cached data.
156 | - `cache-size-limit`: Specifies cache size limit (For example, 100G, 500GB, 2TB). Olah will scan the size of the cache folder every hour. If it exceeds the limit, olah will delete some cache files.
157 | - `cache-clean-strategy`: Specifies cache cleaning strategy (Available strategies: LRU, FIFO, LARGE_FIRST).
158 | - `hf-scheme`: Network protocol for the Hugging Face official site (usually no need to modify).
159 | - `hf-netloc`: Network location of the Hugging Face official site (usually no need to modify).
160 | - `hf-lfs-netloc`: Network location for Hugging Face official site's LFS files (usually no need to modify).
161 | - `mirror-scheme`: Network protocol for the Olah mirror site (should match the above settings; change to HTTPS if providing `ssl-key` and `ssl-cert`).
162 | - `mirror-netloc`: Network location of the Olah mirror site (should match `host` and `port` settings).
163 | - `mirror-lfs-netloc`: Network location for Olah mirror site's LFS (should match `host` and `port` settings).
164 | - `mirrors-path`: Additional mirror file directories. If you have already cloned some Git repositories, you can place them in this directory for downloading. In this example, the directory is `./mirrors_dir`. To add a dataset like `Salesforce/wikitext`, you can place the Git repository in the directory `./mirrors_dir/datasets/Salesforce/wikitext`. Similarly, models can be placed under `./mirrors_dir/models/organization/repository`.
165 |
166 | The second section allows for accessibility restrictions:
167 | ```toml
168 | [accessibility]
169 | offline = false
170 |
171 | [[accessibility.proxy]]
172 | repo = "cais/mmlu"
173 | allow = true
174 |
175 | [[accessibility.proxy]]
176 | repo = "adept/fuyu-8b"
177 | allow = false
178 |
179 | [[accessibility.proxy]]
180 | repo = "mistralai/*"
181 | allow = true
182 |
183 | [[accessibility.proxy]]
184 | repo = "mistralai/Mistral.*"
185 | allow = false
186 | use_re = true
187 |
188 | [[accessibility.cache]]
189 | repo = "cais/mmlu"
190 | allow = true
191 |
192 | [[accessibility.cache]]
193 | repo = "adept/fuyu-8b"
194 | allow = false
195 | ```
196 | - `offline`: Sets whether the Olah mirror site enters offline mode, no longer making requests to the Hugging Face official site for data updates. However, cached repositories can still be downloaded.
197 | - `proxy`: Determines if the repository can be accessed through a proxy. By default, all repositories are allowed. The `repo` field is used to match the repository name. Regular expressions and wildcards can be used by setting `use_re` to control whether to use regular expressions (default is to use wildcards). The `allow` field controls whether the repository is allowed to be proxied.
198 | - `cache`: Determines if the repository will be cached. By default, all repositories are allowed. The `repo` field is used to match the repository name. Regular expressions and wildcards can be used by setting `use_re` to control whether to use regular expressions (default is to use wildcards). The `allow` field controls whether the repository is allowed to be cached.
199 |
200 | ## Future Work
201 |
202 | * Administrator and user system
203 | * OOS backend support
204 | * Mirror Update Schedule Task
205 |
206 | ## License
207 |
208 | olah is released under the MIT License.
209 |
210 |
211 | ## See also
212 |
213 | - [olah-docs](https://github.com/vtuber-plan/olah/tree/main/docs)
214 | - [olah-source](https://github.com/vtuber-plan/olah)
215 |
216 |
217 | ## Star History
218 |
219 | [](https://star-history.com/#vtuber-plan/olah&Date)
220 |
221 |
--------------------------------------------------------------------------------
/src/olah/utils/url_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 XiaHan
3 | #
4 | # Use of this source code is governed by an MIT-style
5 | # license that can be found in the LICENSE file or at
6 | # https://opensource.org/licenses/MIT.
7 |
8 | import datetime
9 | import os
10 | import glob
11 | from typing import Dict, List, Literal, Optional, Tuple, Union
12 | import json
13 | from urllib.parse import ParseResult, urlencode, urljoin, urlparse, parse_qs, urlunparse
14 | import httpx
15 | from olah.configs import OlahConfig
16 | from olah.constants import WORKER_API_TIMEOUT
17 |
18 |
19 | def get_url_tail(parsed_url: Union[str, ParseResult]) -> str:
20 | """
21 | Extracts the tail of a URL, including path, parameters, query, and fragment.
22 |
23 | Args:
24 | parsed_url (Union[str, ParseResult]): The parsed URL or a string URL.
25 |
26 | Returns:
27 | str: The tail of the URL, including path, parameters, query, and fragment.
28 | """
29 | if isinstance(parsed_url, str):
30 | parsed_url = urlparse(parsed_url)
31 | url_tail = parsed_url.path
32 | if len(parsed_url.params) != 0:
33 | url_tail += f";{parsed_url.params}"
34 | if len(parsed_url.query) != 0:
35 | url_tail += f"?{parsed_url.query}"
36 | if len(parsed_url.fragment) != 0:
37 | url_tail += f"#{parsed_url.fragment}"
38 | return url_tail
39 |
40 |
41 | def parse_content_range(content_range: str) -> Tuple[str, Optional[int], Optional[int], Optional[int]]:
42 | """
43 | Parses a Content-Range header string and extracts the unit, start position, end position, and resource size.
44 |
45 | Args:
46 | content_range (str): The Content-Range header string, e.g., "bytes 0-999/1000".
47 |
48 | Returns:
49 | Tuple[str, Optional[int], Optional[int], Optional[int]]: A tuple containing:
50 | - unit (str): The unit of the range, typically "bytes".
51 | - start_pos (Optional[int]): The starting position of the range. None if the range is "*".
52 | - end_pos (Optional[int]): The ending position of the range. None if the range is "*".
53 | - resource_size (Optional[int]): The total size of the resource. None if the size is unknown.
54 |
55 | Raises:
56 | Exception: If the range unit is invalid or the range format is incorrect.
57 | """
58 | if content_range.startswith("bytes "):
59 | unit = "bytes"
60 | content_range_part = content_range[len("bytes "):]
61 | else:
62 | raise Exception("Invalid range unit")
63 |
64 |
65 | if "/" in content_range_part:
66 | data_range, resource_size = content_range_part.split("/", maxsplit=1)
67 | resource_size = int(resource_size)
68 | else:
69 | data_range = content_range_part
70 | resource_size = None
71 |
72 | if "-" in data_range:
73 | start_pos, end_pos = data_range.split("-")
74 | start_pos, end_pos = int(start_pos), int(end_pos)
75 | elif "*" == data_range.strip():
76 | start_pos, end_pos = None, None
77 | else:
78 | raise Exception("Invalid range")
79 | return unit, start_pos, end_pos, resource_size
80 |
81 |
82 | def parse_range_params(range_header: str) -> Tuple[str, List[Tuple[Optional[int], Optional[int]]], Optional[int]]:
83 | """
84 | Parses the HTTP Range request header and returns the unit and a list of ranges.
85 |
86 | Args:
87 | range_header (str): The HTTP Range request header string, e.g., "bytes=0-499" or "bytes=200-999, 2000-2499, 9500-".
88 |
89 | Returns:
90 | Tuple[str, List[Tuple[int, int]], Optional[int]]: A tuple containing the unit (e.g., "bytes") and a list of ranges.
91 | Each range is represented as a tuple of start and end positions. If the end position is not specified,
92 | it is set to None. For suffix-length ranges (e.g., "-500"), the start position is negative.
93 |
94 | Raises:
95 | ValueError: If the Range header is empty or has an invalid format.
96 | """
97 | if not range_header:
98 | raise ValueError("Range header cannot be empty")
99 |
100 | # Split the unit and range specifiers
101 | parts = range_header.split('=')
102 | if len(parts) != 2:
103 | raise ValueError("Invalid Range header format")
104 |
105 | unit = parts[0].strip() # Get the unit, typically "bytes"
106 | range_specifiers = parts[1].strip() # Get the range part
107 |
108 | if range_specifiers.startswith("-") and range_specifiers[1:].isdigit():
109 | return unit, [], int(range_specifiers[1:])
110 |
111 | # Parse multiple ranges
112 | range_list = []
113 | for range_spec in range_specifiers.split(','):
114 | range_spec = range_spec.strip()
115 | if '-' not in range_spec:
116 | raise ValueError("Invalid range specifier")
117 |
118 | start, end = range_spec.split('-')
119 | start = start.strip()
120 | end = end.strip()
121 |
122 | # Handle suffix-length ranges (e.g., "-500")
123 | if not start and end:
124 | range_list.append((None, int(end))) # Negative start indicates suffix-length
125 | continue
126 |
127 | # Handle open-ended ranges (e.g., "500-")
128 | if not end and start:
129 | range_list.append((int(start), None))
130 | continue
131 |
132 | # Handle full ranges (e.g., "200-999")
133 | if start and end:
134 | range_list.append((int(start), int(end)))
135 | continue
136 |
137 | # If neither start nor end is provided, it's invalid
138 | raise ValueError("Invalid range specifier")
139 |
140 | return unit, range_list, None
141 |
142 |
143 | def get_all_ranges(file_size: int, unit: str, ranges: List[Tuple[Optional[int], Optional[int]]], suffix: Optional[int]) -> List[Tuple[int, int]]:
144 | all_ranges: List[Tuple[int, int]] = []
145 | if suffix is not None:
146 | all_ranges.append((file_size - suffix, file_size))
147 | else:
148 | for r in ranges:
149 | r_start = r[0] if r[0] is not None else 0
150 | r_end = r[1] if r[1] is not None else file_size - 1
151 | start_pos = max(0, r_start)
152 | end_pos = min(file_size - 1, r_end)
153 | if end_pos < start_pos:
154 | continue
155 | all_ranges.append((start_pos, end_pos + 1))
156 | return all_ranges
157 |
158 |
159 | class RemoteInfo(object):
160 | def __init__(self, method: str, url: str, headers: Dict[str, str]) -> None:
161 | """
162 | Represents information about a remote request.
163 |
164 | Args:
165 | method (str): The HTTP method of the request.
166 | url (str): The URL of the request.
167 | headers (Dict[str, str]): The headers of the request.
168 | """
169 | self.method = method
170 | self.url = url
171 | self.headers = headers
172 |
173 |
174 | def check_url_has_param_name(url: str, param_name: str) -> bool:
175 | """
176 | Checks if a URL contains a specific query parameter.
177 |
178 | Args:
179 | url (str): The URL to check.
180 | param_name (str): The name of the query parameter.
181 |
182 | Returns:
183 | bool: True if the URL contains the parameter, False otherwise.
184 | """
185 | parsed_url = urlparse(url)
186 | query_params = parse_qs(parsed_url.query)
187 | return param_name in query_params
188 |
189 |
190 | def get_url_param_name(url: str, param_name: str) -> Optional[str]:
191 | """
192 | Retrieves the value of a specific query parameter from a URL.
193 |
194 | Args:
195 | url (str): The URL to retrieve the parameter from.
196 | param_name (str): The name of the query parameter.
197 |
198 | Returns:
199 | Optional[str]: The value of the query parameter if found, None otherwise.
200 | """
201 | parsed_url = urlparse(url)
202 | query_params = parse_qs(parsed_url.query)
203 | original_location = query_params.get(param_name)
204 | if original_location:
205 | return original_location[0]
206 | else:
207 | return None
208 |
209 |
210 | def add_query_param(url: str, param_name: str, param_value: str) -> str:
211 | """
212 | Adds a query parameter to a URL.
213 |
214 | Args:
215 | url (str): The URL to add the parameter to.
216 | param_name (str): The name of the query parameter.
217 | param_value (str): The value of the query parameter.
218 |
219 | Returns:
220 | str: The modified URL with the added query parameter.
221 | """
222 | parsed_url = urlparse(url)
223 | query_params = parse_qs(parsed_url.query)
224 |
225 | query_params[param_name] = [param_value]
226 |
227 | new_query = urlencode(query_params, doseq=True)
228 | new_url = urlunparse(parsed_url._replace(query=new_query))
229 |
230 | return new_url
231 |
232 |
233 | def remove_query_param(url: str, param_name: str) -> str:
234 | """
235 | Removes a query parameter from a URL.
236 |
237 | Args:
238 | url (str): The URL to remove the parameter from.
239 | param_name (str): The name of the query parameter.
240 |
241 | Returns:
242 | str: The modified URL with the parameter removed.
243 | """
244 | parsed_url = urlparse(url)
245 | query_params = parse_qs(parsed_url.query)
246 |
247 | if param_name in query_params:
248 | del query_params[param_name]
249 |
250 | new_query = urlencode(query_params, doseq=True)
251 | new_url = urlunparse(parsed_url._replace(query=new_query))
252 |
253 | return new_url
254 |
255 |
256 | def clean_path(path: str) -> str:
257 | while ".." in path:
258 | path = path.replace("..", "")
259 | path = path.replace("\\", "/")
260 | while "//" in path:
261 | path = path.replace("//", "/")
262 | return path
--------------------------------------------------------------------------------
/src/olah/utils/repo_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 XiaHan
3 | #
4 | # Use of this source code is governed by an MIT-style
5 | # license that can be found in the LICENSE file or at
6 | # https://opensource.org/licenses/MIT.
7 |
8 | import datetime
9 | import os
10 | import glob
11 | import tenacity
12 | from typing import Dict, Literal, Optional, Tuple, Union
13 | import json
14 | from urllib.parse import urljoin
15 | import httpx
16 | from olah.constants import WORKER_API_TIMEOUT
17 | from olah.utils.cache_utils import read_cache_request
18 |
19 |
20 | def get_org_repo(org: Optional[str], repo: str) -> str:
21 | """
22 | Constructs the organization/repository name.
23 |
24 | Args:
25 | org: The organization name (optional).
26 | repo: The repository name.
27 |
28 | Returns:
29 | The organization/repository name as a string.
30 |
31 | """
32 | if org is None:
33 | org_repo = repo
34 | else:
35 | org_repo = f"{org}/{repo}"
36 | return org_repo
37 |
38 |
39 | def parse_org_repo(org_repo: str) -> Tuple[Optional[str], Optional[str]]:
40 | """
41 | Parses the organization/repository name.
42 |
43 | Args:
44 | org_repo: The organization/repository name.
45 |
46 | Returns:
47 | A tuple containing the organization name and repository name.
48 |
49 | """
50 | if "/" in org_repo and org_repo.count("/") != 1:
51 | return None, None
52 | if "/" in org_repo:
53 | org, repo = org_repo.split("/")
54 | else:
55 | org = None
56 | repo = org_repo
57 | return org, repo
58 |
59 |
60 | def get_meta_save_path(
61 | repos_path: str, repo_type: str, org: Optional[str], repo: str, commit: str
62 | ) -> str:
63 | """
64 | Constructs the path to save the meta.json file.
65 |
66 | Args:
67 | repos_path: The base path where repositories are stored.
68 | repo_type: The type of repository.
69 | org: The organization name (optional).
70 | repo: The repository name.
71 | commit: The commit hash.
72 |
73 | Returns:
74 | The path to save the meta.json file as a string.
75 |
76 | """
77 | org_repo = get_org_repo(org, repo)
78 | return os.path.join(
79 | repos_path, f"api/{repo_type}/{org_repo}/revision/{commit}/meta_get.json"
80 | )
81 |
82 |
83 | def get_meta_save_dir(
84 | repos_path: str, repo_type: str, org: Optional[str], repo: str
85 | ) -> str:
86 | """
87 | Constructs the directory path to save the meta.json file.
88 |
89 | Args:
90 | repos_path: The base path where repositories are stored.
91 | repo_type: The type of repository.
92 | org: The organization name (optional).
93 | repo: The repository name.
94 |
95 | Returns:
96 | The directory path to save the meta.json file as a string.
97 |
98 | """
99 | org_repo = get_org_repo(org, repo)
100 | return os.path.join(repos_path, f"api/{repo_type}/{org_repo}/revision")
101 |
102 |
103 | def get_file_save_path(
104 | repos_path: str,
105 | repo_type: str,
106 | org: Optional[str],
107 | repo: str,
108 | commit: str,
109 | file_path: str,
110 | ) -> str:
111 | """
112 | Constructs the path to save a file in the repository.
113 |
114 | Args:
115 | repos_path: The base path where repositories are stored.
116 | repo_type: The type of repository.
117 | org: The organization name (optional).
118 | repo: The repository name.
119 | commit: The commit hash.
120 | file_path: The path of the file within the repository.
121 |
122 | Returns:
123 | The path to save the file as a string.
124 |
125 | """
126 | org_repo = get_org_repo(org, repo)
127 | return os.path.join(
128 | repos_path, f"heads/{repo_type}/{org_repo}/resolve_head/{commit}/{file_path}"
129 | )
130 |
131 |
132 | async def get_newest_commit_hf_offline(
133 | app,
134 | repo_type: Optional[Literal["models", "datasets", "spaces"]],
135 | org: str,
136 | repo: str,
137 | ) -> Optional[str]:
138 | """
139 | Retrieves the newest commit hash for a repository in offline mode.
140 |
141 | Args:
142 | app: The application object.
143 | repo_type: The type of repository.
144 | org: The organization name.
145 | repo: The repository name.
146 |
147 | Returns:
148 | The newest commit hash as a string.
149 |
150 | """
151 | repos_path = app.state.app_settings.config.repos_path
152 | save_dir = get_meta_save_dir(repos_path, repo_type, org, repo)
153 | files = glob.glob(os.path.join(save_dir, "*", "meta_head.json"))
154 |
155 | time_revisions = []
156 | for file in files:
157 | with open(file, "r", encoding="utf-8") as f:
158 | obj = json.loads(f.read())
159 | datetime_object = datetime.datetime.fromisoformat(obj["lastModified"])
160 | time_revisions.append((datetime_object, obj["sha"]))
161 |
162 | time_revisions = sorted(time_revisions)
163 | if len(time_revisions) == 0:
164 | return None
165 | else:
166 | return time_revisions[-1][1]
167 |
168 |
169 | async def get_newest_commit_hf(
170 | app,
171 | repo_type: Optional[Literal["models", "datasets", "spaces"]],
172 | org: Optional[str],
173 | repo: str,
174 | authorization: Optional[str] = None,
175 | ) -> Optional[str]:
176 | """
177 | Retrieves the newest commit hash for a repository.
178 |
179 | Args:
180 | app: The application object.
181 | repo_type: The type of repository.
182 | org: The organization name (optional).
183 | repo: The repository name.
184 |
185 | Returns:
186 | The newest commit hash as a string, or None if it cannot be obtained.
187 |
188 | """
189 | org_repo = get_org_repo(org, repo)
190 | url = urljoin(
191 | app.state.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org_repo}"
192 | )
193 | if app.state.app_settings.config.offline:
194 | return await get_newest_commit_hf_offline(app, repo_type, org, repo)
195 | try:
196 | async with httpx.AsyncClient() as client:
197 | headers = {}
198 | if authorization is not None:
199 | headers["authorization"] = authorization
200 | response = await client.get(url, headers=headers, timeout=WORKER_API_TIMEOUT)
201 | if response.status_code != 200:
202 | return await get_newest_commit_hf_offline(app, repo_type, org, repo)
203 | obj = json.loads(response.text)
204 | return obj.get("sha", None)
205 | except httpx.TimeoutException as e:
206 | return await get_newest_commit_hf_offline(app, repo_type, org, repo)
207 |
208 |
209 | async def get_commit_hf_offline(
210 | app,
211 | repo_type: Optional[Literal["models", "datasets", "spaces"]],
212 | org: Optional[str],
213 | repo: str,
214 | commit: str,
215 | ) -> Optional[str]:
216 | """
217 | Retrieves the commit SHA for a given repository and commit from the offline cache.
218 |
219 | This function is used when the application is in offline mode and the commit information is not available from the API.
220 |
221 | Args:
222 | app: The application instance.
223 | repo_type: Optional. The type of repository ("models", "datasets", or "spaces").
224 | org: Optional. The organization name for the repository.
225 | repo: The name of the repository.
226 | commit: The commit identifier.
227 |
228 | Returns:
229 | The commit SHA as a string if available in the offline cache, or None if the information is not cached.
230 | """
231 | repos_path = app.state.app_settings.config.repos_path
232 | save_path = get_meta_save_path(repos_path, repo_type, org, repo, commit)
233 | if os.path.exists(save_path):
234 | request_cache = await read_cache_request(save_path)
235 | request_cache_json = json.loads(request_cache["content"])
236 | return request_cache_json["sha"]
237 | else:
238 | return None
239 |
240 |
241 | async def get_commit_hf(
242 | app,
243 | repo_type: Optional[Literal["models", "datasets", "spaces"]],
244 | org: Optional[str],
245 | repo: str,
246 | commit: str,
247 | authorization: Optional[str] = None,
248 | ) -> Optional[str]:
249 | """
250 | Retrieves the commit SHA for a given repository and commit from the Hugging Face API.
251 |
252 | Args:
253 | app: The application instance.
254 | repo_type: Optional. The type of repository ("models", "datasets", or "spaces").
255 | org: Optional. The organization name for the repository.
256 | repo: The name of the repository.
257 | commit: The commit identifier.
258 | authorization: Optional. The authorization token for accessing the API.
259 |
260 | Returns:
261 | The commit SHA as a string, or None if the commit cannot be retrieved.
262 |
263 | Raises:
264 | This function does not raise any explicit exceptions but may propagate exceptions from underlying functions.
265 | """
266 | org_repo = get_org_repo(org, repo)
267 | url = urljoin(
268 | app.state.app_settings.config.hf_url_base(),
269 | f"/api/{repo_type}/{org_repo}/revision/{commit}",
270 | )
271 | if app.state.app_settings.config.offline:
272 | return await get_commit_hf_offline(app, repo_type, org, repo, commit)
273 | try:
274 | headers = {}
275 | if authorization is not None:
276 | headers["authorization"] = authorization
277 | async with httpx.AsyncClient() as client:
278 | response = await client.get(
279 | url, headers=headers, timeout=WORKER_API_TIMEOUT, follow_redirects=True
280 | )
281 | if response.status_code not in [200, 307]:
282 | return await get_commit_hf_offline(app, repo_type, org, repo, commit)
283 | obj = json.loads(response.text)
284 | return obj.get("sha", None)
285 | except:
286 | return await get_commit_hf_offline(app, repo_type, org, repo, commit)
287 |
288 |
289 | @tenacity.retry(stop=tenacity.stop_after_attempt(3))
290 | async def check_commit_hf(
291 | app,
292 | repo_type: Optional[Literal["models", "datasets", "spaces"]],
293 | org: Optional[str],
294 | repo: str,
295 | commit: Optional[str] = None,
296 | authorization: Optional[str] = None,
297 | ) -> bool:
298 | """
299 | Checks the commit status of a repository in the Hugging Face ecosystem.
300 |
301 | Args:
302 | app: The application object.
303 | repo_type: The type of repository (models, datasets, or spaces).
304 | org: The organization name (optional).
305 | repo: The repository name.
306 | commit: The commit hash (optional).
307 | authorization: The authorization token (optional).
308 |
309 | Returns:
310 | A boolean indicating if the commit is valid (status code 200 or 307) or not.
311 |
312 | """
313 | org_repo = get_org_repo(org, repo)
314 | if commit is None:
315 | url = urljoin(
316 | app.state.app_settings.config.hf_url_base(), f"/api/{repo_type}/{org_repo}"
317 | )
318 | else:
319 | url = urljoin(
320 | app.state.app_settings.config.hf_url_base(),
321 | f"/api/{repo_type}/{org_repo}/revision/{commit}",
322 | )
323 |
324 | headers = {}
325 | if authorization is not None:
326 | headers["authorization"] = authorization
327 | async with httpx.AsyncClient() as client:
328 | response = await client.request(method="HEAD", url=url, headers=headers, timeout=WORKER_API_TIMEOUT)
329 | status_code = response.status_code
330 | return status_code in [200, 307]
331 |
--------------------------------------------------------------------------------
/src/olah/mirror/repos.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 XiaHan
3 | #
4 | # Use of this source code is governed by an MIT-style
5 | # license that can be found in the LICENSE file or at
6 | # https://opensource.org/licenses/MIT.
7 | import hashlib
8 | import io
9 | import os
10 | import re
11 | from typing import Any, Dict, List, Union
12 | import gitdb
13 | from git import Commit, Optional, Repo, Tree
14 | from git.objects.base import IndexObjUnion
15 | from gitdb.base import OStream
16 | import yaml
17 |
18 | from olah.mirror.meta import RepoMeta
19 |
20 |
21 | class LocalMirrorRepo(object):
22 | def __init__(self, path: str, repo_type: str, org: str, repo: str) -> None:
23 | self._path = path
24 | self._repo_type = repo_type
25 | self._org = org
26 | self._repo = repo
27 |
28 | self._git_repo = Repo(self._path)
29 |
30 | def _sha256(self, text: Union[str, bytes]) -> str:
31 | if isinstance(text, bytes) or isinstance(text, bytearray):
32 | bin = text
33 | elif isinstance(text, str):
34 | bin = text.encode("utf-8")
35 | else:
36 | raise Exception("Invalid sha256 param type.")
37 | sha256_hash = hashlib.sha256()
38 | sha256_hash.update(bin)
39 | hashed_string = sha256_hash.hexdigest()
40 | return hashed_string
41 |
42 | def _match_card(self, readme: str) -> str:
43 | pattern = r"\s*---(.*?)---"
44 |
45 | match = re.match(pattern, readme, flags=re.S)
46 |
47 | if match:
48 | card_string = match.group(1)
49 | return card_string
50 | else:
51 | return ""
52 |
53 | def _remove_card(self, readme: str) -> str:
54 | pattern = r"\s*---(.*?)---"
55 | out = re.sub(pattern, "", readme, flags=re.S)
56 | return out
57 |
58 | def _get_readme(self, commit: Commit) -> str:
59 | if "README.md" not in commit.tree:
60 | return ""
61 | else:
62 | out: bytes = commit.tree["README.md"].data_stream.read()
63 | return out.decode()
64 |
65 | def _get_description(self, commit: Commit) -> str:
66 | readme = self._get_readme(commit)
67 | return self._remove_card(readme)
68 |
69 | def _get_tree_filepaths_recursive(self, tree: Tree, include_dir: bool = False) -> List[str]:
70 | out_paths = []
71 | for entry in tree:
72 | if entry.type == "tree":
73 | out_paths.extend(self._get_tree_filepaths_recursive(entry))
74 | if include_dir:
75 | out_paths.append(entry.path)
76 | else:
77 | out_paths.append(entry.path)
78 | return out_paths
79 |
80 | def _get_commit_filepaths_recursive(self, commit: Commit) -> List[str]:
81 | return self._get_tree_filepaths_recursive(commit.tree)
82 |
83 | def _get_path_info(self, entry: IndexObjUnion, expand: bool = False) -> Dict[str, Union[int, str]]:
84 | lfs = False
85 | if entry.type != "tree":
86 | t = "file"
87 | repr_size = entry.size
88 | if repr_size > 120 and repr_size < 150:
89 | # check lfs
90 | lfs_data = entry.data_stream.read().decode("utf-8")
91 | match_groups = re.match(
92 | r"version https://git-lfs\.github\.com/spec/v[0-9]\noid sha256:([0-9a-z]{64})\nsize ([0-9]+?)\n",
93 | lfs_data,
94 | )
95 | if match_groups is not None:
96 | lfs = True
97 | sha256 = match_groups.group(1)
98 | repr_size = int(match_groups.group(2))
99 | lfs_data = {
100 | "oid": sha256,
101 | "size": repr_size,
102 | "pointerSize": entry.size,
103 | }
104 | else:
105 | t = "directory"
106 | repr_size = entry.size
107 |
108 | if not lfs:
109 | item = {
110 | "type": t,
111 | "oid": entry.hexsha,
112 | "size": repr_size,
113 | "path": entry.path,
114 | "name": entry.name,
115 | }
116 | else:
117 | item = {
118 | "type": t,
119 | "oid": entry.hexsha,
120 | "size": repr_size,
121 | "path": entry.path,
122 | "name": entry.name,
123 | "lfs": lfs_data,
124 | }
125 | if expand:
126 | last_commit = next(self._git_repo.iter_commits(paths=entry.path, max_count=1))
127 | item["lastCommit"] = {
128 | "id": last_commit.hexsha,
129 | "title": last_commit.message,
130 | "date": last_commit.committed_datetime.strftime(
131 | "%Y-%m-%dT%H:%M:%S.%fZ"
132 | )
133 | }
134 | item["security"] = {
135 | "blobId": entry.hexsha,
136 | "name": entry.name,
137 | "safe": True,
138 | "indexed": False,
139 | "avScan": {
140 | "virusFound": False,
141 | "virusNames": None
142 | },
143 | "pickleImportScan": None
144 | }
145 | return item
146 |
147 | def _get_tree_files(
148 | self, tree: Tree, recursive: bool = False, expand: bool = False
149 | ) -> List[Dict[str, Union[int, str]]]:
150 | entries = []
151 | for entry in tree:
152 | entries.append(self._get_path_info(entry=entry, expand=expand))
153 |
154 | if recursive:
155 | for entry in tree:
156 | if entry.type == "tree":
157 | entries.extend(self._get_tree_files(entry, recursive=recursive, expand=expand))
158 | return entries
159 |
160 | def _get_commit_files(self, commit: Commit) -> List[Dict[str, Union[int, str]]]:
161 | return self._get_tree_files(commit.tree)
162 |
163 | def _get_earliest_commit(self) -> Commit:
164 | earliest_commit = None
165 | earliest_commit_date = None
166 |
167 | for commit in self._git_repo.iter_commits():
168 | commit_date = commit.committed_datetime
169 |
170 | if earliest_commit_date is None or commit_date < earliest_commit_date:
171 | earliest_commit = commit
172 | earliest_commit_date = commit_date
173 |
174 | return earliest_commit
175 |
176 | def get_index_object_by_path(
177 | self, commit_hash: str, path: str
178 | ) -> Optional[IndexObjUnion]:
179 | try:
180 | commit = self._git_repo.commit(commit_hash)
181 | except gitdb.exc.BadName:
182 | return None
183 | path_part = path.split("/")
184 | path_part = [part for part in path_part if len(part.strip()) != 0]
185 | tree = commit.tree
186 | items = self._get_tree_files(tree=tree)
187 | if len(path_part) == 0:
188 | return None
189 | for i, part in enumerate(path_part):
190 | if i != len(path_part) - 1:
191 | if part not in [
192 | item["name"] for item in items if item["type"] == "directory"
193 | ]:
194 | return None
195 | else:
196 | if part not in [
197 | item["name"] for item in items
198 | ]:
199 | return None
200 | tree = tree[part]
201 | if tree.type == "tree":
202 | items = self._get_tree_files(tree=tree, recursive=False)
203 | return tree
204 |
205 | def get_pathinfos(
206 | self, commit_hash: str, paths: List[str]
207 | ) -> Optional[List[Dict[str, Any]]]:
208 | try:
209 | commit = self._git_repo.commit(commit_hash)
210 | except gitdb.exc.BadName:
211 | return None
212 |
213 | results = []
214 | for path in paths:
215 | index_obj = self.get_index_object_by_path(
216 | commit_hash=commit_hash, path=path
217 | )
218 | if index_obj is not None:
219 | results.append(self._get_path_info(index_obj))
220 |
221 | for r in results:
222 | if "name" in r:
223 | r.pop("name")
224 | return results
225 |
226 | def get_tree(
227 | self, commit_hash: str, path: str, recursive: bool = False, expand: bool = False
228 | ) -> Optional[Dict[str, Any]]:
229 | try:
230 | commit = self._git_repo.commit(commit_hash)
231 | except gitdb.exc.BadName:
232 | return None
233 |
234 | index_obj = self.get_index_object_by_path(commit_hash=commit_hash, path=path)
235 | items = self._get_tree_files(tree=index_obj, recursive=recursive, expand=expand)
236 | for r in items:
237 | r.pop("name")
238 | return items
239 |
240 | def get_commits(self, commit_hash: str) -> Optional[Dict[str, Any]]:
241 | try:
242 | commit = self._git_repo.commit(commit_hash)
243 | except gitdb.exc.BadName:
244 | return None
245 |
246 | parent_commits = [commit] + [each_commit for each_commit in commit.iter_parents()]
247 | items = []
248 | for each_commit in parent_commits:
249 | item = {
250 | "id": each_commit.hexsha,
251 | "title": each_commit.message,
252 | "message": "",
253 | "authors": [],
254 | "date": each_commit.committed_datetime.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
255 | }
256 | item["authors"].append({
257 | "name": each_commit.author.name,
258 | "avatar": None
259 | })
260 | items.append(item)
261 | return items
262 |
263 | def get_meta(self, commit_hash: str) -> Optional[Dict[str, Any]]:
264 | try:
265 | commit = self._git_repo.commit(commit_hash)
266 | except gitdb.exc.BadName:
267 | return None
268 | meta = RepoMeta()
269 |
270 | meta._id = self._sha256(f"{self._org}/{self._repo}/{commit.hexsha}")
271 | meta.id = f"{self._org}/{self._repo}"
272 | meta.author = self._org
273 | meta.sha = commit.hexsha
274 | meta.lastModified = self._git_repo.head.commit.committed_datetime.strftime(
275 | "%Y-%m-%dT%H:%M:%S.%fZ"
276 | )
277 | meta.private = False
278 | meta.gated = False
279 | meta.disabled = False
280 | meta.tags = []
281 | meta.description = self._get_description(commit)
282 | meta.paperswithcode_id = None
283 | meta.downloads = 0
284 | meta.likes = 0
285 | meta.cardData = yaml.load(
286 | self._match_card(self._get_readme(commit)), Loader=yaml.CLoader
287 | )
288 | meta.siblings = [
289 | {"rfilename": p} for p in self._get_commit_filepaths_recursive(commit)
290 | ]
291 | meta.createdAt = self._get_earliest_commit().committed_datetime.strftime(
292 | "%Y-%m-%dT%H:%M:%S.%fZ"
293 | )
294 | return meta.to_dict()
295 |
296 | def _contain_path(self, path: str, tree: Tree) -> bool:
297 | norm_p = os.path.normpath(path).replace("\\", "/")
298 | parts = norm_p.split("/")
299 | for part in parts:
300 | if all([t.name != part for t in tree]):
301 | return False
302 | else:
303 | entry = tree[part]
304 | if entry.type == "tree":
305 | tree = entry
306 | else:
307 | tree = {}
308 | return True
309 |
310 | def get_file_head(self, commit_hash: str, path: str) -> Optional[Dict[str, Any]]:
311 | try:
312 | commit = self._git_repo.commit(commit_hash)
313 | except gitdb.exc.BadName:
314 | return None
315 |
316 | if not self._contain_path(path, commit.tree):
317 | return None
318 | else:
319 | header = {}
320 | header["content-length"] = str(commit.tree[path].data_stream.size)
321 | header["x-repo-commit"] = commit.hexsha
322 | header["etag"] = commit.tree[path].binsha.hex()
323 | if (commit.tree[path].data_stream.size > 120) and (commit.tree[path].data_stream.size < 150):
324 | lfs_data = commit.tree[path].data_stream.read().decode("utf-8")
325 | match_groups = re.match(
326 | r"version https://git-lfs\.github\.com/spec/v[0-9]\noid sha256:([0-9a-z]{64})\nsize ([0-9]+?)\n",
327 | lfs_data,
328 | )
329 | if match_groups is not None:
330 | oid_sha256 = match_groups.group(1)
331 | objects_dir = os.path.join(self._git_repo.working_dir, '.git', 'lfs', 'objects')
332 | oid_dir = os.path.join(objects_dir, oid_sha256[:2], oid_sha256[2:4], oid_sha256)
333 | header["content-length"] = str(os.path.getsize(oid_dir))
334 | with open(oid_dir, mode='rb') as lfs_file:
335 | header["etag"] = self._sha256(lfs_file.read())
336 |
337 | return header
338 |
339 | def get_file(self, commit_hash: str, path: str) -> Optional[OStream]:
340 | try:
341 | commit = self._git_repo.commit(commit_hash)
342 | except gitdb.exc.BadName:
343 | return None
344 |
345 | lfs = False
346 | oid_dir = ""
347 | if (commit.tree[path].size > 120) and (commit.tree[path].size < 150):
348 | lfs_data = commit.tree[path].data_stream.read().decode("utf-8")
349 | match_groups = re.match(
350 | r"version https://git-lfs\.github\.com/spec/v[0-9]\noid sha256:([0-9a-z]{64})\nsize ([0-9]+?)\n",
351 | lfs_data,
352 | )
353 | if match_groups is not None:
354 | lfs = True
355 | oid_sha256 = match_groups.group(1)
356 | objects_dir = os.path.join(self._git_repo.working_dir, '.git', 'lfs', 'objects')
357 | oid_dir = os.path.join(objects_dir, oid_sha256[:2], oid_sha256[2:4], oid_sha256)
358 |
359 | def stream_wrapper(file_bytes: bytes):
360 | file_stream = io.BytesIO(file_bytes)
361 | while True:
362 | chunk = file_stream.read(4096)
363 | if len(chunk) == 0:
364 | break
365 | else:
366 | yield chunk
367 |
368 | if not self._contain_path(path, commit.tree):
369 | return None
370 | else:
371 | if lfs:
372 | with open(oid_dir, mode='rb') as lfs_file:
373 | return stream_wrapper(lfs_file.read())
374 | else:
375 | return stream_wrapper(commit.tree[path].data_stream.read())
376 |
--------------------------------------------------------------------------------
/src/olah/cache/olah_cache.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 XiaHan
3 | #
4 | # Use of this source code is governed by an MIT-style
5 | # license that can be found in the LICENSE file or at
6 | # https://opensource.org/licenses/MIT.
7 |
8 | import asyncio
9 | import lzma
10 | import mmap
11 | import os
12 | import string
13 | import struct
14 | import threading
15 | import gzip
16 | from typing import BinaryIO, Dict, List, Optional
17 |
18 | import aiofiles
19 | import fastapi
20 | import fastapi.concurrency
21 | import portalocker
22 | from .bitset import Bitset
23 |
24 | CURRENT_OLAH_CACHE_VERSION = 9
25 | # Due to the download chunk settings: https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/constants.py#L37
26 | DEFAULT_BLOCK_SIZE = 50 * 1024 * 1024
27 | MAX_BLOCK_NUM = 8192
28 | DEFAULT_COMPRESSION_ALGO = 1
29 | """
30 | 0: no compression
31 | 1: gzip
32 | 2: lzma
33 | 3: blosc
34 | 4: zlib
35 | 5: zstd
36 | 6: ...
37 | """
38 |
39 | class OlahCacheHeader(object):
40 | MAGIC_NUMBER = "OLAH".encode("ascii")
41 | HEADER_FIX_SIZE = 36
42 |
43 | def __init__(
44 | self,
45 | version: int = CURRENT_OLAH_CACHE_VERSION,
46 | block_size: int = DEFAULT_BLOCK_SIZE,
47 | file_size: int = 0,
48 | compression_algo: int = DEFAULT_COMPRESSION_ALGO,
49 | ) -> None:
50 | self._version = version
51 | self._block_size = block_size
52 | self._file_size = file_size
53 | self._compression_algo = compression_algo
54 |
55 | self._block_number = (file_size + block_size - 1) // block_size
56 |
57 | @property
58 | def version(self) -> int:
59 | return self._version
60 |
61 | @property
62 | def block_size(self) -> int:
63 | return self._block_size
64 |
65 | @property
66 | def file_size(self) -> int:
67 | return self._file_size
68 |
69 | @property
70 | def block_number(self) -> int:
71 | return self._block_number
72 |
73 | @property
74 | def compression_algo(self) -> int:
75 | return self._compression_algo
76 |
77 | def get_header_size(self) -> int:
78 | return self.HEADER_FIX_SIZE
79 |
80 | def _valid_header(self) -> None:
81 | if self._file_size > MAX_BLOCK_NUM * self._block_size:
82 | raise Exception(
83 | f"The size of file {self._file_size} is out of the max capability of container ({MAX_BLOCK_NUM} * {self._block_size})."
84 | )
85 | if self._version < CURRENT_OLAH_CACHE_VERSION:
86 | raise Exception(
87 | f"This Olah Cache file is created by older version Olah. Please remove cache files and retry."
88 | )
89 |
90 | if self._version > CURRENT_OLAH_CACHE_VERSION:
91 | raise Exception(
92 | f"This Olah Cache file is created by newer version Olah. Please remove cache files and retry."
93 | )
94 |
95 | @staticmethod
96 | def read(stream) -> "OlahCacheHeader":
97 | obj = OlahCacheHeader()
98 | try:
99 | magic = struct.unpack(
100 | "<4s", stream.read(4)
101 | )
102 | except struct.error:
103 | raise Exception("File is not a Olah cache file.")
104 | if magic[0] != OlahCacheHeader.MAGIC_NUMBER:
105 | raise Exception("File is not a Olah cache file.")
106 |
107 | version, block_size, file_size, compression_algo = struct.unpack(
108 | "