├── .env.miner.template
├── .env.validator.template
├── .gitignore
├── LICENSE
├── README.md
├── VERSION
├── bitmind
├── __init__.py
├── autoupdater.py
├── cache
│ ├── __init__.py
│ ├── cache_fs.py
│ ├── cache_system.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── dataset_registry.py
│ │ └── datasets.py
│ ├── sampler
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── image_sampler.py
│ │ ├── sampler_registry.py
│ │ └── video_sampler.py
│ ├── updater
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── image_updater.py
│ │ ├── updater_registry.py
│ │ └── video_updater.py
│ └── util
│ │ ├── __init__.py
│ │ ├── download.py
│ │ ├── extract.py
│ │ ├── filesystem.py
│ │ └── video.py
├── config.py
├── encoding.py
├── epistula.py
├── generation
│ ├── __init__.py
│ ├── generation_pipeline.py
│ ├── model_registry.py
│ ├── models.py
│ ├── prompt_generator.py
│ └── util
│ │ ├── __init__.py
│ │ ├── image.py
│ │ ├── model.py
│ │ └── prompt.py
├── metagraph.py
├── scoring
│ ├── __init__.py
│ ├── eval_engine.py
│ └── miner_history.py
├── transforms.py
├── types.py
├── utils.py
└── wandb_utils.py
├── docs
├── Incentive.md
├── Mining.md
├── Validating.md
└── static
│ ├── Bitmind-Logo.png
│ ├── Join-BitMind-Discord.png
│ └── Subnet-Arch.png
├── min_compute.yml
├── neurons
├── __init__.py
├── base.py
├── generator.py
├── miner.py
├── proxy.py
└── validator.py
├── pyproject.toml
├── requirements-git.txt
├── requirements.txt
├── setup.sh
├── start_miner.sh
└── start_validator.sh
/.env.miner.template:
--------------------------------------------------------------------------------
1 | # ======= Miner Configuration (FILL IN) =======
2 | # Wallet
3 | WALLET_NAME=
4 | WALLET_HOTKEY=
5 |
6 | # Network
7 | CHAIN_ENDPOINT=wss://entrypoint-finney.opentensor.ai:443
8 | # OTF public finney endpoint: wss://entrypoint-finney.opentensor.ai:443
9 | # OTF public testnet endpoint: wss://test.finney.opentensor.ai:443/
10 |
11 | # Axon port and (optionally) ip
12 | AXON_PORT=8091
13 | AXON_EXTERNAL_IP=[::]
14 |
15 | FORCE_VPERMIT=true
16 |
17 | # Device for detection models
18 | DEVICE=cpu
19 |
20 | # Logging
21 | LOGLEVEL=trace
--------------------------------------------------------------------------------
/.env.validator.template:
--------------------------------------------------------------------------------
1 | # ======= Validator Configuration (FILL IN) =======
2 | # Wallet
3 | WALLET_NAME=
4 | WALLET_HOTKEY=
5 |
6 | # API Keys
7 | WANDB_API_KEY=
8 | HUGGING_FACE_TOKEN=
9 |
10 | # Network
11 | CHAIN_ENDPOINT=wss://entrypoint-finney.opentensor.ai:443
12 | # OTF public finney endpoint: wss://entrypoint-finney.opentensor.ai:443
13 | # OTF public testnet endpoint: wss://test.finney.opentensor.ai:443/
14 |
15 | # Validator Proxy
16 | PROXY_PORT=10913
17 | PROXY_EXTERNAL_PORT=10913
18 |
19 | # Cache config
20 | SN34_CACHE_DIR=~/.cache/sn34
21 | HEARTBEAT=true
22 |
23 | # Generator config
24 | GENERATION_BATCH_SIZE=3
25 | DEVICE=cuda
26 |
27 | # Other
28 | LOGLEVEL=info
29 | AUTO_UPDATE=true
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | testing/
163 | data/
164 | checkpoints/
165 | .requirements_installed
166 | base_miner/NPR/weights/*
167 | base_miner/NPR/logs/*
168 | base_miner/DFB/weights/*
169 | base_miner/DFB/logs/*
170 | miner_eval.py
171 | *.env
172 | *~
173 | wandb/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 | Copyright © 2023 Yuma Rao
3 | Copyright © 2025 BitMind
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6 | documentation files (the "Software"), to deal in the Software without restriction, including without limitation
7 | the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
8 | and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of
11 | the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
14 | THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
15 | THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
16 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
17 | DEALINGS IN THE SOFTWARE.
18 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | SN34
Deepfake Detection
5 |
6 |
9 |
10 |
15 |
16 |
21 |
22 |
25 |
26 | ## Decentralized Detection of AI Generated Content
27 | The explosive growth of generative AI technology has unleashed an unprecedented wave of synthetic media creation. AI-generated audiovisual content has become remarkably sophisticated, oftentimes indistinguishable from authentic media. This development presents a critical challenge to information integrity and societal trust in the digital age, as the line between real and synthetic content continues to blur.
28 |
29 | To address this growing challenge, SN34 aims to create the most accurate fully-generalized detection system. Here, fully-generalized means that the system is capable of detecting both synthetic and semi-synthetic media with high degrees of accuracy regardless of their content or what model generated them. Our incentive mechanism evolves alongside state-of-the-art generative AI, rewarding miners whose detection algorithms best adapt to new forms of synthetic content.
30 |
31 |
32 | ## Core Components
33 |
34 | > This documentation assumes basic familiarity with [Bittensor concepts](https://docs.bittensor.com/learn/bittensor-building-blocks).
35 |
36 | Miners
37 |
38 | - Miners are tasked with running binary classifiers that discern between genuine and AI-generated content, and are rewarded based on their accuracy.
39 | - For each challenge, a miner is presented an image or video and is required to respond with a multiclass prediction [$p_{real}$, $p_{synthetic}$, $p_{semisynthetic}$] indicating whether the media is real, fully generated, or partially modified by AI.
40 |
41 |
42 | Validators
43 | - Validators challenge miners with a balanced mix of real and synthetic media drawn from a diverse pool of sources.
44 | - We continually add new datasets and generative models to our validators in order to evolve the subnet's detection capabilities alongside advances in generative AI.
45 |
46 |
47 | ## Subnet Architecture
48 |
49 | Overview of the validator neuron, miner neuron, and other components external to the subnet.
50 |
51 | 
52 |
53 |
54 | Challenge Generation and Scoring (Peach Arrows)
55 |
56 | - The validator first randomly samples an image or video from its local media cache.
57 | - The sampled media can be real, synthetic, or semisynthetic, and was either downloaded from an dataset on Huggingface or generated locally by one of many generative models.
58 | - The sampled media is then augmented by a pipeline of random transformations, adding to the challenge difficulty and mitigating incentive mechanism gaming via lookups.
59 | - The augmented media is then sent to miners for classification.
60 | - The validator scores the miners responses and logs comprehensive challenge results to Weights and Biases, including the generated media, original prompt, miner responses and rewards, and other challenge metadata.
61 |
62 |
63 |
64 |
65 | Data Generation and Downloads (Blue Arrows)
66 | The blue arrows show how the validator media cache is maintained by two parallel tracks:
67 |
68 | - The synthetic data generator coordinates a VLM and LLM to generate prompts for our suite of text-to-image, image-to-image, and text-to-video models. Each generated image/video is written to the cache along with the prompt, generation parameters, and other metadata.
69 | - The real data fetcher performs partial dataset downloads, fetching random compressed chunks of datasets from HuggingFace and unpacking random portions of these chunks into the cache along with their metadata. Partial downloads avoid requiring TBs of space for large video datasets like OpenVid1M.
70 |
71 |
72 |
73 |
74 | Organic Traffic (Green Arrows)
75 |
76 | Application requests are distributed to validators by an API server and load balancer in BitMind's cloud. A vector database caches subnet responses to avoid uncessary repetitive calls coming from salient images on the internet.
77 |
78 |
79 |
80 |
81 | ## Community
82 |
83 |
84 |
85 |
86 |
87 |
88 |
--------------------------------------------------------------------------------
/VERSION:
--------------------------------------------------------------------------------
1 | 3.0.10
2 |
--------------------------------------------------------------------------------
/bitmind/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "3.0.10"
2 |
3 | version_split = __version__.split(".")
4 | __spec_version__ = (
5 | (100000 * int(version_split[0]))
6 | + (1000 * int(version_split[1]))
7 | + (10 * int(version_split[2]))
8 | )
9 |
--------------------------------------------------------------------------------
/bitmind/autoupdater.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 | # Copyright © 2023 Yuma Rao
3 | # Copyright © 2024 Manifold Labs
4 | # Copyright © 2025 BitMind
5 |
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
7 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation
8 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
9 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
10 |
11 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
12 | # the Software.
13 |
14 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
15 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
17 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import signal
21 | import time
22 | import os
23 | import requests
24 | import bittensor as bt
25 | import bitmind
26 |
27 |
28 | def autoupdate(branch: str = "main", force=False):
29 | """
30 | Automatically updates the codebase to the latest version available on the specified branch.
31 |
32 | This function checks the remote repository for the latest version by fetching the VERSION file from the specified branch.
33 | If the local version is older than the remote version, it performs a git pull to update the local codebase to the latest version.
34 | After successfully updating, it restarts the application with the updated code.
35 |
36 | Args:
37 | - branch (str): The name of the branch to check for updates. Defaults to "main".
38 |
39 | Note:
40 | - The function assumes that the local codebase is a git repository and has the same structure as the remote repository.
41 | - It requires git to be installed and accessible from the command line.
42 | - The function will restart the application using the same command-line arguments it was originally started with.
43 | - If the update fails, manual intervention is required to resolve the issue and restart the application.
44 | """
45 | bt.logging.info("Checking for updates...")
46 | try:
47 | github_url = f"https://raw.githubusercontent.com/BitMind-AI/bitmind-subnet/{branch}/VERSION?ts={time.time()}"
48 | response = requests.get(
49 | github_url,
50 | headers={
51 | "Cache-Control": "no-cache, no-store, must-revalidate",
52 | "Pragma": "no-cache",
53 | "Expires": "0"
54 | },
55 | )
56 | response.raise_for_status()
57 | repo_version = response.content.decode()
58 | latest_version = int("".join(repo_version.split(".")))
59 | local_version = int("".join(bitmind.__version__.split(".")))
60 |
61 | bt.logging.info(f"Local version: {bitmind.__version__}")
62 | bt.logging.info(f"Latest version: {repo_version}")
63 |
64 | if latest_version > local_version or force:
65 | bt.logging.info(f"A newer version is available. Updating...")
66 | base_path = os.path.abspath(__file__)
67 | while os.path.basename(base_path) != "bitmind-subnet":
68 | base_path = os.path.dirname(base_path)
69 |
70 | os.system(f"cd {base_path} && git pull && chmod +x setup.sh && ./setup.sh")
71 |
72 | with open(os.path.join(base_path, "VERSION")) as f:
73 | new_version = f.read().strip()
74 | new_version = int("".join(new_version.split(".")))
75 |
76 | if new_version == latest_version:
77 | bt.logging.info("Updated successfully.")
78 |
79 | bt.logging.info("Restarting generator...")
80 | subprocess.run(["pm2", "reload", "sn34-generator"], check=True)
81 |
82 | bt.logging.info("Restarting proxy...")
83 | subprocess.run(["pm2", "reload", "sn34-proxy"], check=True)
84 |
85 | bt.logging.info(f"Restarting validator")
86 | os.kill(os.getpid(), signal.SIGINT)
87 | else:
88 | bt.logging.error("Update failed. Manual update required.")
89 | except Exception as e:
90 | bt.logging.error(f"Update check failed: {e}")
91 |
--------------------------------------------------------------------------------
/bitmind/cache/__init__.py:
--------------------------------------------------------------------------------
1 | from .cache_system import CacheSystem
2 |
--------------------------------------------------------------------------------
/bitmind/cache/cache_system.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional, Type
2 | import traceback
3 |
4 | import asyncio
5 | import bittensor as bt
6 |
7 | from bitmind.types import CacheUpdaterConfig, CacheConfig, Modality, MediaType
8 | from bitmind.cache.datasets import DatasetRegistry, initialize_dataset_registry
9 | from bitmind.cache.updater import (
10 | BaseUpdater,
11 | UpdaterRegistry,
12 | ImageUpdater,
13 | VideoUpdater,
14 | )
15 | from bitmind.cache.sampler import (
16 | BaseSampler,
17 | SamplerRegistry,
18 | ImageSampler,
19 | VideoSampler,
20 | )
21 |
22 |
23 | class CacheSystem:
24 | """
25 | Main facade for the caching system.
26 | """
27 |
28 | def __init__(self):
29 | self.dataset_registry = DatasetRegistry()
30 | self.updater_registry = UpdaterRegistry()
31 | self.sampler_registry = SamplerRegistry()
32 |
33 | async def initialize(
34 | self,
35 | base_dir,
36 | max_compressed_gb,
37 | max_media_gb,
38 | media_files_per_source,
39 | ):
40 | try:
41 | dataset_registry = initialize_dataset_registry()
42 | for dataset in dataset_registry.datasets:
43 | self.register_dataset(dataset)
44 |
45 | for modality in Modality:
46 | for media_type in MediaType:
47 | cache_config = CacheConfig(
48 | base_dir=base_dir,
49 | modality=modality.value,
50 | media_type=media_type.value,
51 | max_compressed_gb=max_compressed_gb,
52 | max_media_gb=max_media_gb,
53 | )
54 | sampler_class = (
55 | ImageSampler if modality == Modality.IMAGE else VideoSampler
56 | )
57 | self.create_sampler(
58 | name=f"{media_type.value}_{modality.value}_sampler",
59 | sampler_class=sampler_class,
60 | cache_config=cache_config,
61 | )
62 |
63 | # synthetic video updater not currently used, only generate locally
64 | if not (
65 | modality == Modality.VIDEO and media_type == MediaType.SYNTHETIC
66 | ):
67 | updater_config = CacheUpdaterConfig(
68 | num_sources_per_dataset=1, # one compressed source per dataset for initialization
69 | num_items_per_source=media_files_per_source,
70 | )
71 | updater_class = (
72 | ImageUpdater if modality == Modality.IMAGE else VideoUpdater
73 | )
74 | self.create_updater(
75 | name=f"{media_type.value}_{modality.value}_updater",
76 | updater_class=updater_class,
77 | cache_config=cache_config,
78 | updater_config=updater_config,
79 | )
80 |
81 | # Initialize caches (populate if empty)
82 | bt.logging.info("Starting initial cache population")
83 | await self.initialize_caches()
84 | bt.logging.info("Initial cache population complete")
85 |
86 | except Exception as e:
87 | bt.logging.error(f"Error initializing caches: {e}")
88 | bt.logging.error(traceback.format_exc())
89 |
90 | def register_dataset(self, dataset) -> None:
91 | """
92 | Register a dataset with the system.
93 |
94 | Args:
95 | dataset: Dataset configuration to register
96 | """
97 | self.dataset_registry.register(dataset)
98 |
99 | def register_datasets(self, datasets: List[Any]) -> None:
100 | """
101 | Register multiple datasets with the system.
102 |
103 | Args:
104 | datasets: List of dataset configurations to register
105 | """
106 | self.dataset_registry.register_all(datasets)
107 |
108 | def create_updater(
109 | self,
110 | name: str,
111 | updater_class: Type[BaseUpdater],
112 | cache_config: CacheConfig,
113 | updater_config: CacheUpdaterConfig,
114 | ) -> BaseUpdater:
115 | """
116 | Create and register an updater.
117 |
118 | Args:
119 | name: Unique name for the updater
120 | updater_class: Updater class to instantiate
121 | cache_config: Cache configuration
122 | updater_config: Updater configuration
123 |
124 | Returns:
125 | The created updater instance
126 | """
127 | updater = updater_class(
128 | cache_config=cache_config,
129 | updater_config=updater_config,
130 | data_manager=self.dataset_registry,
131 | )
132 | self.updater_registry.register(name, updater)
133 | return updater
134 |
135 | def create_sampler(
136 | self, name: str, sampler_class: Type[BaseSampler], cache_config: CacheConfig
137 | ) -> BaseSampler:
138 | """
139 | Create and register a sampler.
140 |
141 | Args:
142 | name: Unique name for the sampler
143 | sampler_class: Sampler class to instantiate
144 | cache_config: Cache configuration
145 |
146 | Returns:
147 | The created sampler instance
148 | """
149 | sampler = sampler_class(cache_config=cache_config)
150 | self.sampler_registry.register(name, sampler)
151 | return sampler
152 |
153 | async def initialize_caches(self) -> None:
154 | """
155 | Initialize all caches to ensure they have content.
156 | This is typically called during system startup.
157 | """
158 | updaters = self.updater_registry.get_all()
159 | names = [name for name, _ in updaters.items()]
160 | bt.logging.debug(f"Initializing {len(updaters)} caches: {names}")
161 |
162 | cache_init_tasks = []
163 | for name, updater in updaters.items():
164 | cache_init_tasks.append(updater.initialize_cache())
165 |
166 | if cache_init_tasks:
167 | await asyncio.gather(*cache_init_tasks)
168 |
169 | async def update_compressed_caches(self) -> None:
170 | """
171 | Update all compressed caches in parallel
172 | This is typically called from a block callback.
173 | """
174 | updaters = self.updater_registry.get_all()
175 | names = [name for name, _ in updaters.items()]
176 | bt.logging.trace(f"Updating {len(updaters)} compressed caches: {names}")
177 |
178 | tasks = []
179 | for name, updater in updaters.items():
180 | tasks.append(updater.update_compressed_cache())
181 |
182 | if tasks:
183 | await asyncio.gather(*tasks)
184 |
185 | async def update_media_caches(self) -> None:
186 | """
187 | Update all media caches in parallel.
188 | This is typically called from a block callback.
189 | """
190 | updaters = self.updater_registry.get_all()
191 | names = [name for name, _ in updaters.items()]
192 | bt.logging.debug(f"Updating {len(updaters)} media caches: {names}")
193 |
194 | tasks = []
195 | for name, updater in updaters.items():
196 | tasks.append(updater.update_media_cache())
197 |
198 | if tasks:
199 | await asyncio.gather(*tasks)
200 |
201 | async def sample(self, name: str, count: int, **kwargs) -> Optional[Dict[str, Any]]:
202 | """
203 | Sample from a specific sampler.
204 |
205 | Args:
206 | name: Name of the sampler to use
207 | count: Number of items to sample
208 |
209 | Returns:
210 | The sampled items or None if sampler not found
211 | """
212 | return await self.sampler_registry.sample(name, count, **kwargs)
213 |
214 | async def sample_all(self, count: int = 1) -> Dict[str, Dict[str, Any]]:
215 | """
216 | Sample from all samplers.
217 |
218 | Args:
219 | count: Number of items to sample from each sampler
220 |
221 | Returns:
222 | Dictionary mapping sampler names to their samples
223 | """
224 | return await self.sampler_registry.sample_all(count)
225 |
226 | @property
227 | def samplers(self):
228 | """
229 | Get all registered samplers.
230 |
231 | Returns:
232 | Dictionary of sampler names to sampler instances
233 | """
234 | return self.sampler_registry.get_all()
235 |
236 | @property
237 | def updaters(self):
238 | """
239 | Get all registered updaters.
240 |
241 | Returns:
242 | Dictionary of updater names to updater instances
243 | """
244 | return self.updater_registry.get_all()
245 |
--------------------------------------------------------------------------------
/bitmind/cache/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import initialize_dataset_registry
2 | from .dataset_registry import DatasetRegistry
3 |
--------------------------------------------------------------------------------
/bitmind/cache/datasets/dataset_registry.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | from bitmind.types import DatasetConfig, MediaType, Modality
4 |
5 |
6 | class DatasetRegistry:
7 | """
8 | Registry for dataset configurations with filtering capabilities.
9 | """
10 |
11 | def __init__(self):
12 | self.datasets: List[DatasetConfig] = []
13 |
14 | def register(self, dataset: DatasetConfig) -> None:
15 | """
16 | Register a dataset with the system.
17 |
18 | Args:
19 | dataset: Dataset configuration to register
20 | """
21 | self.datasets.append(dataset)
22 |
23 | def register_all(self, datasets: List[DatasetConfig]) -> None:
24 | """
25 | Register multiple datasets with the system.
26 |
27 | Args:
28 | datasets: List of dataset configurations to register
29 | """
30 | for dataset in datasets:
31 | self.register(dataset)
32 |
33 | def get_datasets(
34 | self,
35 | modality: Optional[Modality] = None,
36 | media_type: Optional[MediaType] = None,
37 | tags: Optional[List[str]] = None,
38 | exclude_tags: Optional[List[str]] = None,
39 | enabled_only: bool = True,
40 | ) -> List[DatasetConfig]:
41 | """
42 | Get datasets filtered by type, media_type, and/or tags.
43 |
44 | Args:
45 | modality: Filter by dataset type
46 | media_type: Filter by media_type
47 | tags: Filter by tags (dataset must have ALL specified tags)
48 | enabled_only: Only return enabled datasets
49 |
50 | Returns:
51 | List of matching datasets
52 | """
53 | result = self.datasets
54 |
55 | if enabled_only:
56 | result = [d for d in result if d.enabled]
57 |
58 | if modality:
59 | if isinstance(modality, str):
60 | modality = Modality(modality.lower())
61 | result = [d for d in result if d.type == modality]
62 |
63 | if media_type:
64 | if isinstance(media_type, str):
65 | media_type = MediaType(media_type.lower())
66 | result = [d for d in result if d.media_type == media_type]
67 |
68 | if tags:
69 | result = [d for d in result if all(tag in d.tags for tag in tags)]
70 |
71 | if exclude_tags:
72 | result = [
73 | d for d in result if all(tag not in d.tags for tag in exclude_tags)
74 | ]
75 |
76 | return result
77 |
78 | def enable_dataset(self, path: str, enabled: bool = True) -> bool:
79 | """
80 | Enable or disable a dataset by path.
81 |
82 | Args:
83 | path: Dataset path to enable/disable
84 | enabled: Whether to enable or disable
85 |
86 | Returns:
87 | True if successful, False if dataset not found
88 | """
89 | for dataset in self.datasets:
90 | if dataset.path == path:
91 | dataset.enabled = enabled
92 | return True
93 | return False
94 |
95 | def get_dataset_by_path(self, path: str) -> Optional[DatasetConfig]:
96 | """
97 | Get a dataset by its path.
98 |
99 | Args:
100 | path: Dataset path to find
101 |
102 | Returns:
103 | Dataset config or None if not found
104 | """
105 | for dataset in self.datasets:
106 | if dataset.path == path:
107 | return dataset
108 | return None
109 |
--------------------------------------------------------------------------------
/bitmind/cache/datasets/datasets.py:
--------------------------------------------------------------------------------
1 | """
2 | Dataset definitions for the validator cache system
3 | """
4 |
5 | from typing import List
6 |
7 | from bitmind.types import Modality, MediaType, DatasetConfig
8 |
9 |
10 | def get_image_datasets() -> List[DatasetConfig]:
11 | """
12 | Get the list of image datasets used by the validator.
13 |
14 | Returns:
15 | List of image dataset configurations
16 | """
17 | return [
18 | # Real image datasets
19 | DatasetConfig(
20 | path="bitmind/bm-eidon-image",
21 | type=Modality.IMAGE,
22 | media_type=MediaType.REAL,
23 | tags=["frontier"],
24 | ),
25 | DatasetConfig(
26 | path="bitmind/bm-real",
27 | type=Modality.IMAGE,
28 | media_type=MediaType.REAL,
29 | ),
30 | DatasetConfig(
31 | path="bitmind/open-image-v7-256",
32 | type=Modality.IMAGE,
33 | media_type=MediaType.REAL,
34 | tags=["diverse"],
35 | ),
36 | DatasetConfig(
37 | path="bitmind/celeb-a-hq",
38 | type=Modality.IMAGE,
39 | media_type=MediaType.REAL,
40 | tags=["faces", "high-quality"],
41 | ),
42 | DatasetConfig(
43 | path="bitmind/ffhq-256",
44 | type=Modality.IMAGE,
45 | media_type=MediaType.REAL,
46 | tags=["faces", "high-quality"],
47 | ),
48 | DatasetConfig(
49 | path="bitmind/MS-COCO-unique-256",
50 | type=Modality.IMAGE,
51 | media_type=MediaType.REAL,
52 | tags=["diverse"],
53 | ),
54 | DatasetConfig(
55 | path="bitmind/AFHQ",
56 | type=Modality.IMAGE,
57 | media_type=MediaType.REAL,
58 | tags=["animals", "high-quality"],
59 | ),
60 | DatasetConfig(
61 | path="bitmind/lfw",
62 | type=Modality.IMAGE,
63 | media_type=MediaType.REAL,
64 | tags=["faces"],
65 | ),
66 | DatasetConfig(
67 | path="bitmind/caltech-256",
68 | type=Modality.IMAGE,
69 | media_type=MediaType.REAL,
70 | tags=["objects", "categorized"],
71 | ),
72 | DatasetConfig(
73 | path="bitmind/caltech-101",
74 | type=Modality.IMAGE,
75 | media_type=MediaType.REAL,
76 | tags=["objects", "categorized"],
77 | ),
78 | DatasetConfig(
79 | path="bitmind/dtd",
80 | type=Modality.IMAGE,
81 | media_type=MediaType.REAL,
82 | tags=["textures"],
83 | ),
84 | DatasetConfig(
85 | path="bitmind/idoc-mugshots-images",
86 | type=Modality.IMAGE,
87 | media_type=MediaType.REAL,
88 | tags=["faces"],
89 | ),
90 | # Synthetic image datasets
91 | DatasetConfig(
92 | path="bitmind/JourneyDB",
93 | type=Modality.IMAGE,
94 | media_type=MediaType.SYNTHETIC,
95 | tags=["midjourney"],
96 | ),
97 | DatasetConfig(
98 | path="bitmind/GenImage_MidJourney",
99 | type=Modality.IMAGE,
100 | media_type=MediaType.SYNTHETIC,
101 | tags=["midjourney"],
102 | ),
103 | DatasetConfig(
104 | path="bitmind/bm-aura-imagegen",
105 | type=Modality.IMAGE,
106 | media_type=MediaType.SYNTHETIC,
107 | tags=["sora"],
108 | ),
109 | # Semisynthetic image datasets
110 | DatasetConfig(
111 | path="bitmind/face-swap",
112 | type=Modality.IMAGE,
113 | media_type=MediaType.SEMISYNTHETIC,
114 | tags=["faces", "manipulated"],
115 | ),
116 | ]
117 |
118 |
119 | def get_video_datasets() -> List[DatasetConfig]:
120 | """
121 | Get the list of video datasets used by the validator.
122 | """
123 | return [
124 | # Real video datasets
125 | DatasetConfig(
126 | path="bitmind/bm-eidon-video",
127 | type=Modality.VIDEO,
128 | media_type=MediaType.REAL,
129 | tags=["frontier"],
130 | compressed_format="zip",
131 | ),
132 | DatasetConfig(
133 | path="shangxd/imagenet-vidvrd",
134 | type=Modality.VIDEO,
135 | media_type=MediaType.REAL,
136 | tags=["diverse"],
137 | compressed_format="zip",
138 | ),
139 | DatasetConfig(
140 | path="nkp37/OpenVid-1M",
141 | type=Modality.VIDEO,
142 | media_type=MediaType.REAL,
143 | tags=["diverse", "large-zips"],
144 | compressed_format="zip",
145 | ),
146 | # Semisynthetic video datasets
147 | DatasetConfig(
148 | path="bitmind/semisynthetic-video",
149 | type=Modality.VIDEO,
150 | media_type=MediaType.SEMISYNTHETIC,
151 | tags=["faces"],
152 | compressed_format="zip",
153 | ),
154 | ]
155 |
156 |
157 | def initialize_dataset_registry():
158 | """
159 | Initialize and populate the dataset registry.
160 |
161 | Returns:
162 | Fully populated DatasetRegistry instance
163 | """
164 | from bitmind.cache.datasets.dataset_registry import DatasetRegistry
165 |
166 | registry = DatasetRegistry()
167 |
168 | registry.register_all(get_image_datasets())
169 | registry.register_all(get_video_datasets())
170 |
171 | return registry
172 |
--------------------------------------------------------------------------------
/bitmind/cache/sampler/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import BaseSampler
2 | from .image_sampler import ImageSampler
3 | from .video_sampler import VideoSampler
4 | from .sampler_registry import SamplerRegistry
5 |
--------------------------------------------------------------------------------
/bitmind/cache/sampler/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from pathlib import Path
3 | from typing import Any, Dict, List
4 |
5 |
6 | from bitmind.cache.cache_fs import CacheFS
7 | from bitmind.types import CacheConfig
8 |
9 |
10 | class BaseSampler(ABC):
11 | """
12 | Base class for samplers that provide access to cached media.
13 | """
14 |
15 | def __init__(self, cache_config: CacheConfig):
16 | self.cache_fs = CacheFS(cache_config)
17 |
18 | @property
19 | @abstractmethod
20 | def media_file_extensions(self) -> List[str]:
21 | """List of file extensions supported by this sampler"""
22 | pass
23 |
24 | @abstractmethod
25 | async def sample(self, count: int) -> Dict[str, Any]:
26 | """
27 | Sample items from the media cache.
28 |
29 | Args:
30 | count: Number of items to sample
31 |
32 | Returns:
33 | Dictionary with sampled items information
34 | """
35 | pass
36 |
37 | def get_available_files(self, use_index=True) -> List[Path]:
38 | """Get list of available files in the media cache"""
39 | return self.cache_fs.get_files(
40 | cache_type="media",
41 | file_extensions=self.media_file_extensions,
42 | use_index=use_index,
43 | )
44 |
45 | def get_available_count(self, use_index=True) -> int:
46 | """Get count of available files in the media cache"""
47 | return len(self.get_available_files(use_index))
48 |
--------------------------------------------------------------------------------
/bitmind/cache/sampler/image_sampler.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import io
4 | from pathlib import Path
5 | from typing import Dict, List, Any
6 |
7 | import cv2
8 | import numpy as np
9 |
10 | from bitmind.cache.sampler.base import BaseSampler
11 | from bitmind.cache.cache_fs import CacheConfig
12 |
13 |
14 | class ImageSampler(BaseSampler):
15 | """
16 | Sampler for cached image data.
17 |
18 | This class provides access to images in the media cache,
19 | allowing sampling with or without metadata.
20 | """
21 |
22 | def __init__(self, cache_config: CacheConfig):
23 | super().__init__(cache_config)
24 |
25 | @property
26 | def media_file_extensions(self) -> List[str]:
27 | """List of file extensions supported by this sampler"""
28 | return [".jpg", ".jpeg", ".png", ".webp"]
29 |
30 | async def sample(
31 | self,
32 | count: int = 1,
33 | remove_from_cache: bool = False,
34 | as_float32: bool = False,
35 | channels_first: bool = False,
36 | as_rgb: bool = True,
37 | ) -> Dict[str, Any]:
38 | """
39 | Sample random images and their metadata from the cache.
40 |
41 | Args:
42 | count: Number of images to sample
43 | remove_from_cache: Whether to remove sampled images from cache
44 |
45 | Returns:
46 | Dictionary containing:
47 | - count: Number of images successfully sampled
48 | - items: List of dictionaries containing:
49 | - image: Image as numpy array in BGR format with shape (H, W, C)
50 | - path: Path to the image file
51 | - dataset: Source dataset name (if available)
52 | - metadata: Additional metadata
53 | """
54 | cached_files = self.cache_fs.get_files(
55 | cache_type="media",
56 | file_extensions=self.media_file_extensions,
57 | group_by_source=True,
58 | )
59 |
60 | if not cached_files:
61 | self.cache_fs._log_warning("No images available in cache")
62 | return {"count": 0, "items": []}
63 |
64 | sampled_items = []
65 |
66 | attempts = 0
67 | max_attempts = count * 3
68 |
69 | while len(sampled_items) < count and attempts < max_attempts:
70 | attempts += 1
71 |
72 | source = random.choice(list(cached_files.keys()))
73 | if not cached_files[source]:
74 | del cached_files[source]
75 | if not cached_files:
76 | break
77 | continue
78 |
79 | image_path = random.choice(cached_files[source])
80 |
81 | try:
82 | # Read image directly as numpy array using cv2
83 | image = cv2.imread(str(image_path))
84 | if image is None:
85 | raise ValueError(f"Failed to load image {image_path}")
86 |
87 | if as_float32: # else np.uint8
88 | image = image.astype(np.float32) / 255.0
89 |
90 | if as_rgb: # else bgr
91 | image = image[:, :, [2, 1, 0]]
92 |
93 | if channels_first: # else channels last
94 | image = np.transpose(image, (2, 0, 1))
95 |
96 | metadata = {}
97 | metadata_path = image_path.with_suffix(".json")
98 | if metadata_path.exists():
99 | try:
100 | with open(metadata_path, "r") as f:
101 | metadata = json.load(f)
102 | except Exception as e:
103 | self.cache_fs._log_warning(
104 | f"Error loading metadata for {image_path}: {e}"
105 | )
106 |
107 | item = {
108 | "image": image,
109 | "path": str(image_path),
110 | "metadata_path": str(metadata_path),
111 | "metadata": metadata,
112 | }
113 |
114 | if "source_parquet" in metadata:
115 | item["source"] = metadata["source_parquet"]
116 |
117 | if "original_index" in metadata:
118 | item["index"] = metadata["original_index"]
119 |
120 | sampled_items.append(item)
121 |
122 | if remove_from_cache:
123 | try:
124 | image_path.unlink(missing_ok=True)
125 | metadata_path.unlink(missing_ok=True)
126 | cached_files[source].remove(image_path)
127 | except Exception as e:
128 | self.cache_fs._log_warning(
129 | f"Failed to remove {image_path}: {e}"
130 | )
131 |
132 | except Exception as e:
133 | self.cache_fs._log_warning(f"Failed to load image {image_path}: {e}")
134 | cached_files[source].remove(image_path)
135 | continue
136 |
137 | return {"count": len(sampled_items), "items": sampled_items}
138 |
--------------------------------------------------------------------------------
/bitmind/cache/sampler/sampler_registry.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional, Any
2 |
3 | import bittensor as bt
4 |
5 | from .base import BaseSampler
6 |
7 |
8 | class SamplerRegistry:
9 | """
10 | Registry for cache samplers.
11 | """
12 |
13 | def __init__(self):
14 | self._samplers: Dict[str, BaseSampler] = {}
15 |
16 | def register(self, name: str, sampler: BaseSampler) -> None:
17 | if name in self._samplers:
18 | bt.logging.warning(f"Sampler {name} already registered, will be replaced")
19 | self._samplers[name] = sampler
20 |
21 | def get(self, name: str) -> Optional[BaseSampler]:
22 | return self._samplers.get(name)
23 |
24 | def get_all(self) -> Dict[str, BaseSampler]:
25 | return dict(self._samplers)
26 |
27 | def deregister(self, name: str) -> None:
28 | if name in self._samplers:
29 | del self._samplers[name]
30 |
31 | async def sample(self, name: str, count: int, **kwargs) -> Optional[Dict[str, Any]]:
32 | """
33 | Sample from a specific sampler.
34 |
35 | Args:
36 | name: Name of the sampler to use
37 | count: Number of items to sample
38 |
39 | Returns:
40 | The sampled items or None if sampler not found
41 | """
42 | sampler = self.get(name)
43 | if not sampler:
44 | bt.logging.error(f"Sampler {name} not found")
45 | return None
46 |
47 | return await sampler.sample(count, **kwargs)
48 |
49 | async def sample_all(self, count_per_sampler: int = 1) -> Dict[str, Dict[str, Any]]:
50 | """
51 | Sample from all samplers.
52 |
53 | Args:
54 | count_per_sampler: Number of items to sample from each sampler
55 |
56 | Returns:
57 | Dictionary mapping sampler names to their samples
58 | """
59 | results = {}
60 | for name, sampler in self._samplers.items():
61 | results[name] = await sampler.sample(count_per_sampler)
62 | return results
63 |
--------------------------------------------------------------------------------
/bitmind/cache/sampler/video_sampler.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import math
4 | import random
5 | import tempfile
6 |
7 | from pathlib import Path
8 | from typing import Dict, List, Any, Optional
9 | from io import BytesIO
10 |
11 | import ffmpeg
12 | import numpy as np
13 | from PIL import Image
14 |
15 | from bitmind.cache.sampler.base import BaseSampler
16 | from bitmind.cache.cache_fs import CacheConfig
17 | from bitmind.cache.util.video import get_video_metadata
18 |
19 |
20 | class VideoSampler(BaseSampler):
21 | """
22 | Sampler for cached video data.
23 |
24 | This class provides access to videos in the media cache,
25 | allowing sampling of video segments as binary data.
26 | """
27 |
28 | def __init__(self, cache_config: CacheConfig):
29 | super().__init__(cache_config)
30 |
31 | @property
32 | def media_file_extensions(self) -> List[str]:
33 | """List of file extensions supported by this sampler"""
34 | return [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".webm"]
35 |
36 | async def sample(
37 | self,
38 | count: int = 1,
39 | remove_from_cache: bool = False,
40 | min_duration: float = 1.0,
41 | max_duration: float = 6.0,
42 | max_frames: int = 144,
43 | ) -> Dict[str, Any]:
44 | """
45 | Sample random video segments from the cache.
46 |
47 | Args:
48 | count: Number of videos to sample
49 | remove_from_cache: Whether to remove sampled videos from cache
50 |
51 | Returns:
52 | Dictionary containing:
53 | - count: Number of videos successfully sampled
54 | - items: List of dictionaries containing video binary data and metadata
55 | """
56 | cached_files = self.cache_fs.get_files(
57 | cache_type="media",
58 | file_extensions=self.media_file_extensions,
59 | group_by_source=True,
60 | )
61 |
62 | if not cached_files:
63 | self.cache_fs._log_warning("No videos available in cache")
64 | return {"count": 0, "items": []}
65 |
66 | sampled_items = []
67 | for _ in range(count):
68 | if not cached_files:
69 | break
70 |
71 | video_result = await self._sample_frames(
72 | files=cached_files,
73 | min_duration=min_duration,
74 | max_duration=max_duration,
75 | max_frames=max_frames,
76 | remove_from_cache=remove_from_cache,
77 | )
78 |
79 | if video_result:
80 | sampled_items.append(video_result)
81 |
82 | return {"count": len(sampled_items), "items": sampled_items}
83 |
84 | async def _sample_frames(
85 | self,
86 | files,
87 | min_duration: float = 1.0,
88 | max_duration: float = 6.0,
89 | max_fps: float = 30.0,
90 | max_frames: int = 144,
91 | remove_from_cache: bool = False,
92 | as_float32: bool = False,
93 | channels_first: bool = False,
94 | as_rgb: bool = True,
95 | ) -> Optional[Dict[str, Any]]:
96 | """
97 | Sample a random video segment and return it as a numpy array.
98 |
99 | Args:
100 | files: Dict mapping source names to lists of video file paths
101 | min_duration: Minimum duration of video segment to extract in seconds
102 | max_duration: Maximum duration of video segment to extract in seconds
103 | max_fps: Maximum frame rate to use when sampling frames
104 | max_frames: Maximum number of frames to extract
105 | remove_from_cache: Whether to remove the source video from cache
106 | as_float32: Whether to return frames as float32 (0-1) instead of uint8 (0-255)
107 | channels_first: Whether to return frames with channels first (TCHW) instead of channels last (THWC)
108 | as_rgb: Whether to return frames in RGB format (True) or BGR format (False)
109 |
110 | Returns:
111 | Dictionary containing:
112 | - frames: Video frames as numpy array with shape (T,H,W,C)
113 | - metadata: Video metadata
114 | - source: Source information
115 | - segment: Information about the extracted segment
116 | Or None if sampling fails
117 | """
118 | for _ in range(5):
119 | if not files:
120 | self.cache_fs._log_warning("No more videos available to try")
121 | return None
122 |
123 | source = random.choice(list(files.keys()))
124 | if not files[source]:
125 | del files[source]
126 | continue
127 |
128 | video_path = random.choice(files[source])
129 |
130 | try:
131 | if not video_path.exists():
132 | files[source].remove(video_path)
133 | continue
134 |
135 | try:
136 | video_info = get_video_metadata(str(video_path))
137 | total_duration = video_info.get("duration", 0)
138 | width = int(video_info.get("width", 256))
139 | height = int(video_info.get("height", 256))
140 | reported_fps = float(video_info.get("fps", max_fps))
141 | except Exception as e:
142 | self.cache_fs._log_error(
143 | f"Unable to extract video metadata from {str(video_path)}: {e}"
144 | )
145 | files[source].remove(video_path)
146 | continue
147 |
148 | if (
149 | reported_fps > max_fps
150 | or reported_fps <= 0
151 | or not math.isfinite(reported_fps)
152 | ):
153 | self.cache_fs._log_warning(
154 | f"Unreasonable FPS ({reported_fps}) detected in {video_path}, capping at {max_fps}"
155 | )
156 | frame_rate = max_fps
157 | else:
158 | frame_rate = reported_fps
159 |
160 | target_duration = random.uniform(min_duration, max_duration)
161 | target_duration = min(target_duration, total_duration)
162 |
163 | num_frames = int(target_duration * frame_rate) + 1
164 | num_frames = min(num_frames, max_frames)
165 |
166 | actual_duration = (num_frames - 1) / frame_rate
167 |
168 | max_start = max(0, total_duration - actual_duration)
169 | start_time = random.uniform(0, max_start)
170 |
171 | frames = []
172 | no_data = []
173 |
174 | for i in range(num_frames):
175 | timestamp = start_time + (i / frame_rate)
176 |
177 | try:
178 | out_bytes, err = (
179 | ffmpeg.input(str(video_path), ss=str(timestamp))
180 | .filter("select", "eq(n,0)")
181 | .output(
182 | "pipe:",
183 | vframes=1,
184 | format="image2",
185 | vcodec="png",
186 | loglevel="error",
187 | )
188 | .run(capture_stdout=True, capture_stderr=True)
189 | )
190 |
191 | if not out_bytes:
192 | no_data.append(timestamp)
193 | continue
194 |
195 | try:
196 | frame = Image.open(BytesIO(out_bytes))
197 | frame.load() # Verify image can be loaded
198 | frames.append(np.array(frame))
199 | except Exception as e:
200 | self.cache_fs._log_error(
201 | f"Failed to process frame at {timestamp}s: {e}"
202 | )
203 | continue
204 |
205 | except ffmpeg.Error as e:
206 | self.cache_fs._log_error(
207 | f"FFmpeg error at {timestamp}s: {e.stderr.decode()}"
208 | )
209 | continue
210 |
211 | if len(no_data) > 0:
212 | tmin, tmax = min(no_data), max(no_data)
213 | self.cache_fs._log_warning(
214 | f"No data received for {len(no_data)} frames between {tmin} and {tmax}"
215 | )
216 |
217 | if not frames:
218 | self.cache_fs._log_warning(
219 | f"No frames successfully extracted from {video_path}"
220 | )
221 | files[source].remove(video_path)
222 | continue
223 |
224 | frames = np.stack(frames, axis=0)
225 |
226 | if as_float32:
227 | frames = frames.astype(np.float32) / 255.0
228 |
229 | if not as_rgb:
230 | frames = frames[:, :, :, [2, 1, 0]] # RGB to BGR
231 |
232 | if channels_first:
233 | frames = np.transpose(frames, (0, 3, 1, 2))
234 |
235 | metadata = {}
236 | metadata_path = video_path.with_suffix(".json")
237 | if metadata_path.exists():
238 | try:
239 | with open(metadata_path, "r") as f:
240 | metadata = json.load(f)
241 | except Exception as e:
242 | self.cache_fs._log_warning(
243 | f"Error loading metadata for {video_path}: {e}"
244 | )
245 |
246 | result = {
247 | "video": frames,
248 | "path": str(video_path),
249 | "metadata_path": str(metadata_path),
250 | "metadata": metadata,
251 | "segment": {
252 | "start_time": start_time,
253 | "duration": actual_duration,
254 | "fps": frame_rate,
255 | "width": width,
256 | "height": height,
257 | "num_frames": len(frames),
258 | },
259 | }
260 |
261 | if remove_from_cache:
262 | try:
263 | video_path.unlink(missing_ok=True)
264 | metadata_path.unlink(missing_ok=True)
265 | files[source].remove(video_path)
266 | except Exception as e:
267 | self.cache_fs._log_warning(
268 | f"Failed to remove {video_path}: {e}"
269 | )
270 |
271 | self.cache_fs._log_debug(
272 | f"Successfully sampled {actual_duration}s segment ({len(frames)} frames)"
273 | )
274 | return result
275 |
276 | except Exception as e:
277 | self.cache_fs._log_error(f"Error sampling from {video_path}: {e}")
278 | files[source].remove(video_path)
279 |
280 | self.cache_fs._log_error("Failed to sample any video after multiple attempts")
281 | return None
282 |
--------------------------------------------------------------------------------
/bitmind/cache/updater/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import BaseUpdater
2 | from .image_updater import ImageUpdater
3 | from .video_updater import VideoUpdater
4 | from .updater_registry import UpdaterRegistry
5 |
--------------------------------------------------------------------------------
/bitmind/cache/updater/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from pathlib import Path
3 | from typing import Any, List, Optional
4 |
5 | import numpy as np
6 |
7 | from bitmind.cache.cache_fs import CacheFS
8 | from bitmind.cache.datasets import DatasetRegistry
9 | from bitmind.types import CacheUpdaterConfig, CacheConfig, CacheType
10 | from bitmind.cache.util.download import list_hf_files, download_files
11 | from bitmind.cache.util.filesystem import (
12 | filter_ready_files,
13 | wait_for_downloads_to_complete,
14 | is_source_complete,
15 | )
16 |
17 |
18 | class BaseUpdater(ABC):
19 | """
20 | Base class for cache updaters that handle downloading and extracting data.
21 |
22 | This version is designed to work with block callbacks rather than having
23 | its own internal timing logic.
24 | """
25 |
26 | def __init__(
27 | self,
28 | cache_config: CacheConfig,
29 | updater_config: CacheUpdaterConfig,
30 | data_manager: DatasetRegistry,
31 | ):
32 | self.cache_fs = CacheFS(cache_config)
33 | self.updater_config = updater_config
34 | self.dataset_registry = data_manager
35 | self._datasets = self._get_filtered_datasets()
36 | self._recently_downloaded_files = []
37 |
38 | def _get_filtered_datasets(
39 | self,
40 | modality: Optional[str] = None,
41 | media_type: Optional[str] = None,
42 | tags: Optional[List[str]] = None,
43 | exclude_tags: Optional[List[str]] = None,
44 | ) -> List[Any]:
45 | """Get datasets that match the cache configuration"""
46 | modality = self.cache_fs.config.modality if modality is None else modality
47 | media_type = (
48 | self.cache_fs.config.media_type if media_type is None else media_type
49 | )
50 | tags = self.cache_fs.config.tags if tags is None else tags
51 |
52 | return self.dataset_registry.get_datasets(
53 | modality=self.cache_fs.config.modality,
54 | media_type=self.cache_fs.config.media_type,
55 | tags=self.cache_fs.config.tags,
56 | exclude_tags=exclude_tags,
57 | )
58 |
59 | @property
60 | @abstractmethod
61 | def media_file_extensions(self) -> List[str]:
62 | pass
63 |
64 | @property
65 | @abstractmethod
66 | def compressed_file_extension(self) -> str:
67 | pass
68 |
69 | @abstractmethod
70 | async def _extract_items_from_source(
71 | self, source_path: Path, count: int
72 | ) -> List[Path]:
73 | pass
74 |
75 | async def initialize_cache(self) -> None:
76 | """
77 | This performs a one-time initialization to ensure the cache has
78 | content available, particularly useful during first startup.
79 | """
80 | self.cache_fs._log_debug("Setting up cache")
81 |
82 | if self.cache_fs.is_empty(CacheType.MEDIA):
83 | if self.cache_fs.is_empty(CacheType.COMPRESSED):
84 | self.cache_fs._log_debug("Compressed cache empty; populating")
85 | await self.update_compressed_cache(
86 | n_sources_per_dataset=1,
87 | n_datasets=1,
88 | exclude_tags=["large-zips"],
89 | maybe_prune=False,
90 | )
91 |
92 | self.cache_fs._log_debug(
93 | "Waiting for compressed files to finish downloading..."
94 | )
95 | await wait_for_downloads_to_complete(
96 | self._recently_downloaded_files,
97 | )
98 | self._recently_downloaded_files = []
99 |
100 | self.cache_fs._log_debug(
101 | "Compressed files downloaded. Updating media cache."
102 | )
103 | await self.update_media_cache(maybe_prune=False)
104 | else:
105 | self.cache_fs._log_debug(
106 | "Compressed sources available; Media cache empty; populating"
107 | )
108 | await self.update_media_cache()
109 |
110 | async def update_compressed_cache(
111 | self,
112 | n_sources_per_dataset: Optional[int] = None,
113 | n_datasets: Optional[int] = None,
114 | exclude_tags: Optional[List[str]] = None,
115 | maybe_prune: bool = True,
116 | ) -> None:
117 | """
118 | Update the compressed cache by downloading new files.
119 |
120 | Args:
121 | n_sources_per_dataset: Optional override for number of sources per dataset
122 | n_datasets: Optional limit on number of datasets to process
123 | """
124 | if n_sources_per_dataset is None:
125 | n_sources_per_dataset = self.updater_config.num_sources_per_dataset
126 |
127 | if maybe_prune:
128 | await self.cache_fs.maybe_prune_cache(
129 | cache_type=CacheType.COMPRESSED,
130 | file_extensions=[self.compressed_file_extension],
131 | )
132 |
133 | # Reset tracking list before new downloads
134 | self._recently_downloaded_files = []
135 |
136 | datasets = self._get_filtered_datasets(exclude_tags=exclude_tags)
137 | if n_datasets is not None and n_datasets > 0:
138 | datasets = datasets[:n_datasets]
139 | np.random.shuffle(datasets)
140 |
141 | new_files = []
142 | for dataset in datasets:
143 | try:
144 | filenames = self._list_remote_dataset_files(dataset.path)
145 | if not filenames:
146 | self.cache_fs._log_warning(f"No files found for {dataset.path}")
147 | continue
148 |
149 | remote_paths = self._get_download_urls(dataset.path, filenames)
150 | to_download = self._select_files_to_download(
151 | remote_paths, n_sources_per_dataset
152 | )
153 |
154 | output_dir = self.cache_fs.compressed_dir / dataset.path.split("/")[-1]
155 |
156 | self.cache_fs._log_debug(
157 | f"Downloading {len(to_download)} files from {dataset.path}"
158 | )
159 | batch_files = await self._download_files(to_download, output_dir)
160 |
161 | # Track downloaded files
162 | self._recently_downloaded_files.extend(batch_files)
163 | new_files.extend(batch_files)
164 | except Exception as e:
165 | self.cache_fs._log_error(f"Error downloading from {dataset.path}: {e}")
166 |
167 | if new_files:
168 | self.cache_fs._log_debug(f"Added {len(new_files)} new compressed files")
169 | else:
170 | self.cache_fs._log_warning(f"No new files were added to compressed cache")
171 |
172 | async def update_media_cache(
173 | self, n_items_per_source: Optional[int] = None, maybe_prune: bool = True
174 | ) -> None:
175 | """
176 | Update the media cache by extracting from compressed sources.
177 |
178 | Args:
179 | n_items_per_source: Optional override for number of items per source
180 | """
181 | if n_items_per_source is None:
182 | n_items_per_source = self.updater_config.num_items_per_source
183 |
184 | if maybe_prune:
185 | await self.cache_fs.maybe_prune_cache(
186 | cache_type=CacheType.MEDIA, file_extensions=self.media_file_extensions
187 | )
188 |
189 | all_compressed_files = self.cache_fs.get_files(
190 | cache_type=CacheType.COMPRESSED,
191 | file_extensions=[self.compressed_file_extension],
192 | use_index=False,
193 | )
194 |
195 | if not all_compressed_files:
196 | self.cache_fs._log_warning(f"No compressed sources available")
197 | return
198 |
199 | compressed_files = filter_ready_files(all_compressed_files)
200 |
201 | if not compressed_files:
202 | self.cache_fs._log_warning(
203 | f"No ready compressed sources available. Files may still be downloading."
204 | )
205 | return
206 |
207 | valid_compressed_files = []
208 | for path in compressed_files:
209 | if not is_source_complete(path):
210 | try:
211 | Path(path).unlink()
212 | except Exception as del_err:
213 | self.cache_fs._log_error(
214 | f"Failed to delete corrupted file {path}: {del_err}"
215 | )
216 | else:
217 | valid_compressed_files.append(path)
218 |
219 | if len(valid_compressed_files) > 10:
220 | valid_compressed_files = np.random.choice(
221 | valid_compressed_files, size=10, replace=False
222 | ).tolist()
223 |
224 | new_files = []
225 | for source in valid_compressed_files:
226 | try:
227 | items = await self._extract_items_from_source(
228 | source, n_items_per_source
229 | )
230 | new_files.extend(items)
231 | except Exception as e:
232 | self.cache_fs._log_error(f"Error extracting from {source}: {e}")
233 |
234 | if new_files:
235 | self.cache_fs._log_debug(f"Added {len(new_files)} new items to media cache")
236 | else:
237 | self.cache_fs._log_warning(f"No new items were added to media cache")
238 |
239 | def num_media_files(self) -> int:
240 | count = self.cache_fs.num_files(CacheType.MEDIA, self.media_file_extensions)
241 | return count == 0
242 |
243 | def num_compressed_files(self) -> int:
244 | count = self.cache_fs.num_files(
245 | CacheType.COMPRESSED, [self.compressed_file_extension]
246 | )
247 | return count == 0
248 |
249 | def _select_files_to_download(self, urls: List[str], count: int) -> List[str]:
250 | """Select random files to download"""
251 | return np.random.choice(
252 | urls, size=min(count, len(urls)), replace=False
253 | ).tolist()
254 |
255 | def _list_remote_dataset_files(self, dataset_path: str) -> List[str]:
256 | """List available files in a dataset with the parquet extension"""
257 | return list_hf_files(
258 | repo_id=dataset_path, extension=self.compressed_file_extension
259 | )
260 |
261 | def _get_download_urls(self, dataset_path: str, filenames: List[str]) -> List[str]:
262 | """Get Hugging Face download URLs for data files"""
263 | return [
264 | f"https://huggingface.co/datasets/{dataset_path}/resolve/main/{f}"
265 | for f in filenames
266 | ]
267 |
268 | async def _download_files(self, urls: List[str], output_dir: Path) -> List[Path]:
269 | """Download a subset of a remote dataset's compressed files"""
270 | return await download_files(urls, output_dir)
271 |
--------------------------------------------------------------------------------
/bitmind/cache/updater/image_updater.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import List
3 | import traceback
4 |
5 | from bitmind.cache.updater import BaseUpdater
6 | from bitmind.cache.datasets import DatasetRegistry
7 | from bitmind.cache.util.filesystem import is_parquet_complete
8 | from bitmind.types import CacheUpdaterConfig, CacheConfig
9 |
10 |
11 | class ImageUpdater(BaseUpdater):
12 | """
13 | Updater for image data from parquet files.
14 |
15 | This class handles downloading parquet files from Hugging Face datasets
16 | and extracting images from them into the media cache.
17 | """
18 |
19 | def __init__(
20 | self,
21 | cache_config: CacheConfig,
22 | updater_config: CacheUpdaterConfig,
23 | data_manager: DatasetRegistry,
24 | ):
25 | super().__init__(
26 | cache_config=cache_config,
27 | updater_config=updater_config,
28 | data_manager=data_manager,
29 | )
30 |
31 | @property
32 | def media_file_extensions(self) -> List[str]:
33 | """List of file extensions supported by this updater"""
34 | return [".jpg", ".jpeg", ".png", ".webp"]
35 |
36 | @property
37 | def compressed_file_extension(self) -> str:
38 | """File extension for compressed source files"""
39 | return ".parquet"
40 |
41 | async def _extract_items_from_source(
42 | self, source_path: Path, count: int
43 | ) -> List[Path]:
44 | """
45 | Extract images from a parquet file.
46 |
47 | Args:
48 | source_path: Path to the parquet file
49 | count: Number of images to extract
50 |
51 | Returns:
52 | List of paths to extracted image files
53 | """
54 | self.cache_fs._log_trace(f"Extracting up to {count} images from {source_path}")
55 |
56 | dataset_name = source_path.parent.name
57 | if not dataset_name:
58 | dataset_name = source_path.stem
59 |
60 | dest_dir = self.cache_fs.cache_dir / dataset_name
61 | dest_dir.mkdir(parents=True, exist_ok=True)
62 |
63 | try:
64 | from ..util import extract_images_from_parquet
65 |
66 | saved_files = extract_images_from_parquet(
67 | parquet_path=source_path, dest_dir=dest_dir, num_images=count
68 | )
69 |
70 | self.cache_fs._log_trace(
71 | f"Extracted {len(saved_files)} images from {source_path}"
72 | )
73 | return [Path(f) for f in saved_files]
74 |
75 | except Exception as e:
76 | self.cache_fs._log_error(f"Error extracting images from {source_path}: {e}")
77 | self.cache_fs._log_error(traceback.format_exc())
78 | return []
79 |
--------------------------------------------------------------------------------
/bitmind/cache/updater/updater_registry.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional
2 |
3 | import bittensor as bt
4 |
5 | from bitmind.cache.updater import BaseUpdater
6 |
7 |
8 | class UpdaterRegistry:
9 | """
10 | Registry for cache updaters.
11 | """
12 |
13 | def __init__(self):
14 | self._updaters: Dict[str, BaseUpdater] = {}
15 |
16 | def register(self, name: str, updater: BaseUpdater) -> None:
17 | if name in self._updaters:
18 | bt.logging.warning(f"Updater {name} already registered, will be replaced")
19 | self._updaters[name] = updater
20 |
21 | def get(self, name: str) -> Optional[BaseUpdater]:
22 | return self._updaters.get(name)
23 |
24 | def get_all(self) -> Dict[str, BaseUpdater]:
25 | return dict(self._updaters)
26 |
27 | def deregister(self, name: str) -> None:
28 | if name in self._updaters:
29 | del self._updaters[name]
30 |
--------------------------------------------------------------------------------
/bitmind/cache/updater/video_updater.py:
--------------------------------------------------------------------------------
1 | import zipfile
2 | from pathlib import Path
3 | from typing import List
4 |
5 | from bitmind.types import CacheUpdaterConfig, CacheConfig
6 | from bitmind.cache.updater import BaseUpdater
7 | from bitmind.cache.datasets import DatasetRegistry
8 |
9 |
10 | class VideoUpdater(BaseUpdater):
11 | """
12 | Updater for video data from zip files.
13 |
14 | This class handles downloading zip files from Hugging Face datasets
15 | and extracting videos from them into the media cache.
16 | """
17 |
18 | def __init__(
19 | self,
20 | cache_config: CacheConfig,
21 | updater_config: CacheUpdaterConfig,
22 | data_manager: DatasetRegistry,
23 | ):
24 | super().__init__(
25 | cache_config=cache_config,
26 | updater_config=updater_config,
27 | data_manager=data_manager,
28 | )
29 |
30 | @property
31 | def media_file_extensions(self) -> List[str]:
32 | """List of file extensions supported by this updater"""
33 | return [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".webm"]
34 |
35 | @property
36 | def compressed_file_extension(self) -> str:
37 | """File extension for compressed source files"""
38 | return ".zip"
39 |
40 | async def _extract_items_from_source(
41 | self, source_path: Path, count: int
42 | ) -> List[Path]:
43 | """
44 | Extract videos from a zip file.
45 |
46 | Args:
47 | source_path: Path to the zip file
48 | count: Number of videos to extract
49 |
50 | Returns:
51 | List of paths to extracted video files
52 | """
53 | self.cache_fs._log_trace(f"Extracting up to {count} videos from {source_path}")
54 |
55 | dataset_name = source_path.parent.name
56 | if not dataset_name:
57 | dataset_name = source_path.stem
58 |
59 | dest_dir = self.cache_fs.cache_dir / dataset_name
60 | dest_dir.mkdir(parents=True, exist_ok=True)
61 |
62 | try:
63 | from ..util import extract_videos_from_zip
64 |
65 | extracted_pairs = extract_videos_from_zip(
66 | zip_path=source_path,
67 | dest_dir=dest_dir,
68 | num_videos=count,
69 | file_extensions=set(self.media_file_extensions),
70 | )
71 |
72 | # extract_videos_from_zip returns pairs of (video_path, metadata_path)
73 | # We just need the video paths for our return value
74 | video_paths = [Path(pair[0]) for pair in extracted_pairs]
75 |
76 | self.cache_fs._log_trace(
77 | f"Extracted {len(video_paths)} videos from {source_path}"
78 | )
79 | return video_paths
80 |
81 | except Exception as e:
82 | self.cache_fs._log_trace(f"Error extracting videos from {source_path}: {e}")
83 | return []
84 |
--------------------------------------------------------------------------------
/bitmind/cache/util/__init__.py:
--------------------------------------------------------------------------------
1 | from bitmind.cache.util.filesystem import (
2 | is_source_complete,
3 | is_zip_complete,
4 | is_parquet_complete,
5 | get_most_recent_update_time,
6 | )
7 |
8 | from bitmind.cache.util.download import (
9 | download_files,
10 | list_hf_files,
11 | openvid1m_err_handler,
12 | )
13 |
14 | from bitmind.cache.util.video import (
15 | get_video_duration,
16 | get_video_metadata,
17 | seconds_to_str,
18 | )
19 |
20 | from bitmind.cache.util.extract import (
21 | extract_videos_from_zip,
22 | extract_images_from_parquet,
23 | )
24 |
25 | __all__ = [
26 | # Filesystem
27 | "is_source_complete",
28 | "is_zip_complete",
29 | "is_parquet_complete",
30 | "get_most_recent_update_time",
31 | # Download
32 | "download_files",
33 | "list_hf_files",
34 | "openvid1m_err_handler",
35 | # Video
36 | "get_video_duration",
37 | "get_video_metadata",
38 | "seconds_to_str",
39 | # Extraction
40 | "extract_videos_from_zip",
41 | "extract_images_from_parquet",
42 | ]
43 |
--------------------------------------------------------------------------------
/bitmind/cache/util/download.py:
--------------------------------------------------------------------------------
1 | import os
2 | import traceback
3 | from pathlib import Path
4 | from typing import List, Union, Optional
5 |
6 | import asyncio
7 | import aiohttp
8 | import bittensor as bt
9 | import huggingface_hub as hf_hub
10 | from requests.exceptions import RequestException
11 |
12 |
13 | def list_hf_files(repo_id, repo_type="dataset", extension=None):
14 | """List files from a Hugging Face repository.
15 |
16 | Args:
17 | repo_id: Repository ID
18 | repo_type: Type of repository ('dataset', 'model', etc.)
19 | extension: Filter files by extension
20 |
21 | Returns:
22 | List of files in the repository
23 | """
24 | files = []
25 | try:
26 | files = list(hf_hub.list_repo_files(repo_id=repo_id, repo_type=repo_type))
27 | if extension:
28 | files = [f for f in files if f.endswith(extension)]
29 | except Exception as e:
30 | bt.logging.error(f"Failed to list files of type {extension} in {repo_id}: {e}")
31 | return files
32 |
33 |
34 | async def download_files(
35 | urls: List[str], output_dir: Union[str, Path], chunk_size: int = 8192
36 | ) -> List[Path]:
37 | """Download multiple files asynchronously.
38 |
39 | Args:
40 | urls: List of URLs to download
41 | output_dir: Directory to save the files
42 | chunk_size: Size of chunks to download at a time
43 |
44 | Returns:
45 | List of successfully downloaded file paths
46 | """
47 | output_dir = Path(output_dir)
48 | output_dir.mkdir(parents=True, exist_ok=True)
49 |
50 | download_tasks = []
51 | timeout = aiohttp.ClientTimeout(
52 | total=3600,
53 | )
54 |
55 | async with aiohttp.ClientSession(timeout=timeout) as session:
56 | # Create download tasks for each URL
57 | for url in urls:
58 | download_tasks.append(
59 | download_single_file(session, url, output_dir, chunk_size)
60 | )
61 |
62 | # Run all downloads concurrently and gather results
63 | downloaded_files = await asyncio.gather(*download_tasks, return_exceptions=True)
64 |
65 | # Filter out exceptions and return only successful downloads
66 | return [f for f in downloaded_files if isinstance(f, Path)]
67 |
68 |
69 | async def download_single_file(
70 | session: aiohttp.ClientSession, url: str, output_dir: Path, chunk_size: int
71 | ) -> Path:
72 | """Download a single file asynchronously.
73 |
74 | Args:
75 | session: aiohttp ClientSession to use for requests
76 | url: URL to download
77 | output_dir: Directory to save the file
78 | chunk_size: Size of chunks to download at a time
79 |
80 | Returns:
81 | Path to the downloaded file
82 | """
83 | try:
84 | bt.logging.info(f"Downloading {url}")
85 |
86 | async with session.get(url) as response:
87 | if response.status != 200:
88 | bt.logging.error(f"Failed to download {url}: Status {response.status}")
89 | raise Exception(f"HTTP error {response.status}")
90 |
91 | filename = os.path.basename(url)
92 | filepath = output_dir / filename
93 |
94 | bt.logging.info(f"Writing to {filepath}")
95 |
96 | # Use async file I/O to write the file
97 | with open(filepath, "wb") as f:
98 | # Download and write in chunks
99 | async for chunk in response.content.iter_chunked(chunk_size):
100 | if chunk: # filter out keep-alive chunks
101 | f.write(chunk)
102 |
103 | return filepath
104 |
105 | except Exception as e:
106 | bt.logging.error(f"Error downloading {url}: {str(e)}")
107 | bt.logging.error(traceback.format_exc())
108 | raise
109 |
110 |
111 | def openvid1m_err_handler(
112 | base_zip_url: str,
113 | output_path: Path,
114 | part_index: int,
115 | chunk_size: int = 8192,
116 | timeout: int = 300,
117 | ) -> Optional[Path]:
118 | """Synchronous error handler for OpenVid1M downloads that handles split files.
119 |
120 | Args:
121 | base_zip_url: Base URL for the zip parts
122 | output_path: Directory to save files
123 | part_index: Index of the part to download
124 | chunk_size: Size of download chunks
125 | timeout: Download timeout in seconds
126 |
127 | Returns:
128 | Path to combined file if successful, None otherwise
129 | """
130 | part_urls = [
131 | f"{base_zip_url}{part_index}_partaa",
132 | f"{base_zip_url}{part_index}_partab",
133 | ]
134 | error_log_path = output_path / "download_log.txt"
135 | downloaded_parts = []
136 |
137 | # Download each part
138 | for part_url in part_urls:
139 | part_file_path = output_path / Path(part_url).name
140 |
141 | if part_file_path.exists():
142 | bt.logging.warning(f"File {part_file_path} exists.")
143 | downloaded_parts.append(part_file_path)
144 | continue
145 |
146 | try:
147 | response = requests.get(part_url, stream=True, timeout=timeout)
148 | if response.status_code != 200:
149 | raise RequestException(
150 | f"HTTP {response.status_code}: {response.reason}"
151 | )
152 |
153 | with open(part_file_path, "wb") as f:
154 | for chunk in response.iter_content(chunk_size=chunk_size):
155 | if chunk: # filter out keep-alive chunks
156 | f.write(chunk)
157 |
158 | bt.logging.info(f"File {part_url} saved to {part_file_path}")
159 | downloaded_parts.append(part_file_path)
160 |
161 | except Exception as e:
162 | error_message = f"File {part_url} download failed: {str(e)}\n"
163 | bt.logging.error(error_message)
164 | with open(error_log_path, "a") as error_log_file:
165 | error_log_file.write(error_message)
166 | return None
167 |
168 | if len(downloaded_parts) == len(part_urls):
169 | try:
170 | combined_file = output_path / f"OpenVid_part{part_index}.zip"
171 | combined_data = bytearray()
172 | for part_path in downloaded_parts:
173 | with open(part_path, "rb") as part_file:
174 | combined_data.extend(part_file.read())
175 |
176 | with open(combined_file, "wb") as out_file:
177 | out_file.write(combined_data)
178 |
179 | for part_path in downloaded_parts:
180 | part_path.unlink()
181 |
182 | bt.logging.info(f"Successfully combined parts into {combined_file}")
183 | return combined_file
184 |
185 | except Exception as e:
186 | error_message = (
187 | f"Failed to combine parts for index {part_index}: {str(e)}\n"
188 | )
189 | bt.logging.error(error_message)
190 | with open(error_log_path, "a") as error_log_file:
191 | error_log_file.write(error_message)
192 | return None
193 |
194 | return None
195 |
--------------------------------------------------------------------------------
/bitmind/cache/util/extract.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import hashlib
3 | import json
4 | import os
5 | import random
6 | import shutil
7 | from datetime import datetime
8 | from io import BytesIO
9 | from pathlib import Path
10 | from typing import List, Optional, Set, Tuple
11 | from zipfile import ZipFile
12 |
13 | import bittensor as bt
14 | import pyarrow.parquet as pq
15 | from PIL import Image
16 |
17 |
18 | def extract_videos_from_zip(
19 | zip_path: Path,
20 | dest_dir: Path,
21 | num_videos: int,
22 | file_extensions: Set[str] = {".mp4", ".avi", ".mov", ".mkv", ".wmv"},
23 | include_checksums: bool = True,
24 | ) -> List[Tuple[str, str]]:
25 | """Extract random videos and their metadata from a zip file and save them to disk.
26 |
27 | Args:
28 | zip_path: Path to the zip file
29 | dest_dir: Directory to save videos and metadata
30 | num_videos: Number of videos to extract
31 | file_extensions: Set of valid video file extensions
32 | include_checksums: Whether to calculate and include file checksums in metadata
33 |
34 | Returns:
35 | List of tuples containing (video_path, metadata_path)
36 | """
37 | dest_dir = Path(dest_dir)
38 | dest_dir.mkdir(parents=True, exist_ok=True)
39 |
40 | extracted_files = []
41 | try:
42 | with ZipFile(zip_path) as zip_file:
43 | video_files = [
44 | f
45 | for f in zip_file.namelist()
46 | if any(f.lower().endswith(ext) for ext in file_extensions)
47 | and "MACOSX" not in f
48 | ]
49 | if not video_files:
50 | bt.logging.warning(f"No video files found in {zip_path}")
51 | return extracted_files
52 |
53 | bt.logging.debug(f"{len(video_files)} video files found in {zip_path}")
54 | selected_videos = random.sample(
55 | video_files, min(num_videos, len(video_files))
56 | )
57 |
58 | bt.logging.debug(
59 | f"Extracting {len(selected_videos)} randomly sampled video files from {zip_path}"
60 | )
61 | for video in selected_videos:
62 | try:
63 | # extract video and get metadata
64 | video_path = dest_dir / Path(video).name
65 | with zip_file.open(video) as source:
66 | with open(video_path, "wb") as target:
67 | shutil.copyfileobj(source, target)
68 |
69 | video_info = zip_file.getinfo(video)
70 | metadata = {
71 | "dataset": Path(zip_path).parent.name,
72 | "source_zip": str(zip_path),
73 | "path_in_zip": video,
74 | "extraction_date": datetime.now().isoformat(),
75 | "file_size": os.path.getsize(video_path),
76 | "zip_metadata": {
77 | "compress_size": video_info.compress_size,
78 | "file_size": video_info.file_size,
79 | "compress_type": video_info.compress_type,
80 | "date_time": datetime.strftime(
81 | datetime(*video_info.date_time), "%Y-%m-%d %H:%M:%S"
82 | ),
83 | },
84 | }
85 |
86 | if include_checksums:
87 | with open(video_path, "rb") as f:
88 | file_data = f.read()
89 | metadata["checksums"] = {
90 | "md5": hashlib.md5(file_data).hexdigest(),
91 | "sha256": hashlib.sha256(file_data).hexdigest(),
92 | }
93 |
94 | metadata_filename = f"{video_path.stem}.json"
95 | metadata_path = dest_dir / metadata_filename
96 |
97 | with open(metadata_path, "w", encoding="utf-8") as f:
98 | json.dump(metadata, f, indent=2, ensure_ascii=False)
99 |
100 | extracted_files.append((str(video_path), str(metadata_path)))
101 |
102 | except Exception as e:
103 | bt.logging.warning(f"Error extracting {video}: {e}")
104 | continue
105 |
106 | except Exception as e:
107 | bt.logging.warning(f"Error processing zip file {zip_path}: {e}")
108 |
109 | return extracted_files
110 |
111 |
112 | def extract_images_from_parquet(
113 | parquet_path: Path, dest_dir: Path, num_images: int, seed: Optional[int] = None
114 | ) -> List[str]:
115 | """Extract random images and their metadata from a parquet file and save them to disk.
116 |
117 | Args:
118 | parquet_path: Path to the parquet file
119 | dest_dir: Directory to save images and metadata
120 | num_images: Number of images to extract
121 | seed: Random seed for sampling
122 |
123 | Returns:
124 | List of image file paths
125 | """
126 | dest_dir = Path(dest_dir)
127 | dest_dir.mkdir(parents=True, exist_ok=True)
128 |
129 | # read parquet file, sample random image rows
130 | table = pq.read_table(parquet_path)
131 | df = table.to_pandas()
132 | sample_df = df.sample(n=min(num_images, len(df)), random_state=seed)
133 | image_col = next((col for col in sample_df.columns if "image" in col.lower()), None)
134 | metadata_cols = [c for c in sample_df.columns if c != image_col]
135 |
136 | saved_files = []
137 | parquet_prefix = parquet_path.stem
138 | for idx, row in sample_df.iterrows():
139 | try:
140 | img_data = row[image_col]
141 | if isinstance(img_data, dict):
142 | key = next(
143 | (
144 | k
145 | for k in img_data
146 | if "bytes" in k.lower() or "image" in k.lower()
147 | ),
148 | None,
149 | )
150 | img_data = img_data[key]
151 |
152 | try:
153 | img = Image.open(BytesIO(img_data))
154 | except Exception:
155 | img_data = base64.b64decode(img_data)
156 | img = Image.open(BytesIO(img_data))
157 |
158 | base_filename = f"{parquet_prefix}__image_{idx}"
159 | image_format = img.format.lower() if img.format else "png"
160 | img_filename = f"{base_filename}.{image_format}"
161 | img_path = dest_dir / img_filename
162 | img.save(img_path)
163 |
164 | metadata = {
165 | "dataset": Path(parquet_path).parent.name,
166 | "source_parquet": str(parquet_path),
167 | "original_index": str(idx),
168 | "image_format": image_format,
169 | "image_size": img.size,
170 | "image_mode": img.mode,
171 | }
172 |
173 | for col in metadata_cols:
174 | # Convert any non-serializable types to strings
175 | try:
176 | json.dumps({col: row[col]})
177 | metadata[col] = row[col]
178 | except (TypeError, OverflowError):
179 | metadata[col] = str(row[col])
180 |
181 | metadata_filename = f"{base_filename}.json"
182 | metadata_path = dest_dir / metadata_filename
183 | with open(metadata_path, "w", encoding="utf-8") as f:
184 | json.dump(metadata, f, indent=2, ensure_ascii=False)
185 |
186 | saved_files.append(str(img_path))
187 |
188 | except Exception as e:
189 | bt.logging.warning(f"Failed to extract/save image {idx}: {e}")
190 | continue
191 |
192 | return saved_files
193 |
--------------------------------------------------------------------------------
/bitmind/cache/util/filesystem.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Callable, Dict, List, Optional, Tuple, Union, Any
3 | import bittensor as bt
4 | import pyarrow.parquet as pq
5 | from zipfile import ZipFile, BadZipFile
6 | import asyncio
7 | import sys
8 | import time
9 |
10 | from bitmind.types import FileType
11 |
12 |
13 | def get_most_recent_update_time(directory: Path) -> float:
14 | """Get the most recent modification time of any file in directory."""
15 | try:
16 | mtimes = [f.stat().st_mtime for f in directory.iterdir()]
17 | return max(mtimes) if mtimes else 0
18 | except Exception as e:
19 | bt.logging.error(f"Error getting modification times: {e}")
20 | return 0
21 |
22 |
23 | def is_source_complete(path: Union[str, Path]) -> Callable[[Path], bool]:
24 | """Checks integrity of parquet or zip file"""
25 |
26 | path = Path(path)
27 | if path.suffix.lower() == ".parquet":
28 | return is_parquet_complete(path)
29 | elif path.suffix.lower() == ".zip":
30 | return is_zip_complete(path)
31 | else:
32 | return None
33 |
34 |
35 | def is_zip_complete(zip_path: Union[str, Path], testzip=False) -> bool:
36 | try:
37 | with ZipFile(zip_path) as zf:
38 | if testzip:
39 | zf.testzip()
40 | else:
41 | zf.namelist()
42 | return True
43 | except (BadZipFile, Exception) as e:
44 | bt.logging.error(f"Zip file {zip_path} is invalid: {e}")
45 | return False
46 |
47 |
48 | def is_parquet_complete(path: Path) -> bool:
49 | try:
50 | with open(path, "rb") as f:
51 | pq.read_metadata(f)
52 | return True
53 | except Exception as e:
54 | bt.logging.error(f"Parquet file {path} is incomplete or corrupted: {e}")
55 | return False
56 |
57 |
58 | def get_dir_size(
59 | path: Union[str, Path], exclude_dirs: Optional[List[str]] = None
60 | ) -> Tuple[int, int]:
61 | if exclude_dirs is None:
62 | exclude_dirs = []
63 |
64 | total_size = 0
65 | file_count = 0
66 | path_obj = Path(path)
67 |
68 | try:
69 | for item in path_obj.iterdir():
70 | if item.is_dir() and item.name in exclude_dirs:
71 | continue
72 | elif item.is_file():
73 | try:
74 | total_size += item.stat().st_size
75 | file_count += 1
76 | except (OSError, PermissionError):
77 | pass
78 | elif item.is_dir():
79 | subdir_size, subdir_count = get_dir_size(item, exclude_dirs)
80 | total_size += subdir_size
81 | file_count += subdir_count
82 | except (PermissionError, OSError) as e:
83 | print(f"Error accessing {path}: {e}", file=sys.stderr)
84 |
85 | return total_size, file_count
86 |
87 |
88 | def scale_size(size: float, from_unit: str = "B", to_unit: str = "GB") -> float:
89 | if size == 0:
90 | return 0.0
91 |
92 | units = ["B", "KB", "MB", "GB", "TB", "PB"]
93 | from_unit, to_unit = from_unit.upper(), to_unit.upper()
94 | if from_unit not in units or to_unit not in units:
95 | raise ValueError(f"Units must be one of: {', '.join(units)}")
96 |
97 | from_index = units.index(from_unit)
98 | to_index = units.index(to_unit)
99 | scale_factor = from_index - to_index
100 |
101 | if scale_factor > 0:
102 | return size * (1024**scale_factor)
103 | elif scale_factor < 0:
104 | return size / (1024 ** abs(scale_factor))
105 | return size
106 |
107 |
108 | def format_size(
109 | size: float, from_unit: str = "B", to_unit: Optional[str] = None
110 | ) -> str:
111 | if size == 0:
112 | return "0 B"
113 |
114 | units = ["B", "KB", "MB", "GB", "TB", "PB"]
115 | from_unit = from_unit.upper()
116 |
117 | if from_unit not in units:
118 | raise ValueError(f"From unit must be one of: {', '.join(units)}")
119 |
120 | if to_unit is None:
121 | current_size = scale_size(size, from_unit, "B")
122 | unit_index = 0
123 |
124 | while current_size >= 1024 and unit_index < len(units) - 1:
125 | current_size /= 1024
126 | unit_index += 1
127 |
128 | return f"{current_size:.2f} {units[unit_index]}"
129 | else:
130 | to_unit = to_unit.upper()
131 | if to_unit not in units:
132 | raise ValueError(f"To unit must be one of: {', '.join(units)}")
133 | scaled_size = scale_size(size, from_unit, to_unit)
134 | return f"{scaled_size:.2f} {to_unit}"
135 |
136 |
137 | def analyze_directory(
138 | root_path: Union[str, Path],
139 | exclude_dirs: Optional[List[str]] = None,
140 | min_file_count: int = 1,
141 | log_func=None,
142 | ) -> Dict[str, Any]:
143 | if exclude_dirs is None:
144 | exclude_dirs = []
145 |
146 | path_obj = Path(root_path)
147 | result = {
148 | "name": path_obj.name or str(path_obj),
149 | "path": str(path_obj),
150 | "subdirs": [],
151 | "excluded_dirs": [],
152 | }
153 |
154 | size, count = get_dir_size(path_obj, exclude_dirs)
155 | result["size"] = size
156 | result["count"] = count
157 |
158 | try:
159 | subdirs = [d for d in path_obj.iterdir() if d.is_dir()]
160 |
161 | for subdir in sorted(subdirs):
162 | if subdir.name in exclude_dirs:
163 | _, excluded_count = get_dir_size(subdir, [])
164 | if excluded_count < min_file_count:
165 | continue
166 |
167 | excluded_data = analyze_directory(subdir, [], min_file_count, log_func)
168 | excluded_data["excluded"] = True
169 | result["excluded_dirs"].append(excluded_data)
170 | else:
171 | subdir_data = analyze_directory(
172 | subdir, exclude_dirs, min_file_count, log_func
173 | )
174 | if subdir_data["count"] < min_file_count:
175 | continue
176 |
177 | result["subdirs"].append(subdir_data)
178 | except (PermissionError, OSError) as e:
179 | error_msg = f"Error accessing {path_obj}: {e}"
180 | if log_func:
181 | log_func(error_msg)
182 | else:
183 | print(error_msg, file=sys.stderr)
184 |
185 | return result
186 |
187 |
188 | def print_directory_tree(
189 | tree_data: Dict[str, Any],
190 | indent: str = "",
191 | is_last: bool = True,
192 | prefix: str = "",
193 | log_func=None,
194 | ) -> None:
195 | if (
196 | tree_data["count"] == 0
197 | and not tree_data["subdirs"]
198 | and not tree_data["excluded_dirs"]
199 | ):
200 | return
201 |
202 | if is_last:
203 | branch = "└── "
204 | next_indent = indent + " "
205 | else:
206 | branch = "├── "
207 | next_indent = indent + "│ "
208 |
209 | name = tree_data["name"]
210 | count = tree_data["count"]
211 | size = scale_size(tree_data["size"])
212 |
213 | tree_line = f"{indent}{prefix}{branch}[{name}] - {count} files, {size}"
214 | if log_func:
215 | log_func(tree_line)
216 | else:
217 | print(tree_line)
218 |
219 | num_subdirs = len(tree_data["subdirs"])
220 |
221 | for i, subdir in enumerate(tree_data["subdirs"]):
222 | is_subdir_last = (i == num_subdirs - 1) and not tree_data["excluded_dirs"]
223 | print_directory_tree(subdir, next_indent, is_subdir_last, "", log_func)
224 |
225 | for i, excluded in enumerate(tree_data["excluded_dirs"]):
226 | is_excluded_last = i == len(tree_data["excluded_dirs"]) - 1
227 | print_directory_tree(
228 | excluded, next_indent, is_excluded_last, "(SOURCE) ", log_func
229 | )
230 |
231 |
232 | def is_file_older_than(file_path: Union[str, Path], seconds: float = 1.0) -> bool:
233 | """Check if a file's last modification time is older than specified seconds."""
234 | try:
235 | mtime = Path(file_path).stat().st_mtime
236 | return (time.time() - mtime) >= seconds
237 | except (FileNotFoundError, PermissionError):
238 | return False
239 |
240 |
241 | def has_stable_size(file_path: Union[str, Path], wait_time: float = 0.1) -> bool:
242 | """Check if a file's size is stable (not changing)."""
243 | path = Path(file_path)
244 | try:
245 | size1 = path.stat().st_size
246 | time.sleep(wait_time)
247 | size2 = path.stat().st_size
248 | return size1 == size2
249 | except (FileNotFoundError, PermissionError):
250 | return False
251 |
252 |
253 | def is_file_locked(file_path: Union[str, Path]) -> bool:
254 | """Check if a file is locked (being written to by another process)."""
255 | try:
256 | with open(file_path, "rb+") as _:
257 | pass
258 | return False
259 | except (PermissionError, OSError):
260 | return True
261 |
262 |
263 | def is_file_ready(
264 | file_path: Union[str, Path],
265 | min_age_seconds: float = 1.0,
266 | check_size_stability: bool = False,
267 | check_file_lock: bool = True,
268 | stability_wait_time: float = 0.1,
269 | ) -> bool:
270 | """
271 | Determine if a file is ready for processing (not being downloaded/written to).
272 |
273 | Args:
274 | file_path: Path to the file to check
275 | min_age_seconds: Minimum age in seconds since last modification
276 | check_size_stability: Whether to check if file size is stable
277 | check_file_lock: Whether to check if file is locked by another process
278 | stability_wait_time: Time to wait when checking size stability
279 |
280 | Returns:
281 | bool: True if the file appears ready for processing
282 | """
283 | file_path = Path(file_path) if isinstance(file_path, str) else file_path
284 |
285 | if not file_path.exists() or not file_path.is_file():
286 | return False
287 |
288 | if not is_file_older_than(file_path, min_age_seconds):
289 | return False
290 |
291 | if check_size_stability and not has_stable_size(file_path, stability_wait_time):
292 | return False
293 |
294 | if check_file_lock and is_file_locked(file_path):
295 | return False
296 |
297 | return True
298 |
299 |
300 | def filter_ready_files(
301 | file_list: List[Union[str, Path]], **kwargs
302 | ) -> List[Union[str, Path]]:
303 | """
304 | Filter a list of files to only include those that are ready for processing.
305 |
306 | Args:
307 | file_list: List of file paths
308 | **kwargs: Additional arguments to pass to is_file_ready()
309 |
310 | Returns:
311 | list: Filtered list containing only ready files
312 | """
313 | return [f for f in file_list if is_file_ready(f, **kwargs)]
314 |
315 |
316 | async def wait_for_downloads_to_complete(
317 | files: List[Path], min_age_seconds: float = 2.0, timeout_seconds: int = 180
318 | ) -> bool:
319 | if not files:
320 | return True
321 |
322 | start_time = time.time()
323 |
324 | while time.time() - start_time < timeout_seconds:
325 | ready_files = filter_ready_files(
326 | file_list=files, min_age_seconds=min_age_seconds
327 | )
328 | if len(ready_files) == len(files):
329 | return True
330 | # yield to event loop
331 | await asyncio.sleep(5)
332 |
333 | bt.logging.error(f"Timeout waiting for {files} after {timeout_seconds} seconds")
334 | return False
335 |
--------------------------------------------------------------------------------
/bitmind/cache/util/video.py:
--------------------------------------------------------------------------------
1 | import json
2 | import subprocess
3 | import math
4 | import ffmpeg
5 | from typing import Dict, Any, Optional, Union, Tuple
6 |
7 |
8 | def get_video_duration(video_path: str) -> float:
9 | """Get the duration of a video file in seconds.
10 |
11 | Args:
12 | video_path: Path to the video file
13 |
14 | Returns:
15 | Duration in seconds
16 |
17 | Raises:
18 | Exception: If the duration cannot be determined
19 | """
20 | try:
21 | probe = ffmpeg.probe(video_path)
22 | duration = float(probe["format"]["duration"])
23 | return duration
24 | except Exception as e:
25 | try:
26 | result = subprocess.run(
27 | [
28 | "ffprobe",
29 | "-v",
30 | "error",
31 | "-show_entries",
32 | "format=duration",
33 | "-of",
34 | "json",
35 | video_path,
36 | ],
37 | stdout=subprocess.PIPE,
38 | stderr=subprocess.PIPE,
39 | text=True,
40 | )
41 | data = json.loads(result.stdout)
42 | duration = float(data["format"]["duration"])
43 | return duration
44 | except Exception as sub_e:
45 | raise Exception(f"Failed to get video duration: {e}, {sub_e}")
46 |
47 |
48 | def get_video_metadata(video_path: str, max_fps: float = 30.0) -> Dict[str, Any]:
49 | """Get comprehensive metadata from a video file with sanity checks.
50 |
51 | Args:
52 | video_path: Path to the video file
53 | max_fps: Maximum reasonable FPS value (default: 60.0)
54 |
55 | Returns:
56 | Dictionary containing metadata with sanity-checked values
57 | """
58 | try:
59 | ffprobe_fields = (
60 | "format=duration,size,bit_rate,format_name:"
61 | "stream=width,height,codec_name,codec_type,"
62 | "r_frame_rate,avg_frame_rate,pix_fmt,sample_rate,channels"
63 | )
64 | result = subprocess.run(
65 | [
66 | "ffprobe",
67 | "-v",
68 | "error",
69 | "-show_entries",
70 | ffprobe_fields,
71 | "-of",
72 | "json",
73 | video_path,
74 | ],
75 | stdout=subprocess.PIPE,
76 | stderr=subprocess.PIPE,
77 | text=True,
78 | check=True, # This will raise CalledProcessError if ffprobe fails
79 | )
80 |
81 | data = json.loads(result.stdout)
82 |
83 | # Extract basic format information
84 | format_info = data.get("format", {})
85 | streams = data.get("streams", [])
86 |
87 | # Find video and audio streams
88 | video_stream = next(
89 | (s for s in streams if s.get("codec_type") == "video"), None
90 | )
91 | audio_stream = next(
92 | (s for s in streams if s.get("codec_type") == "audio"), None
93 | )
94 |
95 | # Build base metadata
96 | metadata = {
97 | "duration": float(format_info.get("duration", 0)),
98 | "size_bytes": int(format_info.get("size", 0)),
99 | "bit_rate": (
100 | int(format_info.get("bit_rate", 0))
101 | if "bit_rate" in format_info
102 | else None
103 | ),
104 | "format": format_info.get("format_name"),
105 | "has_video": video_stream is not None,
106 | "has_audio": audio_stream is not None,
107 | }
108 |
109 | # Add video stream details if present
110 | if video_stream:
111 | fps, fps_corrected, original_fps = _get_sanitized_fps(video_stream, max_fps)
112 |
113 | metadata.update(
114 | {
115 | "fps": fps,
116 | "width": int(video_stream.get("width", 0)),
117 | "height": int(video_stream.get("height", 0)),
118 | "codec": video_stream.get("codec_name"),
119 | "pix_fmt": video_stream.get("pix_fmt"),
120 | }
121 | )
122 |
123 | if fps_corrected:
124 | metadata["original_fps"] = original_fps
125 | metadata["fps_corrected"] = True
126 |
127 | # Add audio stream details if present
128 | if audio_stream:
129 | metadata.update(
130 | {
131 | "audio_codec": audio_stream.get("codec_name"),
132 | "sample_rate": audio_stream.get("sample_rate"),
133 | "channels": int(audio_stream.get("channels", 0)),
134 | }
135 | )
136 |
137 | return metadata
138 |
139 | except subprocess.CalledProcessError as e:
140 | return _create_error_metadata(f"ffprobe process failed: {e.stderr.strip()}")
141 | except json.JSONDecodeError:
142 | return _create_error_metadata("Failed to parse ffprobe output as JSON")
143 | except Exception as e:
144 | return _create_error_metadata(f"Unexpected error: {str(e)}")
145 |
146 |
147 | def _get_sanitized_fps(
148 | video_stream: Dict[str, Any], max_fps: float = 60.0
149 | ) -> Tuple[float, bool, Optional[float]]:
150 | """Parse and sanitize frame rate from video stream information.
151 |
152 | Returns:
153 | Tuple of (sanitized_fps, was_corrected, original_fps_if_corrected)
154 | """
155 | original_fps = None
156 | fps_corrected = False
157 |
158 | # Try r_frame_rate first (usually more accurate)
159 | fps = _parse_frame_rate_string(video_stream.get("r_frame_rate"))
160 |
161 | # Fall back to avg_frame_rate if needed
162 | if fps is None:
163 | fps = _parse_frame_rate_string(video_stream.get("avg_frame_rate"))
164 |
165 | # Save original before correction
166 | if fps is not None:
167 | original_fps = fps
168 |
169 | # Sanity check and correct if needed
170 | if fps is None or not (0 < fps <= max_fps) or not math.isfinite(fps):
171 | fps_corrected = True
172 | fps = 30.0 # Default to a standard frame rate
173 |
174 | return fps, fps_corrected, original_fps if fps_corrected else None
175 |
176 |
177 | def _parse_frame_rate_string(frame_rate_str: Optional[str]) -> Optional[float]:
178 | """Safely parse a frame rate string in format 'num/den'."""
179 | if not frame_rate_str:
180 | return None
181 |
182 | try:
183 | if "/" in frame_rate_str:
184 | num, den = frame_rate_str.split("/")
185 | num, den = float(num), float(den)
186 | if den <= 0: # Avoid division by zero
187 | return None
188 | return num / den
189 | else:
190 | # Handle case where frame rate is just a number
191 | return float(frame_rate_str)
192 | except (ValueError, ZeroDivisionError):
193 | return None
194 |
195 |
196 | def _create_error_metadata(error_message: str) -> Dict[str, Any]:
197 | """Create a metadata dictionary for error cases."""
198 | return {
199 | "duration": 0,
200 | "has_video": False,
201 | "has_audio": False,
202 | "error": error_message,
203 | }
204 |
205 |
206 | def seconds_to_str(seconds):
207 | """Convert seconds to formatted time string (HH:MM:SS)."""
208 | seconds = int(float(seconds))
209 | hours = seconds // 3600
210 | minutes = (seconds % 3600) // 60
211 | seconds = seconds % 60
212 | return f"{hours:02}:{minutes:02}:{seconds:02}"
213 |
--------------------------------------------------------------------------------
/bitmind/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import bittensor as bt
3 |
4 | MAINNET_UID = 34
5 |
6 |
7 | def validate_config_and_neuron_path(config):
8 | r"""Checks/validates the config namespace object."""
9 | full_path = os.path.expanduser(
10 | "{}/{}/{}/netuid{}/{}".format(
11 | config.logging.logging_dir,
12 | config.wallet.name,
13 | config.wallet.hotkey,
14 | config.netuid,
15 | config.neuron.name,
16 | )
17 | )
18 | bt.logging.info(f"Logging path: {full_path}")
19 | config.neuron.full_path = os.path.expanduser(full_path)
20 | if not os.path.exists(config.neuron.full_path):
21 | os.makedirs(config.neuron.full_path, exist_ok=True)
22 | return config
23 |
24 |
25 | def add_args(parser):
26 | """
27 | Adds relevant arguments to the parser for operation.
28 | """
29 | parser.add_argument("--netuid", type=int, help="Subnet netuid", default=34)
30 |
31 | parser.add_argument(
32 | "--neuron.name",
33 | type=str,
34 | help="Neuron Name",
35 | default="bitmind",
36 | )
37 |
38 | parser.add_argument(
39 | "--epoch-length",
40 | type=int,
41 | help="The default epoch length (how often we set weights, measured in 12 second blocks).",
42 | default=360,
43 | )
44 |
45 | parser.add_argument(
46 | "--mock",
47 | action="store_true",
48 | help="Run in mock mode",
49 | default=False,
50 | )
51 |
52 | parser.add_argument(
53 | "--autoupdate-off",
54 | action="store_false",
55 | dest="autoupdate",
56 | help="Disable automatic updates on latest version on Main.",
57 | default=True,
58 | )
59 |
60 | parser.add_argument("--wandb.entity", type=str, default="bitmindai")
61 |
62 | parser.add_argument("--wandb.off", action="store_true", default=False)
63 |
64 |
65 | def add_miner_args(parser):
66 | """Add miner specific arguments to the parser."""
67 |
68 | parser.add_argument(
69 | "--no-force-validator-permit",
70 | action="store_true",
71 | help="If set, we will not force incoming requests to have a permit.",
72 | default=False,
73 | )
74 |
75 | parser.add_argument(
76 | "--device",
77 | type=str,
78 | default="cpu",
79 | help="Device to use for detection models (cuda/cpu)",
80 | )
81 |
82 |
83 | def add_validator_args(parser):
84 | """Add validator specific arguments to the parser."""
85 |
86 | parser.add_argument(
87 | "--vpermit-tao-limit",
88 | type=int,
89 | help="The maximum number of TAO allowed to query a validator with a vpermit.",
90 | default=20000,
91 | )
92 |
93 | parser.add_argument(
94 | "--compressed-cache-update-interval",
95 | type=int,
96 | help="How often to download new zip/parquet files, measured in 12 second blocks",
97 | default=720,
98 | )
99 |
100 | parser.add_argument(
101 | "--media-cache-update-interval",
102 | type=int,
103 | help="How often to unpack random media files, measured in 12 second blocks",
104 | default=300,
105 | )
106 |
107 | parser.add_argument(
108 | "--challenge-interval",
109 | type=int,
110 | help="How often we set challenge miners, measured in 12 second blocks.",
111 | default=5,
112 | )
113 |
114 | parser.add_argument(
115 | "--wandb-restart-interval",
116 | type=int,
117 | help="How often we restart wandb run to avoid log truncation",
118 | default=2000,
119 | )
120 |
121 | parser.add_argument(
122 | "--cache.base-dir",
123 | type=str,
124 | default=os.path.expanduser("~/.cache/sn34"),
125 | help="Base directory for cache storage",
126 | )
127 |
128 | parser.add_argument(
129 | "--cache.max-compressed-gb",
130 | type=float,
131 | default=50.0,
132 | help="Maximum size in GB for compressed cache",
133 | )
134 |
135 | parser.add_argument(
136 | "--cache.max-media-gb",
137 | type=float,
138 | default=5.0,
139 | help="Maximum size in GB for media cache",
140 | )
141 |
142 | parser.add_argument(
143 | "--cache.media-files-per-source",
144 | type=int,
145 | default=50,
146 | help="Number of media files to keep per source",
147 | )
148 |
149 | parser.add_argument(
150 | "--neuron.max-state-backup-hours",
151 | type=float,
152 | help="The oldest backup of validator state to load in the case of a failure to load most recent",
153 | default=1,
154 | )
155 |
156 | parser.add_argument(
157 | "--neuron.miner-total-timeout",
158 | type=float,
159 | help="Total timeout for miner requests in seconds",
160 | default=11.0,
161 | )
162 |
163 | parser.add_argument(
164 | "--neuron.miner-connect-timeout",
165 | type=float,
166 | help="TCP connection timeout for miner requests in seconds",
167 | default=4.0,
168 | )
169 |
170 | parser.add_argument(
171 | "--neuron.miner-sock-connect-timeout",
172 | type=float,
173 | help="Socket connection timeout for miner requests in seconds",
174 | default=3.0,
175 | )
176 |
177 | parser.add_argument(
178 | "--neuron.heartbeat",
179 | action="store_true",
180 | help="Run validator heartbeat thread",
181 | default=False,
182 | )
183 |
184 | parser.add_argument(
185 | "--neuron.heartbeat-interval-seconds",
186 | type=float,
187 | help="Interval between heartbeat checks in seconds",
188 | default=60.0,
189 | )
190 |
191 | parser.add_argument(
192 | "--neuron.lock-sleep-seconds",
193 | type=float,
194 | help="Sleep duration when lock is held in seconds",
195 | default=5.0,
196 | )
197 |
198 | parser.add_argument(
199 | "--neuron.max-stuck-count",
200 | type=int,
201 | help="Number of consecutive heartbeats with no progress before restart",
202 | default=5,
203 | )
204 |
205 | parser.add_argument(
206 | "--neuron.sample-size",
207 | type=int,
208 | help="Number of miners to query per challenge",
209 | default=50,
210 | )
211 |
212 | parser.add_argument(
213 | "--scoring.moving-average-alpha",
214 | type=float,
215 | help="Alpha for miner score EMA",
216 | default=0.05,
217 | )
218 |
219 | parser.add_argument(
220 | "--scoring.image-weight",
221 | type=float,
222 | help="Weight for image modality scoring",
223 | default=0.6,
224 | )
225 |
226 | parser.add_argument(
227 | "--scoring.video-weight",
228 | type=float,
229 | help="Weight for video modality scoring",
230 | default=0.4,
231 | )
232 |
233 | parser.add_argument(
234 | "--scoring.binary-weight",
235 | type=float,
236 | help="Weight for binary classification scoring",
237 | default=0.75,
238 | )
239 |
240 | parser.add_argument(
241 | "--scoring.multiclass-weight",
242 | type=float,
243 | help="Weight for multiclass classification scoring",
244 | default=0.25,
245 | )
246 |
247 | parser.add_argument(
248 | "--challenge.image-prob",
249 | type=float,
250 | help="Probability of selecting image modality for challenges",
251 | default=0.5,
252 | )
253 |
254 | parser.add_argument(
255 | "--challenge.video-prob",
256 | type=float,
257 | help="Probability of selecting video modality for challenges",
258 | default=0.5,
259 | )
260 |
261 | parser.add_argument(
262 | "--challenge.real-prob",
263 | type=float,
264 | help="Probability of selecting real media for challenges",
265 | default=0.5,
266 | )
267 |
268 | parser.add_argument(
269 | "--challenge.synthetic-prob",
270 | type=float,
271 | help="Probability of selecting synthetic media for challenges",
272 | default=0.3,
273 | )
274 |
275 | parser.add_argument(
276 | "--challenge.semisynthetic-prob",
277 | type=float,
278 | help="Probability of selecting semisynthetic media for challenges",
279 | default=0.2,
280 | )
281 |
282 | parser.add_argument(
283 | "--challenge.multi-video-prob",
284 | type=float,
285 | help="Probability of stitching together two videos of the same media type",
286 | default=0.2,
287 | )
288 |
289 | parser.add_argument(
290 | "--challenge.min-clip-duration",
291 | type=float,
292 | help="Minimum video clip duration in seconds",
293 | default=1.0,
294 | )
295 |
296 | parser.add_argument(
297 | "--challenge.max-clip-duration",
298 | type=float,
299 | help="Maximum video clip duration in seconds",
300 | default=6.0,
301 | )
302 |
303 | parser.add_argument(
304 | "--challenge.max-frames",
305 | type=int,
306 | help="Maximum number of video frames to sample for a challenge",
307 | default=144,
308 | )
309 |
310 |
311 | def add_data_generator_args(parser):
312 | parser.add_argument(
313 | "--cache-dir",
314 | type=str,
315 | default=os.path.expanduser("~/.cache/sn34"),
316 | help="Directory for caching data",
317 | )
318 |
319 | parser.add_argument(
320 | "--batch-size", type=int, default=3, help="Batch size for generation"
321 | )
322 |
323 | parser.add_argument(
324 | "--tasks",
325 | nargs="+",
326 | choices=["t2v", "t2i", "i2i", "i2v"],
327 | default=["t2v", "t2i", "i2i", "i2v"],
328 | help="List of tasks to run (t2v, t2i, i2i, i2v). Defaults to all.",
329 | )
330 |
331 | parser.add_argument(
332 | "--device",
333 | type=str,
334 | default="cuda",
335 | help="Device to use for generation (cuda/cpu)",
336 | )
337 |
338 | parser.add_argument(
339 | "--wandb.num-batches-per-run",
340 | type=int,
341 | default=50,
342 | help="Number of batches to generate before starting new W&B run (avoids log truncation)",
343 | )
344 |
345 | parser.add_argument("--wandb.process-name", type=str, default="generator")
346 |
347 |
348 | def add_proxy_args(parser):
349 | parser.add_argument(
350 | "--proxy.sample-size",
351 | type=int,
352 | default=50,
353 | help="Number of miners to query for organics",
354 | )
355 |
356 | parser.add_argument(
357 | "--proxy.client-url",
358 | type=str,
359 | default="https://subnet-api.bitmindlabs.ai",
360 | help="URL for the proxy client authentication service",
361 | )
362 |
363 | parser.add_argument(
364 | "--proxy.host",
365 | type=str,
366 | default="0.0.0.0",
367 | help="Network interface to listen on",
368 | )
369 |
370 | parser.add_argument(
371 | "--proxy.port",
372 | type=int,
373 | default=10913,
374 | help="Port for the proxy server",
375 | )
376 |
377 | parser.add_argument(
378 | "--proxy.external_port",
379 | type=int,
380 | default=10913,
381 | help="Port for the proxy server",
382 | )
383 |
384 | parser.add_argument(
385 | "--proxy.sample_size",
386 | type=int,
387 | default=50,
388 | help="Number of miners to query for organics",
389 | )
390 |
391 | parser.add_argument(
392 | "--miner-health-sync-interval",
393 | type=int,
394 | default=2,
395 | help="How frequently to check miner health (in blocks)",
396 | )
397 |
398 |
--------------------------------------------------------------------------------
/bitmind/encoding.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import ffmpeg
4 | import os
5 | import tempfile
6 | from typing import List
7 | from io import BytesIO
8 | from PIL import Image
9 |
10 |
11 | def image_to_bytes(img):
12 | """Convert image array to bytes using JPEG encoding with PIL.
13 | Args:
14 | img (np.ndarray): Image array of shape (C, H, W) or (H, W, C)
15 | Can be float32 [0,1] or uint8 [0,255]
16 | Returns:
17 | bytes: JPEG encoded image bytes
18 | str: Content type 'image/jpeg'
19 | """
20 | # Convert float32 [0,1] to uint8 [0,255] if needed
21 | if img.dtype == np.float32:
22 | img = (img * 255).astype(np.uint8)
23 | elif img.dtype != np.uint8:
24 | raise ValueError(f"Image must be float32 or uint8, got {img.dtype}")
25 |
26 | if img.shape[0] == 3 and len(img.shape) == 3: # If in CHW format
27 | img = np.transpose(img, (1, 2, 0)) # CHW to HWC
28 |
29 | # Ensure we have a 3-channel image (H,W,3)
30 | if len(img.shape) == 2:
31 | # Convert grayscale to RGB
32 | img = np.stack([img, img, img], axis=2)
33 | elif img.shape[2] == 1:
34 | # Convert single channel to RGB
35 | img = np.concatenate([img, img, img], axis=2)
36 | elif img.shape[2] == 4:
37 | # Drop alpha channel
38 | img = img[:, :, :3]
39 | elif img.shape[2] != 3:
40 | raise ValueError(f"Expected 1, 3 or 4 channels, got {img.shape[2]}")
41 |
42 | pil_img = Image.fromarray(img)
43 | if pil_img.mode != "RGB":
44 | pil_img = pil_img.convert("RGB")
45 |
46 | buffer = BytesIO()
47 | pil_img.save(buffer, format="JPEG", quality=75)
48 | buffer.seek(0)
49 |
50 | return buffer.getvalue(), "image/jpeg"
51 |
52 |
53 | def video_to_bytes(video: np.ndarray, fps: int | None = None) -> tuple[bytes, str]:
54 | """
55 | Convert a (T, H, W, C) uint8/float32 video to MP4, but *first* pass each frame
56 | through Pillow JPEG → adds normal JPEG artefacts, then encodes losslessly.
57 |
58 | Returns:
59 | bytes: In‑memory MP4 file.
60 | str: MIME‑type ("video/mp4").
61 | """
62 | # ------------- 0. validation / normalisation -------------------------------
63 | if video.dtype == np.float32:
64 | assert video.max() <= 1.0, video.max()
65 | video = (video * 255).clip(0, 255).astype(np.uint8)
66 | elif video.dtype != np.uint8:
67 | raise ValueError(f"Unsupported dtype: {video.dtype}")
68 |
69 | fps = fps or 30
70 |
71 | # TCHW → THWC
72 | if video.shape[1] <= 4 and video.shape[3] > 4:
73 | video = np.transpose(video, (0, 2, 3, 1))
74 |
75 | if video.ndim != 4 or video.shape[3] not in (1, 3):
76 | raise ValueError(f"Expected shape (T, H, W, C), got {video.shape}")
77 |
78 | T, H, W, C = video.shape
79 |
80 | # ------------- 1. apply Pillow JPEG to every frame -------------------------
81 | jpeg_degraded_frames: List[np.ndarray] = []
82 | for idx, frame in enumerate(video):
83 | buf = BytesIO()
84 | Image.fromarray(frame).save(
85 | buf,
86 | format="JPEG",
87 | quality=75,
88 | subsampling=2, # 0=4:4:4, 1=4:2:2, 2=4:2:0 (Pillow default = 2)
89 | optimize=False,
90 | progressive=False,
91 | )
92 | buf.seek(0)
93 | # decode back to RGB so FFmpeg sees the artefact‑laden pixels
94 | degraded = np.array(Image.open(buf).convert("RGB"), dtype=np.uint8)
95 | if degraded.shape != (H, W, 3):
96 | raise ValueError(f"Decoded shape mismatch at frame {idx}: {degraded.shape}")
97 | jpeg_degraded_frames.append(degraded)
98 |
99 | degraded_video = np.stack(jpeg_degraded_frames, axis=0) # (T,H,W,3)
100 |
101 | # ------------- 2. write raw RGB + encode losslessly ------------------------
102 | with tempfile.TemporaryDirectory() as tmpdir:
103 | raw_path = os.path.join(tmpdir, "input.raw")
104 | video_path = os.path.join(tmpdir, "output.mp4")
105 |
106 | degraded_video.tofile(raw_path) # write as one big rawvideo blob
107 |
108 | try:
109 | (
110 | ffmpeg.input(
111 | raw_path,
112 | format="rawvideo",
113 | pix_fmt="rgb24",
114 | s=f"{W}x{H}",
115 | r=fps,
116 | )
117 | .output(
118 | video_path,
119 | vcodec="libx264rgb",
120 | crf=0, # mathematically lossless
121 | preset="veryfast",
122 | pix_fmt="rgb24",
123 | movflags="+faststart",
124 | )
125 | .global_args("-y", "-hide_banner", "-loglevel", "error")
126 | .run()
127 | )
128 | except ffmpeg.Error as e:
129 | raise RuntimeError(
130 | f"FFmpeg encoding failed:\n{e.stderr.decode(errors='ignore')}"
131 | ) from e
132 |
133 | with open(video_path, "rb") as f:
134 | video_bytes = f.read()
135 |
136 | return video_bytes, "video/mp4"
137 |
138 |
139 | def media_to_bytes(media, fps=30):
140 | """Convert image or video array to bytes, using PNG encoding for both.
141 |
142 | Args:
143 | media (np.ndarray): Either:
144 | - Image array of shape (C, H, W)
145 | - Video array of shape (T, C, H, W)
146 | Can be float32 [0,1] or uint8 [0,255]
147 | fps (int): Frames per second for video (default: 30)
148 |
149 | Returns:
150 | bytes: Encoded media bytes
151 | str: Content type (either 'image/png' or 'video/avi')
152 | """
153 | if len(media.shape) == 3: # Image
154 | return image_to_bytes(media)
155 | elif len(media.shape) == 4: # Video
156 | return video_to_bytes(media, fps)
157 | else:
158 | raise ValueError(
159 | f"Invalid media shape: {media.shape}. Expected (C,H,W) for image or (T,C,H,W) for video."
160 | )
161 |
--------------------------------------------------------------------------------
/bitmind/epistula.py:
--------------------------------------------------------------------------------
1 | import json
2 | from hashlib import sha256
3 | from uuid import uuid4
4 | from math import ceil
5 | from typing import Annotated, Any, Dict, Optional
6 |
7 | import bittensor as bt
8 | import numpy as np
9 | import asyncio
10 | import ast
11 | import time
12 | import httpx
13 | import aiohttp
14 | from substrateinterface import Keypair
15 |
16 | from bitmind.types import Modality
17 |
18 |
19 | EPISTULA_VERSION = str(2)
20 |
21 |
22 | def generate_header(
23 | hotkey: Keypair,
24 | body: Any,
25 | signed_for: Optional[str] = None,
26 | ) -> Dict[str, Any]:
27 | timestamp = round(time.time() * 1000)
28 | timestampInterval = ceil(timestamp / 1e4) * 1e4
29 | uuid = str(uuid4())
30 | req_hash = None
31 | if isinstance(body, bytes):
32 | req_hash = sha256(body).hexdigest()
33 | else:
34 | req_hash = sha256(json.dumps(body).encode("utf-8")).hexdigest()
35 |
36 | headers = {
37 | "Epistula-Version": EPISTULA_VERSION,
38 | "Epistula-Timestamp": str(timestamp),
39 | "Epistula-Uuid": uuid,
40 | "Epistula-Signed-By": hotkey.ss58_address,
41 | "Epistula-Request-Signature": "0x"
42 | + hotkey.sign(f"{req_hash}.{uuid}.{timestamp}.{signed_for or ''}").hex(),
43 | }
44 | if signed_for:
45 | headers["Epistula-Signed-For"] = signed_for
46 | headers["Epistula-Secret-Signature-0"] = (
47 | "0x" + hotkey.sign(str(timestampInterval - 1) + "." + signed_for).hex()
48 | )
49 | headers["Epistula-Secret-Signature-1"] = (
50 | "0x" + hotkey.sign(str(timestampInterval) + "." + signed_for).hex()
51 | )
52 | headers["Epistula-Secret-Signature-2"] = (
53 | "0x" + hotkey.sign(str(timestampInterval + 1) + "." + signed_for).hex()
54 | )
55 | return headers
56 |
57 |
58 | def verify_signature(
59 | signature, body: bytes, timestamp, uuid, signed_for, signed_by, now
60 | ) -> Optional[Annotated[str, "Error Message"]]:
61 | if not isinstance(signature, str):
62 | return "Invalid Signature"
63 | timestamp = int(timestamp)
64 | if not isinstance(timestamp, int):
65 | return "Invalid Timestamp"
66 | if not isinstance(signed_by, str):
67 | return "Invalid Sender key"
68 | if not isinstance(signed_for, str):
69 | return "Invalid receiver key"
70 | if not isinstance(uuid, str):
71 | return "Invalid uuid"
72 | if not isinstance(body, bytes):
73 | return "Body is not of type bytes"
74 | ALLOWED_DELTA_MS = 8000
75 | keypair = Keypair(ss58_address=signed_by)
76 | if timestamp + ALLOWED_DELTA_MS < now:
77 | return "Request is too stale"
78 | message = f"{sha256(body).hexdigest()}.{uuid}.{timestamp}.{signed_for}"
79 | verified = keypair.verify(message, signature)
80 | if not verified:
81 | return "Signature Mismatch"
82 | return None
83 |
84 |
85 | def create_header_hook(hotkey, axon_hotkey, model):
86 | async def add_headers(request: httpx.Request):
87 | for key, header in generate_header(hotkey, request.read(), axon_hotkey).items():
88 | request.headers[key] = header
89 |
90 | return add_headers
91 |
92 |
93 | async def query_miner(
94 | uid: int,
95 | media: bytes,
96 | content_type: str,
97 | modality: Modality,
98 | axon_info: bt.AxonInfo,
99 | session: aiohttp.ClientSession,
100 | hotkey: bt.Keypair,
101 | total_timeout: float,
102 | connect_timeout: Optional[float] = None,
103 | sock_connect_timeout: Optional[float] = None,
104 | testnet_metadata: dict = None,
105 | ) -> Dict[str, Any]:
106 | """
107 | Query a miner with media data.
108 |
109 | Args:
110 | uid: miner uid
111 | media: encoded media
112 | content_type: determined by media_to_bytes
113 | modality: Type of media ('image' or 'video')
114 | axon_info: miner AxonInfo
115 | session: aiohttp client session
116 | hotkey: validator hotkey Keypair for signing the request
117 | total_timeout: Total timeout for the request
118 | connect_timeout: Connection timeout
119 | sock_connect_timeout: Socket connection timeout
120 |
121 | Returns:
122 | Dictionary containing the response.
123 | prediction field will be None if any error is encountered, including
124 | if the response contains a prediction that doesn't sum to ~1.
125 | """
126 | response = {
127 | "uid": uid,
128 | "hotkey": axon_info.hotkey,
129 | "status": 500,
130 | "prediction": None,
131 | "error": "",
132 | }
133 |
134 | try:
135 |
136 | url = f"http://{axon_info.ip}:{axon_info.port}/detect_{modality}"
137 | headers = generate_header(hotkey, media, axon_info.hotkey)
138 |
139 | headers = {
140 | "Content-Type": content_type,
141 | "X-Media-Type": modality,
142 | **headers,
143 | }
144 |
145 | if testnet_metadata:
146 | testnet_headers = {f"X-Testnet-{k}": str(v) for k, v in testnet_metadata.items()}
147 | headers.update(testnet_headers)
148 |
149 | async with session.post(
150 | url,
151 | headers=headers,
152 | data=media,
153 | timeout=aiohttp.ClientTimeout(
154 | total=total_timeout,
155 | connect=connect_timeout,
156 | sock_connect=sock_connect_timeout,
157 | ),
158 | ) as res:
159 | response["status"] = res.status
160 | if res.status != 200:
161 | response["error"] = f"HTTP {res.status} error"
162 | return response
163 | try:
164 | data = await res.json()
165 | if "prediction" not in data:
166 | response["error"] = "Missing prediction in response"
167 | return response
168 |
169 | pred = [float(p) for p in data["prediction"]]
170 |
171 | # handle binary predictions, assume [real, fake]
172 | if len(pred) == 2:
173 | pred = pred + [0.0]
174 |
175 | pred = np.array(pred)
176 |
177 | # error on predictions that don't sum to ~1 or contain values outside of [0., 1.]
178 | if abs(sum(pred) - 1.0) > 1e-6 or np.any((pred < 0.0) | (pred > 1.0)):
179 | raise ValueError
180 |
181 | response["prediction"] = pred
182 | return response
183 |
184 | except json.JSONDecodeError:
185 | response["error"] = "Failed to decode JSON response"
186 | return response
187 |
188 | except (TypeError, ValueError) as e:
189 | response["error"] = (
190 | f"Invalid prediction value {data.get('prediction')}"
191 | )
192 | return response
193 |
194 | except asyncio.TimeoutError:
195 | response["status"] = 408
196 | response["error"] = "Request timed out"
197 | except aiohttp.ClientConnectorError as e:
198 | response["status"] = 503
199 | response["error"] = f"Connection error: {str(e)}"
200 | except aiohttp.ClientError as e:
201 | response["status"] = 400
202 | response["error"] = f"Network error: {str(e)}"
203 | except Exception as e:
204 | response["error"] = f"Unknown error: {str(e)}"
205 |
206 | return response
207 |
--------------------------------------------------------------------------------
/bitmind/generation/__init__.py:
--------------------------------------------------------------------------------
1 | from .generation_pipeline import GenerationPipeline
2 | from .prompt_generator import PromptGenerator
3 | from .models import initialize_model_registry
4 |
--------------------------------------------------------------------------------
/bitmind/generation/model_registry.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Dict, Union, Any, List
2 | import random
3 |
4 | from bitmind.types import ModelConfig, ModelTask
5 |
6 |
7 | class ModelRegistry:
8 | """
9 | Registry for managing generative models.
10 | """
11 |
12 | def __init__(self):
13 | self.models: Dict[str, ModelConfig] = {}
14 |
15 | def register(self, model_config: ModelConfig) -> None:
16 | self.models[model_config.path] = model_config
17 |
18 | def register_all(self, model_configs: List[ModelConfig]) -> None:
19 | for config in model_configs:
20 | self.register(config)
21 |
22 | def get_model(self, path: str) -> Optional[ModelConfig]:
23 | return self.models.get(path)
24 |
25 | def get_all_models(self) -> Dict[str, ModelConfig]:
26 | return self.models.copy()
27 |
28 | def get_models_by_task(self, task: ModelTask) -> Dict[str, ModelConfig]:
29 | return {
30 | path: config for path, config in self.models.items() if config.task == task
31 | }
32 |
33 | def get_model_names_by_task(self, task: ModelTask) -> Dict[str, ModelConfig]:
34 | return [path for path, config in self.models.items() if config.task == task]
35 |
36 | def get_models_by_tag(self, tag: str) -> Dict[str, ModelConfig]:
37 | return {
38 | path: config for path, config in self.models.items() if tag in config.tags
39 | }
40 |
41 | def get_model_names_by_task(self, task: ModelTask) -> List[str]:
42 | return list(self.get_models_by_task(task).keys())
43 |
44 | @property
45 | def t2i_models(self) -> Dict[str, ModelConfig]:
46 | return self.get_models_by_task(ModelTask.TEXT_TO_IMAGE)
47 |
48 | @property
49 | def t2v_models(self) -> Dict[str, ModelConfig]:
50 | return self.get_models_by_task(ModelTask.TEXT_TO_VIDEO)
51 |
52 | @property
53 | def i2i_models(self) -> Dict[str, ModelConfig]:
54 | return self.get_models_by_task(ModelTask.IMAGE_TO_IMAGE)
55 |
56 | @property
57 | def i2v_models(self) -> List[str]:
58 | return self.get_models_by_task(ModelTask.IMAGE_TO_VIDEO)
59 |
60 | @property
61 | def t2i_model_names(self) -> List[str]:
62 | return list(self.t2i_models.keys())
63 |
64 | @property
65 | def t2v_model_names(self) -> List[str]:
66 | return list(self.t2v_models.keys())
67 |
68 | @property
69 | def i2i_model_names(self) -> List[str]:
70 | return list(self.i2i_models.keys())
71 |
72 | @property
73 | def i2v_model_names(self) -> List[str]:
74 | return list(self.i2v_models.keys())
75 |
76 | @property
77 | def model_names(self) -> List[str]:
78 | return list(self.models.keys())
79 |
80 | def select_random_model(self, task: Optional[Union[ModelTask, str]] = None) -> str:
81 | if isinstance(task, str):
82 | task = ModelTask(task.lower())
83 |
84 | if task is None:
85 | task = random.choice(list(ModelTask))
86 |
87 | model_names = self.get_model_names_by_task(task)
88 | if not model_names:
89 | raise ValueError(f"No models available for task: {task}")
90 |
91 | return random.choice(model_names)
92 |
93 | def get_model_dict(self, model_name: str) -> Dict[str, Any]:
94 | model = self.get_model(model_name)
95 | if model is None:
96 | raise ValueError(f"Model not found: {model_name}")
97 |
98 | return model.to_dict()
99 |
100 | def get_interleaved_model_names(self, tasks=None) -> List[str]:
101 | from itertools import zip_longest
102 |
103 | model_names = []
104 | if tasks is None:
105 | model_names = [
106 | self.t2i_model_names,
107 | self.t2v_model_names,
108 | self.i2i_model_names,
109 | self.i2v_model_names,
110 | ]
111 | else:
112 | for task in tasks:
113 | model_names.append(self.get_model_names_by_task(task))
114 |
115 | shuffled_model_names = (
116 | random.sample(names, len(names)) for names in model_names
117 | )
118 | return [
119 | m
120 | for quad in zip_longest(*shuffled_model_names)
121 | for m in quad
122 | if m is not None
123 | ]
124 |
125 | def get_modality(self, model_name: str) -> str:
126 | model = self.get_model(model_name)
127 | if model is None:
128 | raise ValueError(f"Model not found: {model_name}")
129 |
130 | return "video" if model.task == ModelTask.TEXT_TO_VIDEO else "image"
131 |
132 | def get_task(self, model_name: str) -> str:
133 | model = self.get_model(model_name)
134 | if model is None:
135 | raise ValueError(f"Model not found: {model_name}")
136 |
137 | return model.task.value
138 |
139 | def get_output_media_type(self, model_name: str) -> str:
140 | model = self.get_model(model_name)
141 | if model is None:
142 | raise ValueError(f"Model not found: {model_name}")
143 |
144 | return model.media_type.value
145 |
--------------------------------------------------------------------------------
/bitmind/generation/prompt_generator.py:
--------------------------------------------------------------------------------
1 | import re
2 | import gc
3 | import bittensor as bt
4 | import torch
5 | from PIL import Image
6 | from transformers import (
7 | AutoModelForCausalLM,
8 | AutoTokenizer,
9 | Blip2ForConditionalGeneration,
10 | Blip2Processor,
11 | pipeline,
12 | )
13 |
14 |
15 | class PromptGenerator:
16 | """
17 | A class for generating and moderating image annotations using transformer models.
18 |
19 | This class provides functionality to generate descriptive captions for images
20 | using BLIP2 models and optionally moderate the generated text using a separate
21 | language model.
22 | """
23 |
24 | def __init__(
25 | self,
26 | vlm_name: str,
27 | llm_name: str,
28 | device: str = "cuda",
29 | ) -> None:
30 | """
31 | Initialize the ImageAnnotationGenerator with specific models and device settings.
32 |
33 | Args:
34 | model_name: The name of the BLIP model for generating image captions.
35 | text_moderation_model_name: The name of the model used for moderating
36 | text descriptions.
37 | device: The device to use.
38 | apply_moderation: Flag to determine whether text moderation should be
39 | applied to captions.
40 | """
41 | self.vlm_name = vlm_name
42 | self.llm_name = llm_name
43 | self.vlm_processor = None
44 | self.vlm = None
45 | self.llm = None
46 | self.device = device
47 |
48 | def load_vlm(self) -> None:
49 | """
50 | Load the vision-language model for image annotation.
51 | """
52 | bt.logging.debug(f"Loading caption generation model {self.vlm_name}")
53 | self.vlm_processor = Blip2Processor.from_pretrained(
54 | self.vlm_name, torch_dtype=torch.float32
55 | )
56 | self.vlm = Blip2ForConditionalGeneration.from_pretrained(
57 | self.vlm_name, torch_dtype=torch.float32
58 | )
59 | self.vlm.to(self.device)
60 | bt.logging.info(f"Loaded image annotation model {self.vlm_name}")
61 |
62 | def load_llm(self) -> None:
63 | """
64 | Load the language model for text moderation.
65 | """
66 | bt.logging.debug(f"Loading caption moderation model {self.llm_name}")
67 | m = re.match(r"cuda:(\d+)", self.device)
68 | gpu_id = int(m.group(1)) if m else 0
69 | llm = AutoModelForCausalLM.from_pretrained(
70 | self.llm_name,
71 | torch_dtype=torch.bfloat16,
72 | device_map={"": gpu_id}
73 | )
74 | tokenizer = AutoTokenizer.from_pretrained(self.llm_name)
75 | self.llm = pipeline("text-generation", model=llm, tokenizer=tokenizer)
76 | bt.logging.info(f"Loaded caption moderation model {self.llm_name}")
77 |
78 | def load_models(self) -> None:
79 | """
80 | Load the necessary models for image annotation and text moderation onto
81 | the specified device.
82 | """
83 | if self.vlm is None:
84 | self.load_vlm()
85 | else:
86 | bt.logging.warning(f"vlm already loaded")
87 |
88 | if self.llm is None:
89 | self.load_llm()
90 | else:
91 | bt.logging.warning(f"llm already loaded")
92 |
93 | def clear_gpu(self) -> None:
94 | """
95 | Clear GPU memory by moving models back to CPU and deleting them,
96 | followed by collecting garbage.
97 | """
98 | bt.logging.debug("Clearing GPU memory after prompt generation")
99 | if self.vlm:
100 | del self.vlm
101 | self.vlm = None
102 |
103 | if self.llm:
104 | del self.llm
105 | self.llm = None
106 |
107 | gc.collect()
108 | torch.cuda.empty_cache()
109 |
110 | def generate(
111 | self, image: Image.Image, downstream_task: str = None, max_new_tokens: int = 20
112 | ) -> str:
113 | """
114 | Generate a string description for a given image using prompt-based
115 | captioning and building conversational context.
116 |
117 | Args:
118 | image: The image for which the description is to be generated.
119 | task: The generation task ('t2i', 't2v', 'i2i', 'i2v'). If video task,
120 | motion descriptions will be added.
121 | max_new_tokens: The maximum number of tokens to generate for each
122 | prompt.
123 |
124 | Returns:
125 | A generated description of the image.
126 | """
127 | if self.vlm is None or self.vlm_processor is None:
128 | self.load_vlm()
129 |
130 | description = ""
131 | prompts = [
132 | "An image of",
133 | "The setting is",
134 | "The background is",
135 | "The image type/style is",
136 | ]
137 |
138 | for i, prompt in enumerate(prompts):
139 | description += prompt + " "
140 | inputs = self.vlm_processor(
141 | image, text=description, return_tensors="pt"
142 | ).to(self.device, torch.float32)
143 |
144 | generated_ids = self.vlm.generate(**inputs, max_new_tokens=max_new_tokens)
145 | answer = self.vlm_processor.batch_decode(
146 | generated_ids, skip_special_tokens=True
147 | )[0].strip()
148 |
149 | bt.logging.trace(f"{i}. Prompt: {prompt}")
150 | bt.logging.trace(f"{i}. Answer: {answer}")
151 |
152 | if answer:
153 | answer = answer.rstrip(" ,;!?")
154 | if not answer.endswith("."):
155 | answer += "."
156 | description += answer + " "
157 | else:
158 | description = description[: -len(prompt) - 1]
159 |
160 | if description.startswith(prompts[0]):
161 | description = description[len(prompts[0]) :]
162 |
163 | description = description.strip()
164 | if not description.endswith("."):
165 | description += "."
166 |
167 | moderated_description = self.moderate(description)
168 |
169 | if downstream_task in ["t2v", "i2v"]:
170 | return self.enhance(moderated_description)
171 | return moderated_description
172 |
173 | def moderate(self, description: str, max_new_tokens: int = 80) -> str:
174 | """
175 | Use the text moderation pipeline to make the description more concise
176 | and neutral.
177 |
178 | Args:
179 | description: The text description to be moderated.
180 | max_new_tokens: Maximum number of new tokens to generate in the
181 | moderated text.
182 |
183 | Returns:
184 | The moderated description text, or the original description if
185 | moderation fails.
186 | """
187 | if self.llm is None:
188 | self.load_llm()
189 |
190 | messages = [
191 | {
192 | "role": "system",
193 | "content": (
194 | "[INST]You always concisely rephrase given descriptions, "
195 | "eliminate redundancy, and remove all specific references to "
196 | "individuals by name. You do not respond with anything other "
197 | "than the revised description.[/INST]"
198 | ),
199 | },
200 | {"role": "user", "content": description},
201 | ]
202 | try:
203 | moderated_text = self.llm(
204 | messages,
205 | max_new_tokens=max_new_tokens,
206 | pad_token_id=self.llm.tokenizer.eos_token_id,
207 | return_full_text=False,
208 | )
209 | return moderated_text[0]["generated_text"]
210 |
211 | except Exception as e:
212 | bt.logging.error(f"An error occurred during moderation: {e}", exc_info=True)
213 | return description
214 |
215 | def enhance(self, description: str, max_new_tokens: int = 80) -> str:
216 | """
217 | Enhance a static image description to make it suitable for video generation
218 | by adding dynamic elements and motion.
219 |
220 | Args:
221 | description: The static image description to enhance.
222 | max_new_tokens: Maximum number of new tokens to generate in the enhanced text.
223 |
224 | Returns:
225 | An enhanced description suitable for video generation, or the original
226 | description if enhancement fails.
227 | """
228 | if self.llm is None:
229 | self.load_llm()
230 |
231 | messages = [
232 | {
233 | "role": "system",
234 | "content": (
235 | "[INST]You are an expert at converting image descriptions into video prompts. "
236 | "Analyze the existing motion in the scene and enhance it naturally:\n"
237 | "1. If motion exists in the image (falling, throwing, running, etc.):\n"
238 | " - Maintain and emphasize that existing motion\n"
239 | " - Add smooth continuation of the movement\n"
240 | "2. If the subject is static (sitting, standing, placed):\n"
241 | " - Keep it stable\n"
242 | " - Add minimal environmental motion if appropriate\n"
243 | "3. Add ONE subtle camera motion that complements the scene\n"
244 | "4. Keep the description concise and natural\n"
245 | "Only respond with the enhanced description.[/INST]"
246 | ),
247 | },
248 | {"role": "user", "content": description},
249 | ]
250 |
251 | try:
252 | enhanced_text = self.llm(
253 | messages,
254 | max_new_tokens=max_new_tokens,
255 | pad_token_id=self.llm.tokenizer.eos_token_id,
256 | return_full_text=False,
257 | )
258 | return enhanced_text[0]["generated_text"]
259 |
260 | except Exception as e:
261 | bt.logging.error(f"An error occurred during motion enhancement: {e}")
262 | return description
263 |
264 | def sanitize(self, prompt: str, max_new_tokens: int = 80) -> str:
265 | """
266 | Use the LLM to make the prompt more SFW (less NSFW).
267 | """
268 |
269 | if self.llm is None:
270 | self.load_llm()
271 |
272 | messages = [
273 | {
274 | "role": "system",
275 | "content": (
276 | "[INST]You are an expert at making prompts safe for work (SFW). "
277 | "Rephrase the following prompt to remove or neutralize any NSFW, sexual, or explicit content. "
278 | "Keep the prompt as close as possible to the original intent, but ensure it is SFW. "
279 | "Only respond with the sanitized prompt.[/INST]"
280 | ),
281 | },
282 | {"role": "user", "content": prompt},
283 | ]
284 | try:
285 | sanitized = self.llm(
286 | messages,
287 | max_new_tokens=max_new_tokens,
288 | pad_token_id=self.llm.tokenizer.eos_token_id,
289 | return_full_text=False,
290 | )
291 | return sanitized[0]["generated_text"]
292 | except Exception as e:
293 | bt.logging.error(f"An error occurred during prompt sanitization: {e}")
294 | return prompt
295 |
--------------------------------------------------------------------------------
/bitmind/generation/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BitMind-AI/bitmind-subnet/1a108251c409c1015cde394ddcd73660f881a6a0/bitmind/generation/util/__init__.py
--------------------------------------------------------------------------------
/bitmind/generation/util/image.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import PIL
3 | import os
4 | from PIL import Image, ImageDraw
5 | from typing import Tuple, Union, List
6 |
7 |
8 | def resize_image(
9 | image: PIL.Image.Image, max_width: int, max_height: int
10 | ) -> PIL.Image.Image:
11 | """Resize the image to fit within specified dimensions while maintaining aspect ratio."""
12 | original_width, original_height = image.size
13 |
14 | # Calculate the aspect ratio and determine new dimensions
15 | aspect_ratio = original_width / original_height
16 | new_width = min(max_width, original_width)
17 | new_height = int(new_width / aspect_ratio)
18 |
19 | if new_height > max_height:
20 | new_height = max_height
21 | new_width = int(new_height * aspect_ratio)
22 |
23 | # Resize the image using the high-quality LANCZOS filter
24 | resized_image = image.resize((new_width, new_height), PIL.Image.LANCZOS)
25 | return resized_image
26 |
27 |
28 | def resize_images_in_directory(directory, target_width, target_height):
29 | """
30 | Resize all images in the specified directory to the target width and height.
31 |
32 | Args:
33 | directory (str): Path to the directory containing images.
34 | target_width (int): Target width for resizing the images.
35 | target_height (int): Target height for resizing the images.
36 | """
37 | # List all files in the directory
38 | for filename in os.listdir(directory):
39 | if filename.endswith(
40 | (".png", ".jpg", ".jpeg", ".bmp", ".gif")
41 | ): # Check for image file extensions
42 | filepath = os.path.join(directory, filename)
43 | with PIL.Image.open(filepath) as img:
44 | # Resize the image and save back to the file location
45 | resized_img = resize_image(
46 | img, max_width=target_width, max_height=target_height
47 | )
48 | resized_img.save(filepath)
49 |
50 |
51 | def save_images_to_disk(
52 | image_dataset, start_index, num_images, save_directory, resize=True
53 | ):
54 | if not os.path.exists(save_directory):
55 | os.makedirs(save_directory)
56 |
57 | for i in range(start_index, start_index + num_images):
58 | try:
59 | image_data = image_dataset[i] # Retrieve image using the __getitem__ method
60 | image = image_data["image"] # Extract the image
61 | image_id = image_data["id"] # Extract the image ID
62 | file_path = os.path.join(
63 | save_directory, f"{image_id}.jpg"
64 | ) # Construct file path
65 | # if resize:
66 | # image = resize_image(image, TARGET_IMAGE_SIZE[0], TARGET_IMAGE_SIZE[1])
67 | image.save(file_path, "JPEG") # Save the image
68 | print(f"Saved: {file_path}")
69 | except Exception as e:
70 | print(f"Failed to save image {i}: {e}")
71 |
72 |
73 | def create_random_mask(size: Tuple[int, int]) -> Image.Image:
74 | """
75 | Create a random mask for i2i transformation.
76 | """
77 | w, h = size
78 | mask = Image.new("RGB", size, "black")
79 |
80 | if np.random.rand() < 0.5:
81 | # Rectangular mask with smoother edges
82 | width = np.random.randint(w // 4, w // 2)
83 | height = np.random.randint(h // 4, h // 2)
84 |
85 | # Center the rectangle with some random offset
86 | x = (w - width) // 2 + np.random.randint(-width // 4, width // 4)
87 | y = (h - height) // 2 + np.random.randint(-height // 4, height // 4)
88 |
89 | # Create mask with PIL draw for smoother edges
90 | draw = ImageDraw.Draw(mask)
91 | draw.rounded_rectangle(
92 | [x, y, x + width, y + height],
93 | radius=min(width, height) // 10, # Smooth corners
94 | fill="white",
95 | )
96 | else:
97 | # Circular mask with feathered edges
98 | draw = ImageDraw.Draw(mask)
99 | x = w // 2
100 | y = h // 2
101 |
102 | # Make radius proportional to image size
103 | radius = min(w, h) // 4
104 |
105 | # Add small random offset to center
106 | x += np.random.randint(-radius // 4, radius // 4)
107 | y += np.random.randint(-radius // 4, radius // 4)
108 |
109 | # Draw multiple circles with decreasing opacity for feathered edge
110 | for r in range(radius, radius - 10, -1):
111 | opacity = int(255 * (r - (radius - 10)) / 10)
112 | draw.ellipse([x - r, y - r, x + r, y + r], fill=(255, 255, 255, opacity))
113 |
114 | return mask, (x, y)
115 |
116 |
117 | def is_black_output(
118 | modality: str, output: Union[List[Image.Image], Image.Image], threshold: int = 10
119 | ) -> bool:
120 | """
121 | Returns True if the image or frames are (almost) completely black.
122 | """
123 | if modality == "image":
124 | arr = np.array(output[modality].images[0])
125 | return np.mean(arr) < threshold
126 | elif modality == "video":
127 | return np.all([np.mean(np.array(arr)) < threshold for arr in output[modality].frames[0]])
128 |
--------------------------------------------------------------------------------
/bitmind/generation/util/prompt.py:
--------------------------------------------------------------------------------
1 | def get_tokenizer_with_min_len(model):
2 | """
3 | Returns the tokenizer with the smallest maximum token length.
4 |
5 | Args:
6 | model: Single pipeline or dict of pipeline stages.
7 |
8 | Returns:
9 | tuple: (tokenizer, max_token_length)
10 | """
11 | # Get the model to check for tokenizers
12 | pipeline = model["stage1"] if isinstance(model, dict) else model
13 |
14 | # If model has two tokenizers, return the one with smaller max length
15 | if hasattr(pipeline, "tokenizer_2"):
16 | len_1 = pipeline.tokenizer.model_max_length
17 | len_2 = pipeline.tokenizer_2.model_max_length
18 | return (
19 | (pipeline.tokenizer_2, len_2)
20 | if len_2 < len_1
21 | else (pipeline.tokenizer, len_1)
22 | )
23 |
24 | return pipeline.tokenizer, pipeline.tokenizer.model_max_length
25 |
26 |
27 | def truncate_prompt_if_too_long(prompt: str, model):
28 | """
29 | Truncates the input string if it exceeds the maximum token length when tokenized.
30 |
31 | Args:
32 | prompt (str): The text prompt that may need to be truncated.
33 |
34 | Returns:
35 | str: The original prompt if within the token limit; otherwise, a truncated version of the prompt.
36 | """
37 | tokenizer, max_token_len = get_tokenizer_with_min_len(model)
38 | tokens = tokenizer(prompt, verbose=False) # Suppress token max exceeded warnings
39 | if len(tokens["input_ids"]) < max_token_len:
40 | return prompt
41 |
42 | # Truncate tokens if they exceed the maximum token length, decode the tokens back to a string
43 | truncated_prompt = tokenizer.decode(
44 | token_ids=tokens["input_ids"][: max_token_len - 1], skip_special_tokens=True
45 | )
46 | tokens = tokenizer(truncated_prompt)
47 | return truncated_prompt
48 |
--------------------------------------------------------------------------------
/bitmind/metagraph.py:
--------------------------------------------------------------------------------
1 | import time
2 | import asyncio
3 | from typing import Callable, List, Tuple
4 | import numpy as np
5 | import bittensor as bt
6 | from bittensor.utils.weight_utils import process_weights_for_netuid
7 |
8 | from bitmind.utils import fail_with_none
9 |
10 | import threading
11 |
12 |
13 | def get_miner_uids(
14 | metagraph: "bt.metagraph", self_uid: int, vpermit_tao_limit: int
15 | ) -> List[int]:
16 | available_uids = []
17 | for uid in range(int(metagraph.n.item())):
18 | if uid == self_uid:
19 | continue
20 |
21 | # Filter non serving axons.
22 | if not metagraph.axons[uid].is_serving:
23 | continue
24 | # Filter validator permit > 1024 stake.
25 | if metagraph.validator_permit[uid]:
26 | if metagraph.S[uid] > vpermit_tao_limit:
27 | continue
28 | available_uids.append(uid)
29 | continue
30 | return available_uids
31 |
32 |
33 | def create_set_weights(version: int, netuid: int):
34 | @fail_with_none("Failed setting weights")
35 | def set_weights(
36 | wallet: "bt.wallet",
37 | metagraph: "bt.metagraph",
38 | subtensor: "bt.subtensor",
39 | weights: Tuple[List[int], List[float]],
40 | ):
41 | uids, raw_weights = weights
42 | if not len(uids):
43 | bt.logging.info("No UIDS to score")
44 | return
45 |
46 | # Set the weights on chain via our subtensor connection.
47 | (
48 | processed_weight_uids,
49 | processed_weights,
50 | ) = process_weights_for_netuid(
51 | uids=np.asarray(uids),
52 | weights=np.asarray(raw_weights),
53 | netuid=netuid,
54 | subtensor=subtensor,
55 | metagraph=metagraph,
56 | )
57 |
58 | bt.logging.info("Setting Weights: " + str(processed_weights))
59 | bt.logging.info("Weight Uids: " + str(processed_weight_uids))
60 | for _ in range(3):
61 | result, message = subtensor.set_weights(
62 | wallet=wallet,
63 | netuid=netuid,
64 | uids=processed_weight_uids, # type: ignore
65 | weights=processed_weights,
66 | wait_for_finalization=False,
67 | wait_for_inclusion=False,
68 | version_key=version,
69 | max_retries=1,
70 | )
71 | if result is True:
72 | bt.logging.success("set_weights on chain successfully!")
73 | break
74 | else:
75 | bt.logging.error(f"set_weights failed {message}")
76 | time.sleep(15)
77 |
78 | return set_weights
79 |
80 |
81 | def create_subscription_handler(substrate, callback: Callable):
82 | def inner(obj, update_nr, _):
83 | substrate.get_block(block_number=obj["header"]["number"])
84 |
85 | if update_nr >= 1:
86 | loop = asyncio.new_event_loop()
87 | asyncio.set_event_loop(loop)
88 | return loop.run_until_complete(callback(obj["header"]["number"]))
89 |
90 | return inner
91 |
92 |
93 | def start_subscription(substrate, callback: Callable):
94 | return substrate.subscribe_block_headers(
95 | create_subscription_handler(substrate, callback)
96 | )
97 |
98 |
99 | def run_block_callback_thread(substrate, callback: Callable):
100 | try:
101 | subscription_thread = threading.Thread(
102 | target=start_subscription, args=[substrate, callback], daemon=True
103 | )
104 | subscription_thread.start()
105 | bt.logging.info("Block subscription started in background thread.")
106 | return subscription_thread
107 | except Exception as e:
108 | bt.logging.error(f"faaailuuure {callback} - {e}")
109 |
--------------------------------------------------------------------------------
/bitmind/scoring/__init__.py:
--------------------------------------------------------------------------------
1 | from .eval_engine import EvalEngine
2 |
--------------------------------------------------------------------------------
/bitmind/scoring/miner_history.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | from collections import deque
3 | import bittensor as bt
4 | import numpy as np
5 | import joblib
6 | import traceback
7 | import os
8 |
9 | from bitmind.types import Modality
10 |
11 |
12 | class MinerHistory:
13 | """Tracks all recent miner performance to facilitate reward computation.
14 | Will be replaced with Redis in a future release """
15 |
16 | VERSION = 2
17 |
18 | def __init__(self, store_last_n_predictions: int = 100):
19 | self.predictions: Dict[int, Dict[Modality, deque]] = {}
20 | self.labels: Dict[int, Dict[Modality, deque]] = {}
21 | self.miner_hotkeys: Dict[int, str] = {}
22 | self.health: Dict[int: int] = {}
23 | self.store_last_n_predictions = store_last_n_predictions
24 | self.version = self.VERSION
25 |
26 | def update(
27 | self,
28 | uid: int,
29 | prediction: np.ndarray,
30 | error: str,
31 | label: int,
32 | modality: Modality,
33 | miner_hotkey: str,
34 | ):
35 | """Update the miner prediction history.
36 |
37 | Args:
38 | prediction: numpy array of shape (3,) containing probabilities for
39 | [real, synthetic, semi-synthetic]
40 | label: integer label (0 for real, 1 for synthetic, 2 for semi-synthetic)
41 | """
42 | if uid not in self.miner_hotkeys or self.miner_hotkeys[uid] != miner_hotkey:
43 | self.reset_miner_history(uid, miner_hotkey)
44 | bt.logging.info(f"Reset history for {uid} {miner_hotkey}")
45 |
46 | if not error:
47 | self.predictions[uid][modality].append(np.array(prediction))
48 | self.labels[uid][modality].append(label)
49 | self.health[uid] = 1
50 | else:
51 | self.health[uid] = 0
52 |
53 | def _reset_predictions(self, uid: int):
54 | self.predictions[uid] = {
55 | Modality.IMAGE: deque(maxlen=self.store_last_n_predictions),
56 | Modality.VIDEO: deque(maxlen=self.store_last_n_predictions),
57 | }
58 |
59 | def _reset_labels(self, uid: int):
60 | self.labels[uid] = {
61 | Modality.IMAGE: deque(maxlen=self.store_last_n_predictions),
62 | Modality.VIDEO: deque(maxlen=self.store_last_n_predictions),
63 | }
64 |
65 | def reset_miner_history(self, uid: int, miner_hotkey: str):
66 | self._reset_predictions(uid)
67 | self._reset_labels(uid)
68 | self.miner_hotkeys[uid] = miner_hotkey
69 |
70 | def get_prediction_count(self, uid: int) -> int:
71 | counts = {}
72 | for modality in [Modality.IMAGE, Modality.VIDEO]:
73 | counts[modality] = len(self.get_recent_predictions_and_labels(uid, modality)[0])
74 | return counts
75 |
76 | def get_recent_predictions_and_labels(self, uid, modality):
77 | if uid not in self.predictions or modality not in self.predictions[uid]:
78 | return [], []
79 | valid_indices = [
80 | i for i, p in enumerate(self.predictions[uid][modality])
81 | if p is not None and (isinstance(p, (list, np.ndarray)) and not np.any(p == None))
82 | ]
83 | valid_preds = np.array([
84 | p for i, p in enumerate(self.predictions[uid][modality]) if i in valid_indices
85 | ])
86 | labels_with_valid_preds = np.array([
87 | p for i, p in enumerate(self.labels[uid][modality]) if i in valid_indices
88 | ])
89 | return valid_preds, labels_with_valid_preds
90 |
91 | def get_healthy_miner_uids(self) -> list:
92 | return [uid for uid, healthy in self.health.items() if healthy]
93 |
94 | def get_unhealthy_miner_uids(self) -> list:
95 | return [uid for uid, healthy in self.health.items() if not healthy]
96 |
97 | def save_state(self, save_dir):
98 | path = os.path.join(save_dir, "history.pkl")
99 | state = {
100 | "version": self.version,
101 | "store_last_n_predictions": self.store_last_n_predictions,
102 | "miner_hotkeys": self.miner_hotkeys,
103 | "predictions": self.predictions,
104 | "labels": self.labels,
105 | "health": self.health
106 | }
107 | joblib.dump(state, path)
108 |
109 | def load_state(self, save_dir):
110 | path = os.path.join(save_dir, "history.pkl")
111 | if not os.path.isfile(path):
112 | bt.logging.warning(f"No saved state found at {path}")
113 | return False
114 |
115 | try:
116 | state = joblib.load(path)
117 | if state["version"] != self.VERSION:
118 | bt.logging.warning(
119 | f"Loading state from different version: {state['version']} != {self.VERSION}"
120 | )
121 |
122 | self.version = state.get("version", self.VERSION)
123 | self.store_last_n_predictions = state.get("store_last_n_predictions", self.store_last_n_predictions)
124 | self.miner_hotkeys = state.get("miner_hotkeys", self.miner_hotkeys)
125 | self.predictions = state.get("predictions", self.predictions)
126 | self.labels = state.get("labels", self.labels)
127 | self.health = state.get("health", self.health)
128 |
129 | if len(self.miner_hotkeys) == 0:
130 | bt.logging.warning("Loaded state has no miner hotkeys")
131 | if len(self.predictions) == 0:
132 | bt.logging.warning("Loaded state has no predictions")
133 | if len(self.labels) == 0:
134 | bt.logging.warning("Loaded state has no labels")
135 | if len(self.health) == 0:
136 | bt.logging.warning("Loaded state has no health records")
137 |
138 | bt.logging.debug(
139 | f"Successfully loaded history for {len(self.miner_hotkeys)} miners"
140 | )
141 | return True
142 |
143 | except Exception as e:
144 | bt.logging.error(f"Error deserializing MinerHistory state: {str(e)}")
145 | bt.logging.error(traceback.format_exc())
146 | return False
147 |
--------------------------------------------------------------------------------
/bitmind/types.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from dataclasses import dataclass, field
3 | from enum import Enum, auto
4 | from pydantic import BaseModel
5 | from typing import Dict, List, Any, Optional, Union
6 |
7 |
8 | class NeuronType(Enum):
9 | VALIDATOR = "VALIDATOR"
10 | VALIDATOR_PROXY = "VALIDATOR_PROXY"
11 | MINER = "MINER"
12 |
13 |
14 | class FileType(Enum):
15 | PARQUET = auto()
16 | ZIP = auto()
17 | VIDEO = auto()
18 | IMAGE = auto()
19 |
20 |
21 | class CacheType(str, Enum):
22 | MEDIA = "media"
23 | COMPRESSED = "compressed"
24 |
25 |
26 | class Modality(str, Enum):
27 | IMAGE = "image"
28 | VIDEO = "video"
29 |
30 |
31 | class MediaType(str, Enum):
32 | REAL = "real", 0
33 | SYNTHETIC = "synthetic", 1
34 | SEMISYNTHETIC = "semisynthetic", 2
35 |
36 | def __new__(cls, str_value, int_value):
37 | obj = str.__new__(cls, str_value)
38 | obj._value_ = str_value
39 | obj.int_value = int_value
40 | return obj
41 |
42 |
43 | @dataclass
44 | class CacheUpdaterConfig:
45 | num_sources_per_dataset: int = 1
46 | num_items_per_source: int = 100
47 |
48 |
49 | @dataclass
50 | class CacheConfig:
51 | """Configuration for a cache at base_dir / {modality} / {media_type}"""
52 |
53 | modality: str
54 | media_type: str
55 | base_dir: Path = Path("~/.cache/sn34").expanduser()
56 | tags: Optional[List[str]] = None
57 | max_compressed_gb: float = 100.0
58 | max_media_gb: float = 10.0
59 |
60 | def get_path(self):
61 | media_cache_path = Path(self.base_dir) / self.modality / self.media_type
62 | media_cache_path.mkdir(exist_ok=True, parents=True)
63 | return media_cache_path
64 |
65 |
66 | @dataclass
67 | class DatasetConfig:
68 | path: str # HuggingFace path
69 | type: Modality
70 | media_type: MediaType
71 | tags: List[str] = field(default_factory=list)
72 | file_format: str = ""
73 | compressed_format: str = ""
74 | priority: int = 1 # Optional: priority for sampling (higher is more frequent)
75 | enabled: bool = True
76 |
77 | def __post_init__(self):
78 | """Validate and set defaults"""
79 | if not self.compressed_format:
80 | if self.type == Modality.IMAGE:
81 | self.compressed_format = "parquet"
82 | elif self.type == Modality.VIDEO:
83 | self.compressed_format = "zip"
84 |
85 | if isinstance(self.tags, str):
86 | self.tags = [t.strip() for t in self.tags.split(",")]
87 |
88 | if isinstance(self.type, str):
89 | self.type = Modality(self.type.lower())
90 |
91 | if isinstance(self.media_type, str):
92 | self.media_type = MediaType(self.media_type.lower())
93 |
94 |
95 | class ModelTask(str, Enum):
96 | """Type of task the model is designed for"""
97 |
98 | TEXT_TO_IMAGE = "t2i"
99 | TEXT_TO_VIDEO = "t2v"
100 | IMAGE_TO_IMAGE = "i2i"
101 | IMAGE_TO_VIDEO = "i2v"
102 |
103 |
104 | class ModelConfig:
105 | """
106 | Configuration for a generative AI model.
107 |
108 | Attributes:
109 | path: The Hugging Face model path or identifier
110 | task: The primary task of the model (T2I, T2V, I2I)
111 | media_type: Type of output (synthetic or semisynthetic)
112 | pipeline_cls: Pipeline class used to load the model
113 | pretrained_args: Arguments for the from_pretrained method
114 | generate_args: Default arguments for generation
115 | tags: List of tags for categorizing the model
116 | use_autocast: Whether to use autocast during generation
117 | scheduler: Optional scheduler configuration
118 | scheduler_cls: Optional scheduler class
119 | scheduler_args: Optional scheduler args
120 | """
121 |
122 | def __init__(
123 | self,
124 | path: str,
125 | task: ModelTask,
126 | pipeline_cls: Union[Any, Dict[str, Any]],
127 | media_type: Optional[MediaType] = None,
128 | pretrained_args: Dict[str, Any] = None,
129 | generate_args: Dict[str, Any] = None,
130 | tags: List[str] = None,
131 | use_autocast: bool = True,
132 | enable_model_cpu_offload: bool = False,
133 | enable_sequential_cpu_offload: bool = False,
134 | vae_enable_slicing: bool = False,
135 | vae_enable_tiling: bool = False,
136 | scheduler: Dict[str, Any] = None,
137 | save_args: Dict[str, Any] = None,
138 | pipeline_stages: List[Dict[str, Any]] = None,
139 | clear_memory_on_stage_end: bool = False,
140 | lora_model_id: str = None,
141 | lora_loading_args: Dict[str, Any] = None,
142 | ):
143 | self.path = path
144 | self.task = task
145 | self.pipeline_cls = pipeline_cls
146 | self.media_type = media_type
147 |
148 | if self.media_type is None:
149 | self.media_type = (
150 | MediaType.SEMISYNTHETIC
151 | if task == ModelTask.IMAGE_TO_IMAGE
152 | else MediaType.SYNTHETIC
153 | )
154 |
155 | self.pretrained_args = pretrained_args or {}
156 | self.generate_args = generate_args or {}
157 | self.tags = tags or []
158 | self.use_autocast = use_autocast
159 | self.enable_model_cpu_offload = enable_model_cpu_offload
160 | self.enable_sequential_cpu_offload = enable_sequential_cpu_offload
161 | self.vae_enable_slicing = vae_enable_slicing
162 | self.vae_enable_tiling = vae_enable_tiling
163 | self.scheduler = scheduler
164 | self.save_args = save_args or {}
165 | self.pipeline_stages = pipeline_stages
166 | self.clear_memory_on_stage_end = clear_memory_on_stage_end
167 | self.lora_model_id = lora_model_id
168 | self.lora_loading_args = lora_loading_args
169 |
170 | def to_dict(self) -> Dict[str, Any]:
171 | """Convert config to dictionary format"""
172 | return {
173 | "pipeline_cls": self.pipeline_cls,
174 | "from_pretrained_args": self.pretrained_args,
175 | "generate_args": self.generate_args,
176 | "use_autocast": self.use_autocast,
177 | "enable_model_cpu_offload": self.enable_model_cpu_offload,
178 | "enable_sequential_cpu_offload": self.enable_sequential_cpu_offload,
179 | "vae_enable_slicing": self.vae_enable_slicing,
180 | "vae_enable_tiling": self.vae_enable_tiling,
181 | "scheduler": self.scheduler,
182 | "save_args": self.save_args,
183 | "pipeline_stages": self.pipeline_stages,
184 | "clear_memory_on_stage_end": self.clear_memory_on_stage_end,
185 | }
186 |
187 |
188 | class ValidatorConfig(BaseModel):
189 | skip_weight_set: Optional[bool] = False
190 | set_weights_on_start: Optional[bool] = False
191 | max_concurrent_organics: Optional[int] = 2
192 |
--------------------------------------------------------------------------------
/bitmind/utils.py:
--------------------------------------------------------------------------------
1 | import traceback
2 | import bittensor as bt
3 | import functools
4 | import json
5 | import os
6 |
7 |
8 | def print_info(metagraph, hotkey, block, isMiner=True):
9 | uid = metagraph.hotkeys.index(hotkey)
10 | log = f"UID:{uid} | Block:{block} | Consensus:{metagraph.C[uid]} | "
11 | if isMiner:
12 | bt.logging.info(
13 | log
14 | + f"Stake:{metagraph.S[uid]} | Trust:{metagraph.T[uid]} | Incentive:{metagraph.I[uid]} | Emission:{metagraph.E[uid]}"
15 | )
16 | return
17 | bt.logging.info(log + f"VTrust:{metagraph.Tv[uid]} | ")
18 |
19 |
20 | def fail_with_none(message: str = ""):
21 | def outer(func):
22 | def inner(*args, **kwargs):
23 | try:
24 | return func(*args, **kwargs)
25 | except Exception as e:
26 | bt.logging.error(message)
27 | bt.logging.error(str(e))
28 | bt.logging.error(traceback.format_exc())
29 | return None
30 |
31 | return inner
32 |
33 | return outer
34 |
35 |
36 | def on_block_interval(interval_attr_name):
37 | """
38 | Decorator for methods that should only execute at specific block intervals.
39 |
40 | Args:
41 | interval_attr_name: String name of the config attribute that specifies the interval
42 | """
43 |
44 | def decorator(func):
45 | @functools.wraps(func)
46 | async def wrapper(self, block, *args, **kwargs):
47 | interval = getattr(self.config, interval_attr_name)
48 | if interval is None:
49 | bt.logging.error(f"No interval found for {interval_attr_name}")
50 | if (
51 | block == 0 or block % interval == 0
52 | ): # Allow execution on block 0 for initialization
53 | return await func(self, block, *args, **kwargs)
54 | return None
55 |
56 | return wrapper
57 |
58 | return decorator
59 |
60 |
61 | class ExitContext:
62 | """
63 | Using this as a class lets us pass this to other threads
64 | """
65 |
66 | isExiting: bool = False
67 |
68 | def startExit(self, *_):
69 | if self.isExiting:
70 | exit()
71 | self.isExiting = True
72 |
73 | def __bool__(self):
74 | return self.isExiting
75 |
76 |
77 | def get_metadata(media_path):
78 | """Get metadata for a media file if it exists."""
79 | base_path = os.path.splitext(media_path)[0]
80 | json_path = f"{base_path}.json"
81 |
82 | if os.path.exists(json_path):
83 | try:
84 | with open(json_path, "r") as f:
85 | return json.load(f)
86 | except json.JSONDecodeError:
87 | bt.logging.error(f"Warning: Could not parse JSON file: {json_path}")
88 | return {}
89 | return {}
90 |
91 |
92 | def get_file_modality(filepath: str) -> str:
93 | """
94 | Determine the type of media file based on its extension.
95 |
96 | Args:
97 | filepath: Path to the media file
98 |
99 | Returns:
100 | "image", "video", or "file" based on the file extension
101 | """
102 | ext = os.path.splitext(filepath)[1].lower()
103 | if ext in [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".webp"]:
104 | return "image"
105 | elif ext in [".mp4", ".avi", ".mov", ".webm", ".mkv", ".flv"]:
106 | return "video"
107 | else:
108 | return "file"
109 |
--------------------------------------------------------------------------------
/docs/Incentive.md:
--------------------------------------------------------------------------------
1 | # Bitmind Subnet Incentive Mechanism
2 |
3 | This document covers the current state of SN34's incentive mechanism.
4 | 1. [Overview](#overview)
5 | 2. [Rewards](#rewards)
6 | 3. [Scores](#scores)
7 | 4. [Weights](#weights)
8 | 5. [Incentive](#incentives)
9 |
10 | ## TLDR
11 |
12 | Miner rewards are a weighted combination of their performance on video and image detection challenges. Validators keep track of miner performance using a score vector, which is updated using an exponential moving average. These scores are used by validators to set weights for miners, which determine their reward distribution, incentivizing high-quality predictions and consistent performance.
13 |
14 |
15 | ## Rewards
16 | >A miner's total reward $C$ combines their performance across both image and video challenges, weighted by configurable parameters $p$ that controls the emphasis placed on each modality.
17 |
18 | $$
19 | C = \sum_{m \in \{image, video\}} p_m \sum_{k \in \{b,m\}} w_k MCC_k
20 | $$
21 |
22 | The reward for each modality $m$ is a weighted combination of binary and multiclass ($b$ and $m$) Matthews Correlation Coefficient (MCC) scores. The weights $w_k$ allow emphasis to be shifted as needed between the binary distinction between synthetic and authentic, and the more granular separation of fully- and semi-synthetic content.
23 |
24 |
25 | ## Scores
26 |
27 | >Validators set weights based on historical miner performances, tracked by their score vector.
28 |
29 | For each challenge *t*, a validator will randomly sample 50 miners, send them an image/video, and compute their rewards *C* as described above. These reward values are then used to update the validator's score vector *V* using an exponential moving average (EMA) with *α* = 0.02.
30 |
31 | $$
32 | V_t = 0.02 \cdot C_t + 0.98 \cdot V_{t-1}
33 | $$
34 |
35 | A low *α* value places emphasis on a miner's historical performance, adding additional smoothing to avoid having a single prediction cause significant score fluctuations.
36 |
37 |
38 | ## Weights
39 |
40 | > Validators set weights around once per tempo (360 blocks) by sending a normalized score vector to the Bittensor blockchain (in `UINT16` representation).
41 |
42 | Weight normalization by L1 norm:
43 |
44 | $$w = \frac{\text{V}}{\lVert\text{V}\rVert_1}$$
45 |
46 |
47 | ## Incentives
48 | > The [Yuma Consensus algorithm](https://docs.bittensor.com/yuma-consensus) translates the weight matrix *W* into incentives for the subnet miners and dividends for the subnet validators
49 |
50 | Specifically, for each miner *j*, incentive is a function of rank *R*:
51 |
52 | $$I_j = \frac{R_j}{\sum_k R_k}$$
53 |
54 | where rank *R* is *W* (a matrix of validator weight vectors) weighted by validator stake vector *S*.
55 |
56 | $$R_k = \sum_i S_i \cdot W_{ik}$$
57 |
58 |
59 |
60 |
61 |
--------------------------------------------------------------------------------
/docs/Mining.md:
--------------------------------------------------------------------------------
1 | # Miner Setup Guide
2 |
3 | ## Before you proceed ⚠️
4 |
5 | If you are new to Bittensor, we recommend familiarizing yourself with the basics in the [Bittensor Docs](https://docs.bittensor.com/) before proceeding.
6 |
7 | **Run your own local subtensor** to avoid rate limits set on public endpoints. See [Run a Subtensor Node Locally](https://github.com/opentensor/subtensor/blob/main/docs/running-subtensor-locally.md#compiling-your-own-binary) for setup instructions.
8 |
9 | **Understand your minimum compute requirements** for model training and miner deployment, which varies depending on your choice of model. You will likely need at least a consumer grade GPU for training. Many models can be deploying in CPU-only environments for mining.
10 |
11 |
12 | ## Installation
13 |
14 | Download the repository and navigate to the folder.
15 | ```bash
16 | git clone https://github.com/bitmind-ai/bitmind-subnet.git && cd bitmind-subnet
17 | ```
18 |
19 | We recommend using a Conda virtual environment to install the necessary Python packages.
20 | - You can set up Conda with this [quick command-line install](https://docs.anaconda.com/free/miniconda/#quick-command-line-install).
21 | - Note that after you run the last commands in the miniconda setup process, you'll be prompted to start a new shell session to complete the initialization.
22 |
23 | With miniconda installed, you can create your virtual environment with this command:
24 |
25 | ```bash
26 | conda create -y -n bitmind python=3.10
27 | ```
28 |
29 | - Activating your virtual environment: `conda activate bitmind`
30 | - Deactivating your virtual environment `conda deactivate`
31 |
32 | Install the remaining necessary requirements with the following chained command.
33 | ```bash
34 | conda activate bitmind
35 | export PIP_NO_CACHE_DIR=1
36 | chmod +x setup.sh
37 | ./setup.sh
38 | ```
39 |
40 | Before you register a miner on testnet or mainnet, you must first fill out all the necessary fields in `.env.miner`. Make a copy of the template, and fill in your wallet and axon information.
41 |
42 | ```
43 | cp .env.miner.template .env.miner
44 | ```
45 |
46 |
47 | ## Miner Task
48 |
49 | ### Expected Miner Outputs
50 |
51 | > Miners respond to validator queries with a probability vector [$p_{real}$, $p_{synthetic}$, $p_{semisynthetic}$]
52 |
53 | Your task as a SN34 miner is to classify images and videos as real, synthetic, or semisynthetic.
54 | - **Real**: Authentic meida, not touched in any way by AI
55 | - **Synthetic**: Fully AI-generated media
56 | - **Semisynthetic**: AI-modified (spatially, not temporally) media. E.g. faceswaps, inpainting, etc.
57 |
58 | Minor details:
59 | - You are scored only on correctness, so rounding these probabilities will not give you extra incentive.
60 | - To maximize incentive, you must respond with the multiclass vector described above.
61 | - If your classifier returns a binary response (e.g. a float in $[0., 1.]$ or a vector [$p_{real}$, $p_{synthetic}$]), you will earn partial credit (as defined by our incentive mechanism)
62 |
63 |
64 | ### Training your Detector
65 |
66 | > [!IMPORTANT]
67 | > The default video and image detection models provided in `neurons/miner.py` serve only to exemplify the desired behavior of the miner neuron, and will not provide competitive performance on mainnet.
68 |
69 | #### Model
70 |
71 | #### Data
72 |
73 |
74 | ## Registration
75 |
76 | To run a miner, you must have a registered hotkey.
77 |
78 | > [!IMPORTANT]
79 | > Registering on a Bittensor subnet burns TAO. To reduce the risk of deregistration due to technical issues or a poor performing model, we recommend the following:
80 | > 1. Test your miner on testnet before you start mining on mainnet.
81 | > 2. Before registering your hotkey on mainnet, make sure your axon port is accepting incoming traffic by running `curl your_ip:your_port`
82 |
83 |
84 | #### Mainnet
85 |
86 | ```bash
87 | btcli s register --netuid 34 --wallet.name [wallet_name] --wallet.hotkey [wallet.hotkey] --subtensor.network finney
88 | ```
89 |
90 | #### Testnet
91 |
92 | > For testnet tao, you can make requests in the [Bittensor Discord's "Requests for Testnet Tao" channel](https://discord.com/channels/799672011265015819/1190048018184011867)
93 |
94 |
95 | ```bash
96 | btcli s register --netuid 168 --wallet.name [wallet_name] --wallet.hotkey [wallet.hotkey] --subtensor.network test
97 | ```
98 |
99 | #### Mining
100 |
101 | You can now launch your miner with `start_miner.sh`, which will use the configuration you provided in `.env.miner` (see the last step of the [Installation](#installation) section).
102 |
--------------------------------------------------------------------------------
/docs/Validating.md:
--------------------------------------------------------------------------------
1 | # Validator Guide
2 |
3 | ## Before you proceed ⚠️
4 |
5 | If you are new to Bittensor (you're probably not if you're reading the validator guide 😎), we recommend familiarizing yourself with the basics in the [Bittensor Docs](https://docs.bittensor.com/) before proceeding.
6 |
7 | **Run your own local subtensor** to avoid rate limits set on public endpoints. See [Run a Subtensor Node Locally](https://github.com/opentensor/subtensor/blob/main/docs/running-subtensor-locally.md#compiling-your-own-binary) for setup instructions.
8 |
9 | **Understand the minimum compute requirements to run a validator**. Validator neurons on SN34 run a suite of generative (text-to-image, text-to-video, etc.) models that require an **80GB VRAM GPU**. They also maintain a large cache of real and synthetic media to ensure diverse, locally available data for challenging miners. We recommend **1TB of storage**. For more details, please see our [minimum compute documentation](../min_compute.yml)
10 |
11 | ## Required Hugging Face Model Access
12 |
13 | To properly validate, you must gain access to several Hugging Face models used by the subnet. This requires logging in to your Hugging Face account and accepting the terms for each model below:
14 |
15 | - [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
16 | - [DeepFloyd IF-II-L-v1.0](https://huggingface.co/DeepFloyd/IF-II-L-v1.0)
17 | - [DeepFloyd IF-I-XL-v1.0](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0)
18 |
19 | > **Note:** Accepting the terms for any one of the DeepFloyd IF models (e.g., IF-II-L or IF-I-XL) will grant you access to all DeepFloyd IF models.
20 | >
21 | > **If you've been validating with us for a while (prior to V3), you've likely already gotten access to these models and can disregard this step.**
22 |
23 | To do this:
24 | 1. Log in to your Hugging Face account.
25 | 2. Visit each model page above.
26 | 3. Click the "Access repository" or "Agree and access repository" button to accept the terms.
27 |
28 | ## Installation
29 |
30 | Download the repository and navigate to the folder.
31 | ```bash
32 | git clone https://github.com/bitmind-ai/bitmind-subnet.git && cd bitmind-subnet
33 | ```
34 |
35 | We recommend using a Conda virtual environment to install the necessary Python packages.
36 | - You can set up Conda with this [quick command-line install](https://www.anaconda.com/docs/getting-started/miniconda/install#linux).
37 | - Note that after you run the last commands in the miniconda setup process, you'll be prompted to start a new shell session to complete the initialization.
38 |
39 | With miniconda installed, you can create your virtual environment with this command:
40 |
41 | ```bash
42 | conda create -y -n bitmind python=3.10
43 | ```
44 |
45 | - Activating your virtual environment: `conda activate bitmind`
46 | - Deactivating your virtual environment `conda deactivate`
47 |
48 | Install the remaining necessary requirements with the following chained command.
49 | ```bash
50 | conda activate bitmind
51 | export PIP_NO_CACHE_DIR=1
52 | chmod +x setup.sh
53 | ./setup.sh
54 | ```
55 |
56 | Before you register, you should first fill out all the necessary fields in `.env.validator`. Make a copy of the template, and fill in your wallet information.
57 |
58 | ```
59 | cp .env.validator.template .env.validator
60 | ```
61 |
62 | ## Registration
63 |
64 | To validate on our subnet, you must have a registered hotkey.
65 |
66 | #### Mainnet
67 |
68 | ```bash
69 | btcli s register --netuid 34 --wallet.name [wallet_name] --wallet.hotkey [wallet.hotkey] --subtensor.network finney
70 | ```
71 |
72 | #### Testnet
73 |
74 | ```bash
75 | btcli s register --netuid 168 --wallet.name [wallet_name] --wallet.hotkey [wallet.hotkey] --subtensor.network test
76 | ```
77 |
78 |
79 | ## Validating
80 |
81 | Before starting your validator, please ensure you've populated the empty fields in `.env.validator`, including `WANDB_API_KEY` and `HUGGING_FACE_TOKEN`.
82 |
83 | If you haven't already, you can start by copying the template,
84 | ```
85 | cp .env.validator.template .env.validator
86 | ```
87 |
88 | If you don't have a W&B API key, please reach out to the BitMind team via Discord and we can provide one.
89 |
90 | Now you're ready to run your validator!
91 |
92 | ```bash
93 | conda activate bitmind
94 | ./start_validator.sh
95 | ```
96 |
97 | - Auto updates are enabled by default. To disable, run with `--no-auto-updates`.
98 | - Self-healing restarts are enabled by default (every 6 hours). To disable, run with `--no-self-heal`.
99 |
100 |
101 | The above command will kick off 3 `pm2` processes
102 | ```
103 | ┌────┬───────────────────┬─────────────┬─────────┬─────────┬──────────┬────────┬──────┬───────────┬──────────┬──────────┬──────────┬──────────┐
104 | │ id │ name │ namespace │ version │ mode │ pid │ uptime │ ↺ │ status │ cpu │ mem │ user │ watching │
105 | ├────┼───────────────────┼─────────────┼─────────┼─────────┼──────────┼────────┼──────┼───────────┼──────────┼──────────┼──────────┼──────────┤
106 | │ 0 │ sn34-generator │ default │ N/A │ fork │ 2397505 │ 38m │ 2 │ online │ 100% │ 3.0gb │ user │ disabled │
107 | │ 2 │ sn34-proxy │ default │ N/A │ fork │ 2398000 │ 27m │ 1 │ online │ 0% │ 695.2mb │ user │ disabled │
108 | │ 1 │ sn34-validator │ default │ N/A │ fork │ 2394939 │ 108m │ 0 │ online │ 0% │ 5.8gb │ user │ disabled │
109 | └────┴───────────────────┴─────────────┴─────────┴─────────┴──────────┴────────┴──────┴───────────┴──────────┴──────────┴──────────┴──────────┘
110 | ```
111 | - `sn34-validator` is the validator process
112 | - `sn34-generator` runs our data generation pipeline to produce **synthetic images and videos** (stored in `~/.cache/sn34`)
113 | - `sn34-proxy`routes organic traffic from our applications to miners.
114 |
--------------------------------------------------------------------------------
/docs/static/Bitmind-Logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BitMind-AI/bitmind-subnet/1a108251c409c1015cde394ddcd73660f881a6a0/docs/static/Bitmind-Logo.png
--------------------------------------------------------------------------------
/docs/static/Join-BitMind-Discord.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BitMind-AI/bitmind-subnet/1a108251c409c1015cde394ddcd73660f881a6a0/docs/static/Join-BitMind-Discord.png
--------------------------------------------------------------------------------
/docs/static/Subnet-Arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BitMind-AI/bitmind-subnet/1a108251c409c1015cde394ddcd73660f881a6a0/docs/static/Subnet-Arch.png
--------------------------------------------------------------------------------
/min_compute.yml:
--------------------------------------------------------------------------------
1 | # NOTE FOR MINERS:
2 | # Miner min compute varies based on selected model architecture.
3 | # For model training, you will most likely need a GPU. For miner deployment, depending
4 | # on your model, you may be able to get away with CPU.
5 |
6 | version: '3.0.0'
7 |
8 | compute_spec:
9 |
10 | validator:
11 |
12 | cpu:
13 | min_cores: 4 # Minimum number of CPU cores
14 | min_speed: 2.5 # Minimum speed per core (GHz)
15 | recommended_cores: 8 # Recommended number of CPU cores
16 | recommended_speed: 3.5 # Recommended speed per core (GHz)
17 | architecture: "x86_64" # Architecture type (e.g., x86_64, arm64)
18 |
19 | gpu:
20 | required: True # Does the application require a GPU?
21 | min_vram: 80 # Minimum GPU VRAM (GB)
22 | recommended_vram: 80 # Recommended GPU VRAM (GB)
23 | min_compute_capability: 8.0 # Minimum CUDA compute capability
24 | recommended_compute_capability: 8.0 # Recommended CUDA compute capability
25 | recommended_gpu: "NVIDIA A100 80GB PCIE" # Recommended GPU to purchase/rent
26 | fp64: 9.7 # TFLOPS
27 | fp64_tensor_core: 19.5 # TFLOPS
28 | fp32: 19.5 # TFLOPS
29 | tf32: 156 # TFLOPS*
30 | bfloat16_tensor_core: 312 # TFLOPS*
31 | int8_tensor_core: 624 # TOPS*
32 |
33 | # See NVIDIA A100 datasheet for details:
34 | # https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/
35 | # nvidia-a100-datasheet-nvidia-us-2188504-web.pdf
36 |
37 | # *double with sparsity
38 |
39 | memory:
40 | min_ram: 32 # Minimum RAM (GB)
41 | min_swap: 4 # Minimum swap space (GB)
42 | recommended_swap: 8 # Recommended swap space (GB)
43 | ram_type: "DDR6" # RAM type (e.g., DDR4, DDR3, etc.)
44 |
45 | storage:
46 | min_space: 1000 # Minimum free storage space (GB)
47 | recommended_space: 1000 # Recommended free storage space (GB)
48 | type: "SSD" # Preferred storage type (e.g., SSD, HDD)
49 | min_iops: 1000 # Minimum I/O operations per second (if applicable)
50 | recommended_iops: 5000 # Recommended I/O operations per second
51 |
52 | os:
53 | name: "Ubuntu" # Name of the preferred operating system(s)
54 | version: 22.04 # Version of the preferred operating system(s)
55 |
56 | network_spec:
57 | bandwidth:
58 | download: 100 # Minimum download bandwidth (Mbps)
59 | upload: 20 # Minimum upload bandwidth (Mbps)
60 |
--------------------------------------------------------------------------------
/neurons/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BitMind-AI/bitmind-subnet/1a108251c409c1015cde394ddcd73660f881a6a0/neurons/__init__.py
--------------------------------------------------------------------------------
/neurons/base.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from threading import Thread
3 | from typing import Callable, List
4 | import bittensor as bt
5 | import copy
6 | import inspect
7 | import traceback
8 |
9 | from bittensor.core.settings import SS58_FORMAT, TYPE_REGISTRY
10 | from nest_asyncio import asyncio
11 | from substrateinterface import SubstrateInterface
12 | import signal
13 |
14 | from bitmind import (
15 | __spec_version__ as spec_version,
16 | )
17 | from bitmind.metagraph import run_block_callback_thread
18 | from bitmind.types import NeuronType
19 | from bitmind.utils import ExitContext, on_block_interval
20 | from bitmind.config import (
21 | add_args,
22 | add_validator_args,
23 | add_miner_args,
24 | add_proxy_args,
25 | validate_config_and_neuron_path,
26 | )
27 |
28 |
29 | class BaseNeuron:
30 | config: "bt.config"
31 | neuron_type: NeuronType
32 | exit_context = ExitContext()
33 | next_sync_block = None
34 | block_callbacks: List[Callable] = []
35 | substrate_thread: Thread
36 |
37 | def check_registered(self):
38 | if not self.subtensor.is_hotkey_registered(
39 | netuid=self.config.netuid,
40 | hotkey_ss58=self.wallet.hotkey.ss58_address,
41 | ):
42 | bt.logging.error(
43 | f"Wallet: {self.wallet} is not registered on netuid {self.config.netuid}."
44 | f" Please register the hotkey using `btcli subnets register` before trying again"
45 | )
46 | exit()
47 |
48 | @on_block_interval("epoch_length")
49 | async def maybe_sync_metagraph(self, block):
50 | self.check_registered()
51 | bt.logging.info("Resyncing Metagraph")
52 | self.metagraph.sync(subtensor=self.subtensor)
53 |
54 | if self.neuron_type == NeuronType.VALIDATOR:
55 | bt.logging.info("Metagraph updated, re-syncing hotkeys and moving averages")
56 | self.eval_engine.sync_to_metagraph()
57 |
58 | async def run_callbacks(self, block):
59 | if (
60 | hasattr(self, "initialization_complete")
61 | and not self.initialization_complete
62 | ):
63 | bt.logging.debug(
64 | f"Skipping callbacks at block {block} during initialization"
65 | )
66 | return
67 |
68 | for callback in self.block_callbacks:
69 | try:
70 | res = callback(block)
71 | if inspect.isawaitable(res):
72 | await res
73 | except Exception as e:
74 | bt.logging.error(
75 | f"Failed running callback {callback.__name__}: {str(e)}"
76 | )
77 | bt.logging.error(traceback.format_exc())
78 |
79 | def __init__(self, config=None):
80 | bt.logging.info(
81 | f"Bittensor Version: {bt.__version__} | SN34 Version {spec_version}"
82 | )
83 |
84 | parser = argparse.ArgumentParser()
85 | bt.wallet.add_args(parser)
86 | bt.subtensor.add_args(parser)
87 | bt.logging.add_args(parser)
88 | add_args(parser)
89 |
90 | if self.neuron_type == NeuronType.VALIDATOR:
91 | bt.axon.add_args(parser)
92 | add_validator_args(parser)
93 | if self.neuron_type == NeuronType.VALIDATOR_PROXY:
94 | add_validator_args(parser)
95 | add_proxy_args(parser)
96 | if self.neuron_type == NeuronType.MINER:
97 | bt.axon.add_args(parser)
98 | add_miner_args(parser)
99 |
100 | self.config = bt.config(parser)
101 | if config:
102 | base_config = copy.deepcopy(config)
103 | self.config.merge(base_config)
104 |
105 | validate_config_and_neuron_path(self.config)
106 |
107 | ## Add kill signals
108 | signal.signal(signal.SIGINT, self.exit_context.startExit)
109 | signal.signal(signal.SIGTERM, self.exit_context.startExit)
110 |
111 | ## LOGGING
112 | bt.logging(config=self.config, logging_dir=self.config.neuron.full_path)
113 | bt.logging.set_info()
114 | if self.config.logging.debug:
115 | bt.logging.set_debug(True)
116 | if self.config.logging.trace:
117 | bt.logging.set_trace(True)
118 |
119 | ## BITTENSOR INITIALIZATION
120 | bt.logging.success(self.config)
121 | self.wallet = bt.wallet(config=self.config)
122 | self.subtensor = bt.subtensor(
123 | config=self.config, network=self.config.subtensor.chain_endpoint
124 | )
125 | self.metagraph = self.subtensor.metagraph(self.config.netuid)
126 |
127 | self.loop = asyncio.get_event_loop()
128 | bt.logging.debug(f"Wallet: {self.wallet}")
129 | bt.logging.debug(f"Subtensor: {self.subtensor}")
130 | bt.logging.debug(f"Metagraph: {self.metagraph}")
131 |
132 | ## CHECK IF REGG'D
133 | self.check_registered()
134 | self.uid = self.metagraph.hotkeys.index(self.wallet.hotkey.ss58_address)
135 |
136 | ## Substrate, Subtensor and Metagraph
137 | self.substrate = SubstrateInterface(
138 | ss58_format=SS58_FORMAT,
139 | use_remote_preset=True,
140 | url=self.config.subtensor.chain_endpoint,
141 | type_registry=TYPE_REGISTRY,
142 | )
143 |
144 | self.block_callbacks.append(self.maybe_sync_metagraph)
145 | self.substrate_thread = run_block_callback_thread(
146 | self.substrate, self.run_callbacks
147 | )
148 |
--------------------------------------------------------------------------------
/neurons/generator.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import sys
3 | import io
4 | import time
5 | import signal
6 | import traceback
7 | import argparse
8 | from pathlib import Path
9 | from PIL import Image
10 | from typing import List, Dict, Any
11 | import os
12 | import atexit
13 |
14 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
15 | os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
16 | os.environ["TRANSFORMERS_VERBOSITY"] = "error"
17 | os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
18 | os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
19 |
20 | import warnings
21 |
22 | for module in ["diffusers", "transformers.tokenization_utils_base"]:
23 | warnings.filterwarnings("ignore", category=FutureWarning, module=module)
24 |
25 | import logging
26 |
27 | logging.getLogger("transformers").setLevel(logging.ERROR)
28 | logging.getLogger("diffusers").setLevel(logging.ERROR)
29 | logging.getLogger("torch").setLevel(logging.ERROR)
30 | logging.getLogger("datasets").setLevel(logging.ERROR)
31 |
32 | import transformers
33 |
34 | transformers.logging.set_verbosity_error()
35 |
36 | import bittensor as bt
37 | from bitmind.config import add_args, add_data_generator_args
38 | from bitmind.utils import ExitContext, get_metadata
39 | from bitmind.wandb_utils import init_wandb, clean_wandb_cache
40 | from bitmind.types import CacheConfig, MediaType, Modality
41 | from bitmind.cache.sampler import ImageSampler
42 | from bitmind.generation import (
43 | GenerationPipeline,
44 | initialize_model_registry,
45 | )
46 |
47 |
48 | class Generator:
49 | def __init__(self):
50 | self.exit_context = ExitContext()
51 | self.task = None
52 | self.generation_pipeline = None
53 | self.image_sampler = None
54 |
55 | self.setup_signal_handlers()
56 | atexit.register(self.cleanup)
57 |
58 | parser = argparse.ArgumentParser()
59 | bt.subtensor.add_args(parser)
60 | bt.wallet.add_args(parser)
61 | bt.logging.add_args(parser)
62 | add_data_generator_args(parser)
63 | add_args(parser)
64 |
65 | self.config = bt.config(parser)
66 |
67 | bt.logging(config=self.config, logging_dir=self.config.neuron.full_path)
68 | bt.logging.set_trace()
69 | if self.config.logging.debug:
70 | bt.logging.set_debug(True)
71 | if self.config.logging.trace:
72 | bt.logging.set_trace(True)
73 |
74 | bt.logging.success(self.config)
75 | wallet_configured = (
76 | self.config.wallet.name is not None
77 | and self.config.wallet.hotkey is not None
78 | )
79 | if wallet_configured and not self.config.wandb_off:
80 | try:
81 | self.wallet = bt.wallet(config=self.config)
82 | self.uid = (
83 | bt.subtensor(
84 | config=self.config, network=self.config.subtensor.chain_endpoint
85 | )
86 | .metagraph(self.config.netuid)
87 | .hotkeys.index(self.wallet.hotkey.ss58_address)
88 | )
89 | self.wandb_dir = str(Path(__file__).parent.parent)
90 | clean_wandb_cache(self.wandb_dir)
91 | self.wandb_run = init_wandb(
92 | self.config.copy(),
93 | self.config.wandb.process_name,
94 | self.uid,
95 | self.wallet.hotkey,
96 | )
97 |
98 | except Exception as e:
99 | bt.logging.error("Not registered, can't sign W&B run")
100 | bt.logging.error(e)
101 | self.config.wandb.off = True
102 |
103 | def setup_signal_handlers(self):
104 | signal.signal(signal.SIGINT, self.signal_handler)
105 | signal.signal(signal.SIGTERM, self.signal_handler)
106 | signal.signal(signal.SIGQUIT, self.signal_handler)
107 |
108 | def signal_handler(self, sig, frame):
109 | signal_name = signal.Signals(sig).name
110 | bt.logging.info(f"Received {signal_name}, initiating shutdown...")
111 | self.cleanup()
112 | sys.exit(0)
113 |
114 | def cleanup(self):
115 | if self.task and not self.task.done():
116 | self.task.cancel()
117 |
118 | if self.generation_pipeline:
119 | try:
120 | bt.logging.trace("Shutting down generator...")
121 | self.generation_pipeline.shutdown()
122 | bt.logging.success("Generator shut down gracefully")
123 | except Exception as e:
124 | bt.logging.error(f"Error during generator shutdown: {e}")
125 |
126 | # Force cleanup of any GPU memory
127 | try:
128 | import torch
129 |
130 | if torch.cuda.is_available():
131 | torch.cuda.empty_cache()
132 | bt.logging.trace("CUDA memory cache cleared")
133 | except Exception as e:
134 | pass
135 |
136 | async def wait_for_cache(self, timeout: int = 300):
137 | """Wait for the cache to be populated with images for prompt generation"""
138 | start = time.time()
139 | attempts = 0
140 | while True:
141 | if time.time() - start > timeout:
142 | return False
143 |
144 | available_count = self.image_sampler.get_available_count(use_index=False)
145 | if available_count > 0:
146 | return True
147 |
148 | await asyncio.sleep(10)
149 | if not attempts % 3:
150 | bt.logging.info("Waiting for images in cache...")
151 | attempts += 1
152 |
153 | async def sample_images(self, k: int = 1) -> List[Dict[str, Any]]:
154 | """Sample images from the cache"""
155 | result = await self.image_sampler.sample(k, remove_from_cache=False)
156 | if result["count"] == 0:
157 | raise ValueError("No images available in cache")
158 |
159 | # Convert bytes to PIL images
160 | for item in result["items"]:
161 | if isinstance(item["image"], bytes):
162 | item["image"] = Image.open(io.BytesIO(item["image"]))
163 |
164 | return result["items"]
165 |
166 | async def run(self):
167 | """Main generator loop"""
168 | try:
169 | cache_dir = self.config.cache_dir
170 | batch_size = self.config.batch_size
171 | device = self.config.device
172 |
173 | Path(cache_dir).mkdir(parents=True, exist_ok=True)
174 |
175 | self.image_sampler = ImageSampler(
176 | CacheConfig(
177 | modality=Modality.IMAGE.value,
178 | media_type=MediaType.REAL.value,
179 | base_dir=Path(cache_dir),
180 | )
181 | )
182 |
183 | await self.wait_for_cache()
184 | bt.logging.success("Cache populated. Proceeding to generation.")
185 |
186 | model_registry = initialize_model_registry()
187 | model_names = model_registry.get_interleaved_model_names(self.config.tasks)
188 | bt.logging.info(f"Starting generator")
189 | bt.logging.info(f"Tasks: {self.config.tasks}")
190 | bt.logging.info(f"Models: {model_names}")
191 |
192 | self.generation_pipeline = GenerationPipeline(
193 | output_dir=cache_dir,
194 | device=device,
195 | )
196 |
197 | gen_count = 0
198 | batch_count = 0
199 | while not self.exit_context.isExiting:
200 | if asyncio.current_task().cancelled():
201 | break
202 |
203 | try:
204 | image_samples = await self.sample_images(batch_size)
205 | bt.logging.info(
206 | f"Starting batch generation | Batch Size: {len(image_samples)} | Batch Count: {gen_count}"
207 | )
208 |
209 | start_time = time.time()
210 |
211 | filepaths = self.generation_pipeline.generate(
212 | image_samples, model_names=model_names
213 | )
214 | await asyncio.sleep(1)
215 |
216 | duration = time.time() - start_time
217 | gen_count += len(filepaths)
218 | batch_count += 1
219 | bt.logging.info(
220 | f"Generated {len(filepaths)} files in batch #{batch_count} in {duration:.2f} seconds"
221 | )
222 |
223 | if not self.config.wandb.off:
224 | if batch_count >= self.config.wandb.num_batches_per_run:
225 | batch_count = 0
226 | self.wandb_run.finish()
227 | clean_wandb_cache(self.wandb_dir)
228 | self.wandb_run = init_wandb(
229 | self.config.copy(),
230 | self.config.wandb.process_name,
231 | self.uid,
232 | self.wallet.hotkey,
233 | )
234 |
235 | except asyncio.CancelledError:
236 | bt.logging.info("Task cancelled, exiting loop")
237 | break
238 | except Exception as e:
239 | bt.logging.error(f"Error in batch processing: {e}")
240 | bt.logging.error(traceback.format_exc())
241 | await asyncio.sleep(10)
242 | except Exception as e:
243 | bt.logging.error(f"Unhandled exception in main task: {e}")
244 | bt.logging.error(traceback.format_exc())
245 | raise
246 | finally:
247 | self.cleanup()
248 |
249 | def start(self):
250 | """Start the generator"""
251 | loop = asyncio.get_event_loop()
252 | try:
253 | self.task = asyncio.ensure_future(self.run())
254 | loop.run_until_complete(self.task)
255 | except KeyboardInterrupt:
256 | bt.logging.info("Generator interrupted by KeyboardInterrupt, shutting down")
257 | except Exception as e:
258 | bt.logging.error(f"Unhandled exception: {e}")
259 | bt.logging.error(traceback.format_exc())
260 | finally:
261 | self.cleanup()
262 |
263 |
264 | if __name__ == "__main__":
265 | generator = Generator()
266 | generator.start()
267 | sys.exit(0)
268 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=64", "wheel", "pip>=21.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "bitmind"
7 | dynamic = ["version"]
8 | description = "SN34 on bittensor"
9 | authors = [
10 | {name = "BitMind", email = "intern@bitmind.ai"}
11 | ]
12 | readme = "README.md"
13 | requires-python = ">=3.10"
14 | license = {text = ""}
15 | urls = {homepage = "http://bitmind.ai"}
16 |
17 | dependencies = [
18 | "bittensor==9.3.0",
19 | "bittensor-cli==9.4.1",
20 | "pillow==10.4.0",
21 | "substrate-interface==1.7.11",
22 | "numpy==2.0.1",
23 | "pandas==2.2.3",
24 | "torch==2.5.1",
25 | "asyncpg==0.29.0",
26 | "httpcore==1.0.7",
27 | "httpx==0.28.1",
28 | "pyarrow==19.0.1",
29 | "ffmpeg-python==0.2.0",
30 | "bitsandbytes==0.45.4",
31 | "black==25.1.0",
32 | "pre-commit==4.2.0",
33 | "diffusers==0.33.1",
34 | "transformers==4.50.0",
35 | "scikit-learn==1.6.1",
36 | "av==14.2.0",
37 | "opencv-python==4.11.0.86",
38 | "wandb==0.19.9",
39 | "uvicorn==0.27.1",
40 | "python-multipart==0.0.20",
41 | "peft==0.15.0",
42 | "hf_xet==1.1.1"
43 | ]
44 |
45 | [tool.setuptools]
46 | packages = {find = {where = ["."], exclude = ["docs*", "wandb*", "*.egg-info"]}}
47 |
48 | [tool.setuptools.dynamic]
49 | version = {file = "VERSION"}
--------------------------------------------------------------------------------
/requirements-git.txt:
--------------------------------------------------------------------------------
1 | janus @ git+https://github.com/deepseek-ai/Janus.git
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | bittensor==9.3.0
2 | bittensor-cli==9.4.1
3 | pillow==10.4.0
4 | substrate-interface==1.7.11
5 | numpy==2.0.1
6 | pandas==2.2.3
7 | torch==2.5.1
8 | asyncpg==0.29.0
9 | httpcore==1.0.7
10 | httpx==0.28.1
11 | pyarrow==19.0.1
12 | ffmpeg-python==0.2.0
13 | bitsandbytes==0.45.4
14 | black==25.1.0
15 | pre-commit==4.2.0
16 | diffusers==0.33.1
17 | transformers==4.50.0
18 | scikit-learn==1.6.1
19 | av==14.2.0
20 | opencv-python==4.11.0.86
21 | wandb==0.19.9
22 | uvicorn==0.27.1
23 | python-multipart==0.0.20
24 | peft==0.15.0
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | ###########################################
3 | # System Updates and Package Installation #
4 | ###########################################
5 |
6 | # Update system
7 | sudo apt update -y
8 |
9 | # Install core dependencies
10 | sudo apt install -y \
11 | python3-pip \
12 | nano \
13 | libgl1 \
14 | ffmpeg \
15 | unzip
16 |
17 | """
18 | # Remove old nodejs and npm if present
19 | #sudo apt-get remove --purge -y nodejs npm
20 |
21 | # Install Node.js 20.x (LTS) from NodeSource for stability and universal standard
22 | # NOTE: Update the version here when a new LTS is released
23 | curl -fsSL https://deb.nodesource.com/setup_20.x | sudo -E bash -
24 | sudo apt-get install -y nodejs
25 |
26 |
27 | # Install build dependencies
28 | sudo apt install -y \
29 | build-essential \
30 | cmake \
31 | libopenblas-dev \
32 | liblapack-dev \
33 | libx11-dev \
34 | libgtk-3-dev
35 |
36 | # Install process manager (pm2) globally
37 | sudo npm install -g pm2@latest
38 | """
39 | ############################
40 | # Python Package Installation
41 | ############################
42 |
43 | pip install --use-pep517 -e . -r requirements-git.txt
44 |
--------------------------------------------------------------------------------
/start_miner.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ###################################
4 | # LOAD ENV FILE
5 | ###################################
6 | set -a
7 | source .env.miner
8 | set +a
9 |
10 | ###################################
11 | # PREPARE CLI ARGS
12 | ###################################
13 | if [[ "$CHAIN_ENDPOINT" == *"test"* ]]; then
14 | NETUID=168
15 | NETWORK="test"
16 | elif [[ "$CHAIN_ENDPOINT" == *"finney"* ]]; then
17 | NETUID=34
18 | NETWORK="finney"
19 | fi
20 |
21 | case "$LOGLEVEL" in
22 | "trace")
23 | LOG_PARAM="--logging.trace"
24 | ;;
25 | "debug")
26 | LOG_PARAM="--logging.debug"
27 | ;;
28 | "info")
29 | LOG_PARAM="--logging.info"
30 | ;;
31 | *)
32 | # Default to info if LOGLEVEL is not set or invalid
33 | LOG_PARAM="--logging.info"
34 | ;;
35 | esac
36 |
37 | # Set auto-update parameter based on AUTO_UPDATE
38 | FORCE_VPERMIT_PARAM=""
39 | if [ "$FORCE_VPERMIT" = false ]; then
40 | FORCE_VPERMIT_PARAM="--no-force-validator-permit"
41 | fi
42 |
43 |
44 | ###################################
45 | # RESTART PROCESSES
46 | ###################################
47 | NAME="bitmind-miner"
48 |
49 | # Stop any existing processes
50 | if pm2 list | grep -q "$NAME"; then
51 | echo "'$NAME' is already running. Deleting it..."
52 | pm2 delete $NAME
53 | fi
54 |
55 | echo "Starting $NAME | chain_endpoint: $CHAIN_ENDPOINT | netuid: $NETUID"
56 |
57 | # Run data generator
58 | pm2 start neurons/miner.py \
59 | --interpreter python3 \
60 | --name $NAME \
61 | -- \
62 | --wallet.name $WALLET_NAME \
63 | --wallet.hotkey $WALLET_HOTKEY \
64 | --netuid $NETUID \
65 | --subtensor.chain_endpoint $CHAIN_ENDPOINT \
66 | --axon.port $AXON_PORT \
67 | --axon.external_ip $AXON_EXTERNAL_IP \
68 | --device $DEVICE \
69 | $FORCE_VPERMIT_PARAM
70 |
71 |
--------------------------------------------------------------------------------
/start_validator.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ###################################
4 | # LOAD ENV FILE
5 | ###################################
6 | set -a
7 | source .env.validator
8 | set +a
9 |
10 | ###################################
11 | # LOG IN TO THIRD PARTY SERVICES
12 | ###################################
13 | # Login to Weights & Biases
14 | if ! wandb login $WANDB_API_KEY; then
15 | echo "Failed to login to Weights & Biases with the provided API key."
16 | exit 1
17 | fi
18 | echo "Logged into W&B with API key provided in .env.validator"
19 |
20 | # Login to Hugging Face
21 | if ! huggingface-cli login --token $HUGGING_FACE_TOKEN; then
22 | echo "Failed to login to Hugging Face with the provided token."
23 | exit 1
24 | fi
25 | echo "Logged into W&B with token provided in .env.validator"
26 |
27 | ###################################
28 | # PREPARE CLI ARGS
29 | ###################################
30 | : ${PROXY_PORT:=10913}
31 | : ${PROXY_EXTERNAL_PORT:=$PROXY_PORT}
32 | : ${DEVICE:=cuda}
33 |
34 | if [[ "$CHAIN_ENDPOINT" == *"test"* ]]; then
35 | NETUID=168
36 | NETWORK="test"
37 | elif [[ "$CHAIN_ENDPOINT" == *"finney"* ]]; then
38 | NETUID=34
39 | NETWORK="finney"
40 | fi
41 |
42 | case "$LOGLEVEL" in
43 | "trace")
44 | LOG_PARAM="--logging.trace"
45 | ;;
46 | "debug")
47 | LOG_PARAM="--logging.debug"
48 | ;;
49 | "info")
50 | LOG_PARAM="--logging.info"
51 | ;;
52 | *)
53 | # Default to info if LOGLEVEL is not set or invalid
54 | LOG_PARAM="--logging.info"
55 | ;;
56 | esac
57 |
58 | # Set auto-update parameter based on AUTO_UPDATE
59 | if [ "$AUTO_UPDATE" = true ]; then
60 | AUTO_UPDATE_PARAM=""
61 | else
62 | AUTO_UPDATE_PARAM="--autoupdate-off"
63 | fi
64 |
65 | if [ "$HEARTBEAT" = true ]; then
66 | HEARTBEAT_PARAM="--heartbeat"
67 | else
68 | HEARTBEAT_PARAM=""
69 | fi
70 |
71 | ###################################
72 | # STOP AND WAIT FOR CLEANUP
73 | ###################################
74 | VALIDATOR="sn34-validator"
75 | GENERATOR="sn34-generator"
76 | PROXY="sn34-proxy"
77 |
78 | # Stop any existing processes
79 | if pm2 list | grep -q "$VALIDATOR"; then
80 | echo "'$VALIDATOR' is already running. Deleting it..."
81 | pm2 delete $VALIDATOR
82 | sleep 1
83 | fi
84 |
85 | if pm2 list | grep -q "$GENERATOR"; then
86 | echo "'$GENERATOR' is already running. Deleting it..."
87 | pm2 delete $GENERATOR
88 | sleep 2
89 | fi
90 |
91 | if pm2 list | grep -q "$PROXY"; then
92 | echo "'$PROXY' is already running. Deleting it..."
93 | pm2 delete $PROXY
94 | sleep 1
95 | fi
96 |
97 |
98 | ###################################
99 | # START PROCESSES
100 | ###################################
101 | SN34_CACHE_DIR=$(eval echo "$SN34_CACHE_DIR")
102 |
103 | echo "Starting validator and generator | chain_endpoint: $CHAIN_ENDPOINT | netuid: $NETUID"
104 |
105 | # Run data generator
106 | pm2 start neurons/generator.py \
107 | --interpreter python3 \
108 | --kill-timeout 2000 \
109 | --name $GENERATOR \
110 | -- \
111 | --wallet.name $WALLET_NAME \
112 | --wallet.hotkey $WALLET_HOTKEY \
113 | --netuid $NETUID \
114 | --subtensor.chain_endpoint $CHAIN_ENDPOINT \
115 | --cache-dir $SN34_CACHE_DIR \
116 | --device $DEVICE
117 |
118 | # Run validator
119 | pm2 start neurons/validator.py \
120 | --interpreter python3 \
121 | --kill-timeout 1000 \
122 | --name $VALIDATOR \
123 | -- \
124 | --wallet.name $WALLET_NAME \
125 | --wallet.hotkey $WALLET_HOTKEY \
126 | --netuid $NETUID \
127 | --subtensor.chain_endpoint $CHAIN_ENDPOINT \
128 | --epoch-length 360 \
129 | --cache-dir $SN34_CACHE_DIR \
130 | --proxy.port $PROXY_PORT \
131 | $LOG_PARAM \
132 | $AUTO_UPDATE_PARAM \
133 | $HEARTBEAT_PARAM
134 |
135 | # Run validator proxy
136 | pm2 start neurons/proxy.py \
137 | --interpreter python3 \
138 | --kill-timeout 1000 \
139 | --name $PROXY \
140 | -- \
141 | --wallet.name $WALLET_NAME \
142 | --wallet.hotkey $WALLET_HOTKEY \
143 | --netuid $NETUID \
144 | --subtensor.chain_endpoint $CHAIN_ENDPOINT \
145 | --proxy.port $PROXY_PORT \
146 | --proxy.external_port $PROXY_EXTERNAL_PORT \
147 | $LOG_PARAM
148 |
--------------------------------------------------------------------------------