├── .env.miner.template ├── .env.validator.template ├── .gitignore ├── LICENSE ├── README.md ├── VERSION ├── bitmind ├── __init__.py ├── autoupdater.py ├── cache │ ├── __init__.py │ ├── cache_fs.py │ ├── cache_system.py │ ├── datasets │ │ ├── __init__.py │ │ ├── dataset_registry.py │ │ └── datasets.py │ ├── sampler │ │ ├── __init__.py │ │ ├── base.py │ │ ├── image_sampler.py │ │ ├── sampler_registry.py │ │ └── video_sampler.py │ ├── updater │ │ ├── __init__.py │ │ ├── base.py │ │ ├── image_updater.py │ │ ├── updater_registry.py │ │ └── video_updater.py │ └── util │ │ ├── __init__.py │ │ ├── download.py │ │ ├── extract.py │ │ ├── filesystem.py │ │ └── video.py ├── config.py ├── encoding.py ├── epistula.py ├── generation │ ├── __init__.py │ ├── generation_pipeline.py │ ├── model_registry.py │ ├── models.py │ ├── prompt_generator.py │ └── util │ │ ├── __init__.py │ │ ├── image.py │ │ ├── model.py │ │ └── prompt.py ├── metagraph.py ├── scoring │ ├── __init__.py │ ├── eval_engine.py │ └── miner_history.py ├── transforms.py ├── types.py ├── utils.py └── wandb_utils.py ├── docs ├── Incentive.md ├── Mining.md ├── Validating.md └── static │ ├── Bitmind-Logo.png │ ├── Join-BitMind-Discord.png │ └── Subnet-Arch.png ├── min_compute.yml ├── neurons ├── __init__.py ├── base.py ├── generator.py ├── miner.py ├── proxy.py └── validator.py ├── pyproject.toml ├── requirements-git.txt ├── requirements.txt ├── setup.sh ├── start_miner.sh └── start_validator.sh /.env.miner.template: -------------------------------------------------------------------------------- 1 | # ======= Miner Configuration (FILL IN) ======= 2 | # Wallet 3 | WALLET_NAME= 4 | WALLET_HOTKEY= 5 | 6 | # Network 7 | CHAIN_ENDPOINT=wss://entrypoint-finney.opentensor.ai:443 8 | # OTF public finney endpoint: wss://entrypoint-finney.opentensor.ai:443 9 | # OTF public testnet endpoint: wss://test.finney.opentensor.ai:443/ 10 | 11 | # Axon port and (optionally) ip 12 | AXON_PORT=8091 13 | AXON_EXTERNAL_IP=[::] 14 | 15 | FORCE_VPERMIT=true 16 | 17 | # Device for detection models 18 | DEVICE=cpu 19 | 20 | # Logging 21 | LOGLEVEL=trace -------------------------------------------------------------------------------- /.env.validator.template: -------------------------------------------------------------------------------- 1 | # ======= Validator Configuration (FILL IN) ======= 2 | # Wallet 3 | WALLET_NAME= 4 | WALLET_HOTKEY= 5 | 6 | # API Keys 7 | WANDB_API_KEY= 8 | HUGGING_FACE_TOKEN= 9 | 10 | # Network 11 | CHAIN_ENDPOINT=wss://entrypoint-finney.opentensor.ai:443 12 | # OTF public finney endpoint: wss://entrypoint-finney.opentensor.ai:443 13 | # OTF public testnet endpoint: wss://test.finney.opentensor.ai:443/ 14 | 15 | # Validator Proxy 16 | PROXY_PORT=10913 17 | PROXY_EXTERNAL_PORT=10913 18 | 19 | # Cache config 20 | SN34_CACHE_DIR=~/.cache/sn34 21 | HEARTBEAT=true 22 | 23 | # Generator config 24 | GENERATION_BATCH_SIZE=3 25 | DEVICE=cuda 26 | 27 | # Other 28 | LOGLEVEL=info 29 | AUTO_UPDATE=true -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | testing/ 163 | data/ 164 | checkpoints/ 165 | .requirements_installed 166 | base_miner/NPR/weights/* 167 | base_miner/NPR/logs/* 168 | base_miner/DFB/weights/* 169 | base_miner/DFB/logs/* 170 | miner_eval.py 171 | *.env 172 | *~ 173 | wandb/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | Copyright © 2023 Yuma Rao 3 | Copyright © 2025 BitMind 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 6 | documentation files (the "Software"), to deal in the Software without restriction, including without limitation 7 | the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 8 | and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of 11 | the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 14 | THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 15 | THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 16 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 17 | DEALINGS IN THE SOFTWARE. 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | BitMind Logo 3 |

4 |

SN34
Deepfake Detection

5 | 6 |
7 | Applications 8 |
9 | 10 |
11 | Mining Guide · 12 | Validator Guide · 13 | Incentive Mechanism Overview 14 |
15 | 16 |
17 | HuggingFace · 18 | Mainnet 34 W&B · 19 | Testnet 168 W&B 20 |
21 | 22 |
23 | Leaderboard 24 |
25 | 26 | ## Decentralized Detection of AI Generated Content 27 | The explosive growth of generative AI technology has unleashed an unprecedented wave of synthetic media creation. AI-generated audiovisual content has become remarkably sophisticated, oftentimes indistinguishable from authentic media. This development presents a critical challenge to information integrity and societal trust in the digital age, as the line between real and synthetic content continues to blur. 28 | 29 | To address this growing challenge, SN34 aims to create the most accurate fully-generalized detection system. Here, fully-generalized means that the system is capable of detecting both synthetic and semi-synthetic media with high degrees of accuracy regardless of their content or what model generated them. Our incentive mechanism evolves alongside state-of-the-art generative AI, rewarding miners whose detection algorithms best adapt to new forms of synthetic content. 30 | 31 | 32 | ## Core Components 33 | 34 | > This documentation assumes basic familiarity with [Bittensor concepts](https://docs.bittensor.com/learn/bittensor-building-blocks). 35 | 36 | Miners 37 | 38 | - Miners are tasked with running binary classifiers that discern between genuine and AI-generated content, and are rewarded based on their accuracy. 39 | - For each challenge, a miner is presented an image or video and is required to respond with a multiclass prediction [$p_{real}$, $p_{synthetic}$, $p_{semisynthetic}$] indicating whether the media is real, fully generated, or partially modified by AI. 40 | 41 | 42 | Validators 43 | - Validators challenge miners with a balanced mix of real and synthetic media drawn from a diverse pool of sources. 44 | - We continually add new datasets and generative models to our validators in order to evolve the subnet's detection capabilities alongside advances in generative AI. 45 | 46 | 47 | ## Subnet Architecture 48 | 49 | Overview of the validator neuron, miner neuron, and other components external to the subnet. 50 | 51 | ![Subnet Architecture](docs/static/Subnet-Arch.png) 52 | 53 |
54 | Challenge Generation and Scoring (Peach Arrows) 55 | 62 |
63 | 64 |
65 | Data Generation and Downloads (Blue Arrows) 66 | The blue arrows show how the validator media cache is maintained by two parallel tracks: 67 | 71 |
72 | 73 |
74 | Organic Traffic (Green Arrows) 75 | 76 | Application requests are distributed to validators by an API server and load balancer in BitMind's cloud. A vector database caches subnet responses to avoid uncessary repetitive calls coming from salient images on the internet. 77 |
78 | 79 | 80 | 81 | ## Community 82 | 83 |

84 | 85 | Join us on Discord 86 | 87 |

88 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 3.0.10 2 | -------------------------------------------------------------------------------- /bitmind/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "3.0.10" 2 | 3 | version_split = __version__.split(".") 4 | __spec_version__ = ( 5 | (100000 * int(version_split[0])) 6 | + (1000 * int(version_split[1])) 7 | + (10 * int(version_split[2])) 8 | ) 9 | -------------------------------------------------------------------------------- /bitmind/autoupdater.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # Copyright © 2023 Yuma Rao 3 | # Copyright © 2024 Manifold Labs 4 | # Copyright © 2025 BitMind 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 7 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 8 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 9 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 10 | 11 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 12 | # the Software. 13 | 14 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 15 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 17 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import signal 21 | import time 22 | import os 23 | import requests 24 | import bittensor as bt 25 | import bitmind 26 | 27 | 28 | def autoupdate(branch: str = "main", force=False): 29 | """ 30 | Automatically updates the codebase to the latest version available on the specified branch. 31 | 32 | This function checks the remote repository for the latest version by fetching the VERSION file from the specified branch. 33 | If the local version is older than the remote version, it performs a git pull to update the local codebase to the latest version. 34 | After successfully updating, it restarts the application with the updated code. 35 | 36 | Args: 37 | - branch (str): The name of the branch to check for updates. Defaults to "main". 38 | 39 | Note: 40 | - The function assumes that the local codebase is a git repository and has the same structure as the remote repository. 41 | - It requires git to be installed and accessible from the command line. 42 | - The function will restart the application using the same command-line arguments it was originally started with. 43 | - If the update fails, manual intervention is required to resolve the issue and restart the application. 44 | """ 45 | bt.logging.info("Checking for updates...") 46 | try: 47 | github_url = f"https://raw.githubusercontent.com/BitMind-AI/bitmind-subnet/{branch}/VERSION?ts={time.time()}" 48 | response = requests.get( 49 | github_url, 50 | headers={ 51 | "Cache-Control": "no-cache, no-store, must-revalidate", 52 | "Pragma": "no-cache", 53 | "Expires": "0" 54 | }, 55 | ) 56 | response.raise_for_status() 57 | repo_version = response.content.decode() 58 | latest_version = int("".join(repo_version.split("."))) 59 | local_version = int("".join(bitmind.__version__.split("."))) 60 | 61 | bt.logging.info(f"Local version: {bitmind.__version__}") 62 | bt.logging.info(f"Latest version: {repo_version}") 63 | 64 | if latest_version > local_version or force: 65 | bt.logging.info(f"A newer version is available. Updating...") 66 | base_path = os.path.abspath(__file__) 67 | while os.path.basename(base_path) != "bitmind-subnet": 68 | base_path = os.path.dirname(base_path) 69 | 70 | os.system(f"cd {base_path} && git pull && chmod +x setup.sh && ./setup.sh") 71 | 72 | with open(os.path.join(base_path, "VERSION")) as f: 73 | new_version = f.read().strip() 74 | new_version = int("".join(new_version.split("."))) 75 | 76 | if new_version == latest_version: 77 | bt.logging.info("Updated successfully.") 78 | 79 | bt.logging.info("Restarting generator...") 80 | subprocess.run(["pm2", "reload", "sn34-generator"], check=True) 81 | 82 | bt.logging.info("Restarting proxy...") 83 | subprocess.run(["pm2", "reload", "sn34-proxy"], check=True) 84 | 85 | bt.logging.info(f"Restarting validator") 86 | os.kill(os.getpid(), signal.SIGINT) 87 | else: 88 | bt.logging.error("Update failed. Manual update required.") 89 | except Exception as e: 90 | bt.logging.error(f"Update check failed: {e}") 91 | -------------------------------------------------------------------------------- /bitmind/cache/__init__.py: -------------------------------------------------------------------------------- 1 | from .cache_system import CacheSystem 2 | -------------------------------------------------------------------------------- /bitmind/cache/cache_system.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Type 2 | import traceback 3 | 4 | import asyncio 5 | import bittensor as bt 6 | 7 | from bitmind.types import CacheUpdaterConfig, CacheConfig, Modality, MediaType 8 | from bitmind.cache.datasets import DatasetRegistry, initialize_dataset_registry 9 | from bitmind.cache.updater import ( 10 | BaseUpdater, 11 | UpdaterRegistry, 12 | ImageUpdater, 13 | VideoUpdater, 14 | ) 15 | from bitmind.cache.sampler import ( 16 | BaseSampler, 17 | SamplerRegistry, 18 | ImageSampler, 19 | VideoSampler, 20 | ) 21 | 22 | 23 | class CacheSystem: 24 | """ 25 | Main facade for the caching system. 26 | """ 27 | 28 | def __init__(self): 29 | self.dataset_registry = DatasetRegistry() 30 | self.updater_registry = UpdaterRegistry() 31 | self.sampler_registry = SamplerRegistry() 32 | 33 | async def initialize( 34 | self, 35 | base_dir, 36 | max_compressed_gb, 37 | max_media_gb, 38 | media_files_per_source, 39 | ): 40 | try: 41 | dataset_registry = initialize_dataset_registry() 42 | for dataset in dataset_registry.datasets: 43 | self.register_dataset(dataset) 44 | 45 | for modality in Modality: 46 | for media_type in MediaType: 47 | cache_config = CacheConfig( 48 | base_dir=base_dir, 49 | modality=modality.value, 50 | media_type=media_type.value, 51 | max_compressed_gb=max_compressed_gb, 52 | max_media_gb=max_media_gb, 53 | ) 54 | sampler_class = ( 55 | ImageSampler if modality == Modality.IMAGE else VideoSampler 56 | ) 57 | self.create_sampler( 58 | name=f"{media_type.value}_{modality.value}_sampler", 59 | sampler_class=sampler_class, 60 | cache_config=cache_config, 61 | ) 62 | 63 | # synthetic video updater not currently used, only generate locally 64 | if not ( 65 | modality == Modality.VIDEO and media_type == MediaType.SYNTHETIC 66 | ): 67 | updater_config = CacheUpdaterConfig( 68 | num_sources_per_dataset=1, # one compressed source per dataset for initialization 69 | num_items_per_source=media_files_per_source, 70 | ) 71 | updater_class = ( 72 | ImageUpdater if modality == Modality.IMAGE else VideoUpdater 73 | ) 74 | self.create_updater( 75 | name=f"{media_type.value}_{modality.value}_updater", 76 | updater_class=updater_class, 77 | cache_config=cache_config, 78 | updater_config=updater_config, 79 | ) 80 | 81 | # Initialize caches (populate if empty) 82 | bt.logging.info("Starting initial cache population") 83 | await self.initialize_caches() 84 | bt.logging.info("Initial cache population complete") 85 | 86 | except Exception as e: 87 | bt.logging.error(f"Error initializing caches: {e}") 88 | bt.logging.error(traceback.format_exc()) 89 | 90 | def register_dataset(self, dataset) -> None: 91 | """ 92 | Register a dataset with the system. 93 | 94 | Args: 95 | dataset: Dataset configuration to register 96 | """ 97 | self.dataset_registry.register(dataset) 98 | 99 | def register_datasets(self, datasets: List[Any]) -> None: 100 | """ 101 | Register multiple datasets with the system. 102 | 103 | Args: 104 | datasets: List of dataset configurations to register 105 | """ 106 | self.dataset_registry.register_all(datasets) 107 | 108 | def create_updater( 109 | self, 110 | name: str, 111 | updater_class: Type[BaseUpdater], 112 | cache_config: CacheConfig, 113 | updater_config: CacheUpdaterConfig, 114 | ) -> BaseUpdater: 115 | """ 116 | Create and register an updater. 117 | 118 | Args: 119 | name: Unique name for the updater 120 | updater_class: Updater class to instantiate 121 | cache_config: Cache configuration 122 | updater_config: Updater configuration 123 | 124 | Returns: 125 | The created updater instance 126 | """ 127 | updater = updater_class( 128 | cache_config=cache_config, 129 | updater_config=updater_config, 130 | data_manager=self.dataset_registry, 131 | ) 132 | self.updater_registry.register(name, updater) 133 | return updater 134 | 135 | def create_sampler( 136 | self, name: str, sampler_class: Type[BaseSampler], cache_config: CacheConfig 137 | ) -> BaseSampler: 138 | """ 139 | Create and register a sampler. 140 | 141 | Args: 142 | name: Unique name for the sampler 143 | sampler_class: Sampler class to instantiate 144 | cache_config: Cache configuration 145 | 146 | Returns: 147 | The created sampler instance 148 | """ 149 | sampler = sampler_class(cache_config=cache_config) 150 | self.sampler_registry.register(name, sampler) 151 | return sampler 152 | 153 | async def initialize_caches(self) -> None: 154 | """ 155 | Initialize all caches to ensure they have content. 156 | This is typically called during system startup. 157 | """ 158 | updaters = self.updater_registry.get_all() 159 | names = [name for name, _ in updaters.items()] 160 | bt.logging.debug(f"Initializing {len(updaters)} caches: {names}") 161 | 162 | cache_init_tasks = [] 163 | for name, updater in updaters.items(): 164 | cache_init_tasks.append(updater.initialize_cache()) 165 | 166 | if cache_init_tasks: 167 | await asyncio.gather(*cache_init_tasks) 168 | 169 | async def update_compressed_caches(self) -> None: 170 | """ 171 | Update all compressed caches in parallel 172 | This is typically called from a block callback. 173 | """ 174 | updaters = self.updater_registry.get_all() 175 | names = [name for name, _ in updaters.items()] 176 | bt.logging.trace(f"Updating {len(updaters)} compressed caches: {names}") 177 | 178 | tasks = [] 179 | for name, updater in updaters.items(): 180 | tasks.append(updater.update_compressed_cache()) 181 | 182 | if tasks: 183 | await asyncio.gather(*tasks) 184 | 185 | async def update_media_caches(self) -> None: 186 | """ 187 | Update all media caches in parallel. 188 | This is typically called from a block callback. 189 | """ 190 | updaters = self.updater_registry.get_all() 191 | names = [name for name, _ in updaters.items()] 192 | bt.logging.debug(f"Updating {len(updaters)} media caches: {names}") 193 | 194 | tasks = [] 195 | for name, updater in updaters.items(): 196 | tasks.append(updater.update_media_cache()) 197 | 198 | if tasks: 199 | await asyncio.gather(*tasks) 200 | 201 | async def sample(self, name: str, count: int, **kwargs) -> Optional[Dict[str, Any]]: 202 | """ 203 | Sample from a specific sampler. 204 | 205 | Args: 206 | name: Name of the sampler to use 207 | count: Number of items to sample 208 | 209 | Returns: 210 | The sampled items or None if sampler not found 211 | """ 212 | return await self.sampler_registry.sample(name, count, **kwargs) 213 | 214 | async def sample_all(self, count: int = 1) -> Dict[str, Dict[str, Any]]: 215 | """ 216 | Sample from all samplers. 217 | 218 | Args: 219 | count: Number of items to sample from each sampler 220 | 221 | Returns: 222 | Dictionary mapping sampler names to their samples 223 | """ 224 | return await self.sampler_registry.sample_all(count) 225 | 226 | @property 227 | def samplers(self): 228 | """ 229 | Get all registered samplers. 230 | 231 | Returns: 232 | Dictionary of sampler names to sampler instances 233 | """ 234 | return self.sampler_registry.get_all() 235 | 236 | @property 237 | def updaters(self): 238 | """ 239 | Get all registered updaters. 240 | 241 | Returns: 242 | Dictionary of updater names to updater instances 243 | """ 244 | return self.updater_registry.get_all() 245 | -------------------------------------------------------------------------------- /bitmind/cache/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import initialize_dataset_registry 2 | from .dataset_registry import DatasetRegistry 3 | -------------------------------------------------------------------------------- /bitmind/cache/datasets/dataset_registry.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from bitmind.types import DatasetConfig, MediaType, Modality 4 | 5 | 6 | class DatasetRegistry: 7 | """ 8 | Registry for dataset configurations with filtering capabilities. 9 | """ 10 | 11 | def __init__(self): 12 | self.datasets: List[DatasetConfig] = [] 13 | 14 | def register(self, dataset: DatasetConfig) -> None: 15 | """ 16 | Register a dataset with the system. 17 | 18 | Args: 19 | dataset: Dataset configuration to register 20 | """ 21 | self.datasets.append(dataset) 22 | 23 | def register_all(self, datasets: List[DatasetConfig]) -> None: 24 | """ 25 | Register multiple datasets with the system. 26 | 27 | Args: 28 | datasets: List of dataset configurations to register 29 | """ 30 | for dataset in datasets: 31 | self.register(dataset) 32 | 33 | def get_datasets( 34 | self, 35 | modality: Optional[Modality] = None, 36 | media_type: Optional[MediaType] = None, 37 | tags: Optional[List[str]] = None, 38 | exclude_tags: Optional[List[str]] = None, 39 | enabled_only: bool = True, 40 | ) -> List[DatasetConfig]: 41 | """ 42 | Get datasets filtered by type, media_type, and/or tags. 43 | 44 | Args: 45 | modality: Filter by dataset type 46 | media_type: Filter by media_type 47 | tags: Filter by tags (dataset must have ALL specified tags) 48 | enabled_only: Only return enabled datasets 49 | 50 | Returns: 51 | List of matching datasets 52 | """ 53 | result = self.datasets 54 | 55 | if enabled_only: 56 | result = [d for d in result if d.enabled] 57 | 58 | if modality: 59 | if isinstance(modality, str): 60 | modality = Modality(modality.lower()) 61 | result = [d for d in result if d.type == modality] 62 | 63 | if media_type: 64 | if isinstance(media_type, str): 65 | media_type = MediaType(media_type.lower()) 66 | result = [d for d in result if d.media_type == media_type] 67 | 68 | if tags: 69 | result = [d for d in result if all(tag in d.tags for tag in tags)] 70 | 71 | if exclude_tags: 72 | result = [ 73 | d for d in result if all(tag not in d.tags for tag in exclude_tags) 74 | ] 75 | 76 | return result 77 | 78 | def enable_dataset(self, path: str, enabled: bool = True) -> bool: 79 | """ 80 | Enable or disable a dataset by path. 81 | 82 | Args: 83 | path: Dataset path to enable/disable 84 | enabled: Whether to enable or disable 85 | 86 | Returns: 87 | True if successful, False if dataset not found 88 | """ 89 | for dataset in self.datasets: 90 | if dataset.path == path: 91 | dataset.enabled = enabled 92 | return True 93 | return False 94 | 95 | def get_dataset_by_path(self, path: str) -> Optional[DatasetConfig]: 96 | """ 97 | Get a dataset by its path. 98 | 99 | Args: 100 | path: Dataset path to find 101 | 102 | Returns: 103 | Dataset config or None if not found 104 | """ 105 | for dataset in self.datasets: 106 | if dataset.path == path: 107 | return dataset 108 | return None 109 | -------------------------------------------------------------------------------- /bitmind/cache/datasets/datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset definitions for the validator cache system 3 | """ 4 | 5 | from typing import List 6 | 7 | from bitmind.types import Modality, MediaType, DatasetConfig 8 | 9 | 10 | def get_image_datasets() -> List[DatasetConfig]: 11 | """ 12 | Get the list of image datasets used by the validator. 13 | 14 | Returns: 15 | List of image dataset configurations 16 | """ 17 | return [ 18 | # Real image datasets 19 | DatasetConfig( 20 | path="bitmind/bm-eidon-image", 21 | type=Modality.IMAGE, 22 | media_type=MediaType.REAL, 23 | tags=["frontier"], 24 | ), 25 | DatasetConfig( 26 | path="bitmind/bm-real", 27 | type=Modality.IMAGE, 28 | media_type=MediaType.REAL, 29 | ), 30 | DatasetConfig( 31 | path="bitmind/open-image-v7-256", 32 | type=Modality.IMAGE, 33 | media_type=MediaType.REAL, 34 | tags=["diverse"], 35 | ), 36 | DatasetConfig( 37 | path="bitmind/celeb-a-hq", 38 | type=Modality.IMAGE, 39 | media_type=MediaType.REAL, 40 | tags=["faces", "high-quality"], 41 | ), 42 | DatasetConfig( 43 | path="bitmind/ffhq-256", 44 | type=Modality.IMAGE, 45 | media_type=MediaType.REAL, 46 | tags=["faces", "high-quality"], 47 | ), 48 | DatasetConfig( 49 | path="bitmind/MS-COCO-unique-256", 50 | type=Modality.IMAGE, 51 | media_type=MediaType.REAL, 52 | tags=["diverse"], 53 | ), 54 | DatasetConfig( 55 | path="bitmind/AFHQ", 56 | type=Modality.IMAGE, 57 | media_type=MediaType.REAL, 58 | tags=["animals", "high-quality"], 59 | ), 60 | DatasetConfig( 61 | path="bitmind/lfw", 62 | type=Modality.IMAGE, 63 | media_type=MediaType.REAL, 64 | tags=["faces"], 65 | ), 66 | DatasetConfig( 67 | path="bitmind/caltech-256", 68 | type=Modality.IMAGE, 69 | media_type=MediaType.REAL, 70 | tags=["objects", "categorized"], 71 | ), 72 | DatasetConfig( 73 | path="bitmind/caltech-101", 74 | type=Modality.IMAGE, 75 | media_type=MediaType.REAL, 76 | tags=["objects", "categorized"], 77 | ), 78 | DatasetConfig( 79 | path="bitmind/dtd", 80 | type=Modality.IMAGE, 81 | media_type=MediaType.REAL, 82 | tags=["textures"], 83 | ), 84 | DatasetConfig( 85 | path="bitmind/idoc-mugshots-images", 86 | type=Modality.IMAGE, 87 | media_type=MediaType.REAL, 88 | tags=["faces"], 89 | ), 90 | # Synthetic image datasets 91 | DatasetConfig( 92 | path="bitmind/JourneyDB", 93 | type=Modality.IMAGE, 94 | media_type=MediaType.SYNTHETIC, 95 | tags=["midjourney"], 96 | ), 97 | DatasetConfig( 98 | path="bitmind/GenImage_MidJourney", 99 | type=Modality.IMAGE, 100 | media_type=MediaType.SYNTHETIC, 101 | tags=["midjourney"], 102 | ), 103 | DatasetConfig( 104 | path="bitmind/bm-aura-imagegen", 105 | type=Modality.IMAGE, 106 | media_type=MediaType.SYNTHETIC, 107 | tags=["sora"], 108 | ), 109 | # Semisynthetic image datasets 110 | DatasetConfig( 111 | path="bitmind/face-swap", 112 | type=Modality.IMAGE, 113 | media_type=MediaType.SEMISYNTHETIC, 114 | tags=["faces", "manipulated"], 115 | ), 116 | ] 117 | 118 | 119 | def get_video_datasets() -> List[DatasetConfig]: 120 | """ 121 | Get the list of video datasets used by the validator. 122 | """ 123 | return [ 124 | # Real video datasets 125 | DatasetConfig( 126 | path="bitmind/bm-eidon-video", 127 | type=Modality.VIDEO, 128 | media_type=MediaType.REAL, 129 | tags=["frontier"], 130 | compressed_format="zip", 131 | ), 132 | DatasetConfig( 133 | path="shangxd/imagenet-vidvrd", 134 | type=Modality.VIDEO, 135 | media_type=MediaType.REAL, 136 | tags=["diverse"], 137 | compressed_format="zip", 138 | ), 139 | DatasetConfig( 140 | path="nkp37/OpenVid-1M", 141 | type=Modality.VIDEO, 142 | media_type=MediaType.REAL, 143 | tags=["diverse", "large-zips"], 144 | compressed_format="zip", 145 | ), 146 | # Semisynthetic video datasets 147 | DatasetConfig( 148 | path="bitmind/semisynthetic-video", 149 | type=Modality.VIDEO, 150 | media_type=MediaType.SEMISYNTHETIC, 151 | tags=["faces"], 152 | compressed_format="zip", 153 | ), 154 | ] 155 | 156 | 157 | def initialize_dataset_registry(): 158 | """ 159 | Initialize and populate the dataset registry. 160 | 161 | Returns: 162 | Fully populated DatasetRegistry instance 163 | """ 164 | from bitmind.cache.datasets.dataset_registry import DatasetRegistry 165 | 166 | registry = DatasetRegistry() 167 | 168 | registry.register_all(get_image_datasets()) 169 | registry.register_all(get_video_datasets()) 170 | 171 | return registry 172 | -------------------------------------------------------------------------------- /bitmind/cache/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseSampler 2 | from .image_sampler import ImageSampler 3 | from .video_sampler import VideoSampler 4 | from .sampler_registry import SamplerRegistry 5 | -------------------------------------------------------------------------------- /bitmind/cache/sampler/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from pathlib import Path 3 | from typing import Any, Dict, List 4 | 5 | 6 | from bitmind.cache.cache_fs import CacheFS 7 | from bitmind.types import CacheConfig 8 | 9 | 10 | class BaseSampler(ABC): 11 | """ 12 | Base class for samplers that provide access to cached media. 13 | """ 14 | 15 | def __init__(self, cache_config: CacheConfig): 16 | self.cache_fs = CacheFS(cache_config) 17 | 18 | @property 19 | @abstractmethod 20 | def media_file_extensions(self) -> List[str]: 21 | """List of file extensions supported by this sampler""" 22 | pass 23 | 24 | @abstractmethod 25 | async def sample(self, count: int) -> Dict[str, Any]: 26 | """ 27 | Sample items from the media cache. 28 | 29 | Args: 30 | count: Number of items to sample 31 | 32 | Returns: 33 | Dictionary with sampled items information 34 | """ 35 | pass 36 | 37 | def get_available_files(self, use_index=True) -> List[Path]: 38 | """Get list of available files in the media cache""" 39 | return self.cache_fs.get_files( 40 | cache_type="media", 41 | file_extensions=self.media_file_extensions, 42 | use_index=use_index, 43 | ) 44 | 45 | def get_available_count(self, use_index=True) -> int: 46 | """Get count of available files in the media cache""" 47 | return len(self.get_available_files(use_index)) 48 | -------------------------------------------------------------------------------- /bitmind/cache/sampler/image_sampler.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import io 4 | from pathlib import Path 5 | from typing import Dict, List, Any 6 | 7 | import cv2 8 | import numpy as np 9 | 10 | from bitmind.cache.sampler.base import BaseSampler 11 | from bitmind.cache.cache_fs import CacheConfig 12 | 13 | 14 | class ImageSampler(BaseSampler): 15 | """ 16 | Sampler for cached image data. 17 | 18 | This class provides access to images in the media cache, 19 | allowing sampling with or without metadata. 20 | """ 21 | 22 | def __init__(self, cache_config: CacheConfig): 23 | super().__init__(cache_config) 24 | 25 | @property 26 | def media_file_extensions(self) -> List[str]: 27 | """List of file extensions supported by this sampler""" 28 | return [".jpg", ".jpeg", ".png", ".webp"] 29 | 30 | async def sample( 31 | self, 32 | count: int = 1, 33 | remove_from_cache: bool = False, 34 | as_float32: bool = False, 35 | channels_first: bool = False, 36 | as_rgb: bool = True, 37 | ) -> Dict[str, Any]: 38 | """ 39 | Sample random images and their metadata from the cache. 40 | 41 | Args: 42 | count: Number of images to sample 43 | remove_from_cache: Whether to remove sampled images from cache 44 | 45 | Returns: 46 | Dictionary containing: 47 | - count: Number of images successfully sampled 48 | - items: List of dictionaries containing: 49 | - image: Image as numpy array in BGR format with shape (H, W, C) 50 | - path: Path to the image file 51 | - dataset: Source dataset name (if available) 52 | - metadata: Additional metadata 53 | """ 54 | cached_files = self.cache_fs.get_files( 55 | cache_type="media", 56 | file_extensions=self.media_file_extensions, 57 | group_by_source=True, 58 | ) 59 | 60 | if not cached_files: 61 | self.cache_fs._log_warning("No images available in cache") 62 | return {"count": 0, "items": []} 63 | 64 | sampled_items = [] 65 | 66 | attempts = 0 67 | max_attempts = count * 3 68 | 69 | while len(sampled_items) < count and attempts < max_attempts: 70 | attempts += 1 71 | 72 | source = random.choice(list(cached_files.keys())) 73 | if not cached_files[source]: 74 | del cached_files[source] 75 | if not cached_files: 76 | break 77 | continue 78 | 79 | image_path = random.choice(cached_files[source]) 80 | 81 | try: 82 | # Read image directly as numpy array using cv2 83 | image = cv2.imread(str(image_path)) 84 | if image is None: 85 | raise ValueError(f"Failed to load image {image_path}") 86 | 87 | if as_float32: # else np.uint8 88 | image = image.astype(np.float32) / 255.0 89 | 90 | if as_rgb: # else bgr 91 | image = image[:, :, [2, 1, 0]] 92 | 93 | if channels_first: # else channels last 94 | image = np.transpose(image, (2, 0, 1)) 95 | 96 | metadata = {} 97 | metadata_path = image_path.with_suffix(".json") 98 | if metadata_path.exists(): 99 | try: 100 | with open(metadata_path, "r") as f: 101 | metadata = json.load(f) 102 | except Exception as e: 103 | self.cache_fs._log_warning( 104 | f"Error loading metadata for {image_path}: {e}" 105 | ) 106 | 107 | item = { 108 | "image": image, 109 | "path": str(image_path), 110 | "metadata_path": str(metadata_path), 111 | "metadata": metadata, 112 | } 113 | 114 | if "source_parquet" in metadata: 115 | item["source"] = metadata["source_parquet"] 116 | 117 | if "original_index" in metadata: 118 | item["index"] = metadata["original_index"] 119 | 120 | sampled_items.append(item) 121 | 122 | if remove_from_cache: 123 | try: 124 | image_path.unlink(missing_ok=True) 125 | metadata_path.unlink(missing_ok=True) 126 | cached_files[source].remove(image_path) 127 | except Exception as e: 128 | self.cache_fs._log_warning( 129 | f"Failed to remove {image_path}: {e}" 130 | ) 131 | 132 | except Exception as e: 133 | self.cache_fs._log_warning(f"Failed to load image {image_path}: {e}") 134 | cached_files[source].remove(image_path) 135 | continue 136 | 137 | return {"count": len(sampled_items), "items": sampled_items} 138 | -------------------------------------------------------------------------------- /bitmind/cache/sampler/sampler_registry.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Any 2 | 3 | import bittensor as bt 4 | 5 | from .base import BaseSampler 6 | 7 | 8 | class SamplerRegistry: 9 | """ 10 | Registry for cache samplers. 11 | """ 12 | 13 | def __init__(self): 14 | self._samplers: Dict[str, BaseSampler] = {} 15 | 16 | def register(self, name: str, sampler: BaseSampler) -> None: 17 | if name in self._samplers: 18 | bt.logging.warning(f"Sampler {name} already registered, will be replaced") 19 | self._samplers[name] = sampler 20 | 21 | def get(self, name: str) -> Optional[BaseSampler]: 22 | return self._samplers.get(name) 23 | 24 | def get_all(self) -> Dict[str, BaseSampler]: 25 | return dict(self._samplers) 26 | 27 | def deregister(self, name: str) -> None: 28 | if name in self._samplers: 29 | del self._samplers[name] 30 | 31 | async def sample(self, name: str, count: int, **kwargs) -> Optional[Dict[str, Any]]: 32 | """ 33 | Sample from a specific sampler. 34 | 35 | Args: 36 | name: Name of the sampler to use 37 | count: Number of items to sample 38 | 39 | Returns: 40 | The sampled items or None if sampler not found 41 | """ 42 | sampler = self.get(name) 43 | if not sampler: 44 | bt.logging.error(f"Sampler {name} not found") 45 | return None 46 | 47 | return await sampler.sample(count, **kwargs) 48 | 49 | async def sample_all(self, count_per_sampler: int = 1) -> Dict[str, Dict[str, Any]]: 50 | """ 51 | Sample from all samplers. 52 | 53 | Args: 54 | count_per_sampler: Number of items to sample from each sampler 55 | 56 | Returns: 57 | Dictionary mapping sampler names to their samples 58 | """ 59 | results = {} 60 | for name, sampler in self._samplers.items(): 61 | results[name] = await sampler.sample(count_per_sampler) 62 | return results 63 | -------------------------------------------------------------------------------- /bitmind/cache/sampler/video_sampler.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import math 4 | import random 5 | import tempfile 6 | 7 | from pathlib import Path 8 | from typing import Dict, List, Any, Optional 9 | from io import BytesIO 10 | 11 | import ffmpeg 12 | import numpy as np 13 | from PIL import Image 14 | 15 | from bitmind.cache.sampler.base import BaseSampler 16 | from bitmind.cache.cache_fs import CacheConfig 17 | from bitmind.cache.util.video import get_video_metadata 18 | 19 | 20 | class VideoSampler(BaseSampler): 21 | """ 22 | Sampler for cached video data. 23 | 24 | This class provides access to videos in the media cache, 25 | allowing sampling of video segments as binary data. 26 | """ 27 | 28 | def __init__(self, cache_config: CacheConfig): 29 | super().__init__(cache_config) 30 | 31 | @property 32 | def media_file_extensions(self) -> List[str]: 33 | """List of file extensions supported by this sampler""" 34 | return [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".webm"] 35 | 36 | async def sample( 37 | self, 38 | count: int = 1, 39 | remove_from_cache: bool = False, 40 | min_duration: float = 1.0, 41 | max_duration: float = 6.0, 42 | max_frames: int = 144, 43 | ) -> Dict[str, Any]: 44 | """ 45 | Sample random video segments from the cache. 46 | 47 | Args: 48 | count: Number of videos to sample 49 | remove_from_cache: Whether to remove sampled videos from cache 50 | 51 | Returns: 52 | Dictionary containing: 53 | - count: Number of videos successfully sampled 54 | - items: List of dictionaries containing video binary data and metadata 55 | """ 56 | cached_files = self.cache_fs.get_files( 57 | cache_type="media", 58 | file_extensions=self.media_file_extensions, 59 | group_by_source=True, 60 | ) 61 | 62 | if not cached_files: 63 | self.cache_fs._log_warning("No videos available in cache") 64 | return {"count": 0, "items": []} 65 | 66 | sampled_items = [] 67 | for _ in range(count): 68 | if not cached_files: 69 | break 70 | 71 | video_result = await self._sample_frames( 72 | files=cached_files, 73 | min_duration=min_duration, 74 | max_duration=max_duration, 75 | max_frames=max_frames, 76 | remove_from_cache=remove_from_cache, 77 | ) 78 | 79 | if video_result: 80 | sampled_items.append(video_result) 81 | 82 | return {"count": len(sampled_items), "items": sampled_items} 83 | 84 | async def _sample_frames( 85 | self, 86 | files, 87 | min_duration: float = 1.0, 88 | max_duration: float = 6.0, 89 | max_fps: float = 30.0, 90 | max_frames: int = 144, 91 | remove_from_cache: bool = False, 92 | as_float32: bool = False, 93 | channels_first: bool = False, 94 | as_rgb: bool = True, 95 | ) -> Optional[Dict[str, Any]]: 96 | """ 97 | Sample a random video segment and return it as a numpy array. 98 | 99 | Args: 100 | files: Dict mapping source names to lists of video file paths 101 | min_duration: Minimum duration of video segment to extract in seconds 102 | max_duration: Maximum duration of video segment to extract in seconds 103 | max_fps: Maximum frame rate to use when sampling frames 104 | max_frames: Maximum number of frames to extract 105 | remove_from_cache: Whether to remove the source video from cache 106 | as_float32: Whether to return frames as float32 (0-1) instead of uint8 (0-255) 107 | channels_first: Whether to return frames with channels first (TCHW) instead of channels last (THWC) 108 | as_rgb: Whether to return frames in RGB format (True) or BGR format (False) 109 | 110 | Returns: 111 | Dictionary containing: 112 | - frames: Video frames as numpy array with shape (T,H,W,C) 113 | - metadata: Video metadata 114 | - source: Source information 115 | - segment: Information about the extracted segment 116 | Or None if sampling fails 117 | """ 118 | for _ in range(5): 119 | if not files: 120 | self.cache_fs._log_warning("No more videos available to try") 121 | return None 122 | 123 | source = random.choice(list(files.keys())) 124 | if not files[source]: 125 | del files[source] 126 | continue 127 | 128 | video_path = random.choice(files[source]) 129 | 130 | try: 131 | if not video_path.exists(): 132 | files[source].remove(video_path) 133 | continue 134 | 135 | try: 136 | video_info = get_video_metadata(str(video_path)) 137 | total_duration = video_info.get("duration", 0) 138 | width = int(video_info.get("width", 256)) 139 | height = int(video_info.get("height", 256)) 140 | reported_fps = float(video_info.get("fps", max_fps)) 141 | except Exception as e: 142 | self.cache_fs._log_error( 143 | f"Unable to extract video metadata from {str(video_path)}: {e}" 144 | ) 145 | files[source].remove(video_path) 146 | continue 147 | 148 | if ( 149 | reported_fps > max_fps 150 | or reported_fps <= 0 151 | or not math.isfinite(reported_fps) 152 | ): 153 | self.cache_fs._log_warning( 154 | f"Unreasonable FPS ({reported_fps}) detected in {video_path}, capping at {max_fps}" 155 | ) 156 | frame_rate = max_fps 157 | else: 158 | frame_rate = reported_fps 159 | 160 | target_duration = random.uniform(min_duration, max_duration) 161 | target_duration = min(target_duration, total_duration) 162 | 163 | num_frames = int(target_duration * frame_rate) + 1 164 | num_frames = min(num_frames, max_frames) 165 | 166 | actual_duration = (num_frames - 1) / frame_rate 167 | 168 | max_start = max(0, total_duration - actual_duration) 169 | start_time = random.uniform(0, max_start) 170 | 171 | frames = [] 172 | no_data = [] 173 | 174 | for i in range(num_frames): 175 | timestamp = start_time + (i / frame_rate) 176 | 177 | try: 178 | out_bytes, err = ( 179 | ffmpeg.input(str(video_path), ss=str(timestamp)) 180 | .filter("select", "eq(n,0)") 181 | .output( 182 | "pipe:", 183 | vframes=1, 184 | format="image2", 185 | vcodec="png", 186 | loglevel="error", 187 | ) 188 | .run(capture_stdout=True, capture_stderr=True) 189 | ) 190 | 191 | if not out_bytes: 192 | no_data.append(timestamp) 193 | continue 194 | 195 | try: 196 | frame = Image.open(BytesIO(out_bytes)) 197 | frame.load() # Verify image can be loaded 198 | frames.append(np.array(frame)) 199 | except Exception as e: 200 | self.cache_fs._log_error( 201 | f"Failed to process frame at {timestamp}s: {e}" 202 | ) 203 | continue 204 | 205 | except ffmpeg.Error as e: 206 | self.cache_fs._log_error( 207 | f"FFmpeg error at {timestamp}s: {e.stderr.decode()}" 208 | ) 209 | continue 210 | 211 | if len(no_data) > 0: 212 | tmin, tmax = min(no_data), max(no_data) 213 | self.cache_fs._log_warning( 214 | f"No data received for {len(no_data)} frames between {tmin} and {tmax}" 215 | ) 216 | 217 | if not frames: 218 | self.cache_fs._log_warning( 219 | f"No frames successfully extracted from {video_path}" 220 | ) 221 | files[source].remove(video_path) 222 | continue 223 | 224 | frames = np.stack(frames, axis=0) 225 | 226 | if as_float32: 227 | frames = frames.astype(np.float32) / 255.0 228 | 229 | if not as_rgb: 230 | frames = frames[:, :, :, [2, 1, 0]] # RGB to BGR 231 | 232 | if channels_first: 233 | frames = np.transpose(frames, (0, 3, 1, 2)) 234 | 235 | metadata = {} 236 | metadata_path = video_path.with_suffix(".json") 237 | if metadata_path.exists(): 238 | try: 239 | with open(metadata_path, "r") as f: 240 | metadata = json.load(f) 241 | except Exception as e: 242 | self.cache_fs._log_warning( 243 | f"Error loading metadata for {video_path}: {e}" 244 | ) 245 | 246 | result = { 247 | "video": frames, 248 | "path": str(video_path), 249 | "metadata_path": str(metadata_path), 250 | "metadata": metadata, 251 | "segment": { 252 | "start_time": start_time, 253 | "duration": actual_duration, 254 | "fps": frame_rate, 255 | "width": width, 256 | "height": height, 257 | "num_frames": len(frames), 258 | }, 259 | } 260 | 261 | if remove_from_cache: 262 | try: 263 | video_path.unlink(missing_ok=True) 264 | metadata_path.unlink(missing_ok=True) 265 | files[source].remove(video_path) 266 | except Exception as e: 267 | self.cache_fs._log_warning( 268 | f"Failed to remove {video_path}: {e}" 269 | ) 270 | 271 | self.cache_fs._log_debug( 272 | f"Successfully sampled {actual_duration}s segment ({len(frames)} frames)" 273 | ) 274 | return result 275 | 276 | except Exception as e: 277 | self.cache_fs._log_error(f"Error sampling from {video_path}: {e}") 278 | files[source].remove(video_path) 279 | 280 | self.cache_fs._log_error("Failed to sample any video after multiple attempts") 281 | return None 282 | -------------------------------------------------------------------------------- /bitmind/cache/updater/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseUpdater 2 | from .image_updater import ImageUpdater 3 | from .video_updater import VideoUpdater 4 | from .updater_registry import UpdaterRegistry 5 | -------------------------------------------------------------------------------- /bitmind/cache/updater/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from pathlib import Path 3 | from typing import Any, List, Optional 4 | 5 | import numpy as np 6 | 7 | from bitmind.cache.cache_fs import CacheFS 8 | from bitmind.cache.datasets import DatasetRegistry 9 | from bitmind.types import CacheUpdaterConfig, CacheConfig, CacheType 10 | from bitmind.cache.util.download import list_hf_files, download_files 11 | from bitmind.cache.util.filesystem import ( 12 | filter_ready_files, 13 | wait_for_downloads_to_complete, 14 | is_source_complete, 15 | ) 16 | 17 | 18 | class BaseUpdater(ABC): 19 | """ 20 | Base class for cache updaters that handle downloading and extracting data. 21 | 22 | This version is designed to work with block callbacks rather than having 23 | its own internal timing logic. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | cache_config: CacheConfig, 29 | updater_config: CacheUpdaterConfig, 30 | data_manager: DatasetRegistry, 31 | ): 32 | self.cache_fs = CacheFS(cache_config) 33 | self.updater_config = updater_config 34 | self.dataset_registry = data_manager 35 | self._datasets = self._get_filtered_datasets() 36 | self._recently_downloaded_files = [] 37 | 38 | def _get_filtered_datasets( 39 | self, 40 | modality: Optional[str] = None, 41 | media_type: Optional[str] = None, 42 | tags: Optional[List[str]] = None, 43 | exclude_tags: Optional[List[str]] = None, 44 | ) -> List[Any]: 45 | """Get datasets that match the cache configuration""" 46 | modality = self.cache_fs.config.modality if modality is None else modality 47 | media_type = ( 48 | self.cache_fs.config.media_type if media_type is None else media_type 49 | ) 50 | tags = self.cache_fs.config.tags if tags is None else tags 51 | 52 | return self.dataset_registry.get_datasets( 53 | modality=self.cache_fs.config.modality, 54 | media_type=self.cache_fs.config.media_type, 55 | tags=self.cache_fs.config.tags, 56 | exclude_tags=exclude_tags, 57 | ) 58 | 59 | @property 60 | @abstractmethod 61 | def media_file_extensions(self) -> List[str]: 62 | pass 63 | 64 | @property 65 | @abstractmethod 66 | def compressed_file_extension(self) -> str: 67 | pass 68 | 69 | @abstractmethod 70 | async def _extract_items_from_source( 71 | self, source_path: Path, count: int 72 | ) -> List[Path]: 73 | pass 74 | 75 | async def initialize_cache(self) -> None: 76 | """ 77 | This performs a one-time initialization to ensure the cache has 78 | content available, particularly useful during first startup. 79 | """ 80 | self.cache_fs._log_debug("Setting up cache") 81 | 82 | if self.cache_fs.is_empty(CacheType.MEDIA): 83 | if self.cache_fs.is_empty(CacheType.COMPRESSED): 84 | self.cache_fs._log_debug("Compressed cache empty; populating") 85 | await self.update_compressed_cache( 86 | n_sources_per_dataset=1, 87 | n_datasets=1, 88 | exclude_tags=["large-zips"], 89 | maybe_prune=False, 90 | ) 91 | 92 | self.cache_fs._log_debug( 93 | "Waiting for compressed files to finish downloading..." 94 | ) 95 | await wait_for_downloads_to_complete( 96 | self._recently_downloaded_files, 97 | ) 98 | self._recently_downloaded_files = [] 99 | 100 | self.cache_fs._log_debug( 101 | "Compressed files downloaded. Updating media cache." 102 | ) 103 | await self.update_media_cache(maybe_prune=False) 104 | else: 105 | self.cache_fs._log_debug( 106 | "Compressed sources available; Media cache empty; populating" 107 | ) 108 | await self.update_media_cache() 109 | 110 | async def update_compressed_cache( 111 | self, 112 | n_sources_per_dataset: Optional[int] = None, 113 | n_datasets: Optional[int] = None, 114 | exclude_tags: Optional[List[str]] = None, 115 | maybe_prune: bool = True, 116 | ) -> None: 117 | """ 118 | Update the compressed cache by downloading new files. 119 | 120 | Args: 121 | n_sources_per_dataset: Optional override for number of sources per dataset 122 | n_datasets: Optional limit on number of datasets to process 123 | """ 124 | if n_sources_per_dataset is None: 125 | n_sources_per_dataset = self.updater_config.num_sources_per_dataset 126 | 127 | if maybe_prune: 128 | await self.cache_fs.maybe_prune_cache( 129 | cache_type=CacheType.COMPRESSED, 130 | file_extensions=[self.compressed_file_extension], 131 | ) 132 | 133 | # Reset tracking list before new downloads 134 | self._recently_downloaded_files = [] 135 | 136 | datasets = self._get_filtered_datasets(exclude_tags=exclude_tags) 137 | if n_datasets is not None and n_datasets > 0: 138 | datasets = datasets[:n_datasets] 139 | np.random.shuffle(datasets) 140 | 141 | new_files = [] 142 | for dataset in datasets: 143 | try: 144 | filenames = self._list_remote_dataset_files(dataset.path) 145 | if not filenames: 146 | self.cache_fs._log_warning(f"No files found for {dataset.path}") 147 | continue 148 | 149 | remote_paths = self._get_download_urls(dataset.path, filenames) 150 | to_download = self._select_files_to_download( 151 | remote_paths, n_sources_per_dataset 152 | ) 153 | 154 | output_dir = self.cache_fs.compressed_dir / dataset.path.split("/")[-1] 155 | 156 | self.cache_fs._log_debug( 157 | f"Downloading {len(to_download)} files from {dataset.path}" 158 | ) 159 | batch_files = await self._download_files(to_download, output_dir) 160 | 161 | # Track downloaded files 162 | self._recently_downloaded_files.extend(batch_files) 163 | new_files.extend(batch_files) 164 | except Exception as e: 165 | self.cache_fs._log_error(f"Error downloading from {dataset.path}: {e}") 166 | 167 | if new_files: 168 | self.cache_fs._log_debug(f"Added {len(new_files)} new compressed files") 169 | else: 170 | self.cache_fs._log_warning(f"No new files were added to compressed cache") 171 | 172 | async def update_media_cache( 173 | self, n_items_per_source: Optional[int] = None, maybe_prune: bool = True 174 | ) -> None: 175 | """ 176 | Update the media cache by extracting from compressed sources. 177 | 178 | Args: 179 | n_items_per_source: Optional override for number of items per source 180 | """ 181 | if n_items_per_source is None: 182 | n_items_per_source = self.updater_config.num_items_per_source 183 | 184 | if maybe_prune: 185 | await self.cache_fs.maybe_prune_cache( 186 | cache_type=CacheType.MEDIA, file_extensions=self.media_file_extensions 187 | ) 188 | 189 | all_compressed_files = self.cache_fs.get_files( 190 | cache_type=CacheType.COMPRESSED, 191 | file_extensions=[self.compressed_file_extension], 192 | use_index=False, 193 | ) 194 | 195 | if not all_compressed_files: 196 | self.cache_fs._log_warning(f"No compressed sources available") 197 | return 198 | 199 | compressed_files = filter_ready_files(all_compressed_files) 200 | 201 | if not compressed_files: 202 | self.cache_fs._log_warning( 203 | f"No ready compressed sources available. Files may still be downloading." 204 | ) 205 | return 206 | 207 | valid_compressed_files = [] 208 | for path in compressed_files: 209 | if not is_source_complete(path): 210 | try: 211 | Path(path).unlink() 212 | except Exception as del_err: 213 | self.cache_fs._log_error( 214 | f"Failed to delete corrupted file {path}: {del_err}" 215 | ) 216 | else: 217 | valid_compressed_files.append(path) 218 | 219 | if len(valid_compressed_files) > 10: 220 | valid_compressed_files = np.random.choice( 221 | valid_compressed_files, size=10, replace=False 222 | ).tolist() 223 | 224 | new_files = [] 225 | for source in valid_compressed_files: 226 | try: 227 | items = await self._extract_items_from_source( 228 | source, n_items_per_source 229 | ) 230 | new_files.extend(items) 231 | except Exception as e: 232 | self.cache_fs._log_error(f"Error extracting from {source}: {e}") 233 | 234 | if new_files: 235 | self.cache_fs._log_debug(f"Added {len(new_files)} new items to media cache") 236 | else: 237 | self.cache_fs._log_warning(f"No new items were added to media cache") 238 | 239 | def num_media_files(self) -> int: 240 | count = self.cache_fs.num_files(CacheType.MEDIA, self.media_file_extensions) 241 | return count == 0 242 | 243 | def num_compressed_files(self) -> int: 244 | count = self.cache_fs.num_files( 245 | CacheType.COMPRESSED, [self.compressed_file_extension] 246 | ) 247 | return count == 0 248 | 249 | def _select_files_to_download(self, urls: List[str], count: int) -> List[str]: 250 | """Select random files to download""" 251 | return np.random.choice( 252 | urls, size=min(count, len(urls)), replace=False 253 | ).tolist() 254 | 255 | def _list_remote_dataset_files(self, dataset_path: str) -> List[str]: 256 | """List available files in a dataset with the parquet extension""" 257 | return list_hf_files( 258 | repo_id=dataset_path, extension=self.compressed_file_extension 259 | ) 260 | 261 | def _get_download_urls(self, dataset_path: str, filenames: List[str]) -> List[str]: 262 | """Get Hugging Face download URLs for data files""" 263 | return [ 264 | f"https://huggingface.co/datasets/{dataset_path}/resolve/main/{f}" 265 | for f in filenames 266 | ] 267 | 268 | async def _download_files(self, urls: List[str], output_dir: Path) -> List[Path]: 269 | """Download a subset of a remote dataset's compressed files""" 270 | return await download_files(urls, output_dir) 271 | -------------------------------------------------------------------------------- /bitmind/cache/updater/image_updater.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List 3 | import traceback 4 | 5 | from bitmind.cache.updater import BaseUpdater 6 | from bitmind.cache.datasets import DatasetRegistry 7 | from bitmind.cache.util.filesystem import is_parquet_complete 8 | from bitmind.types import CacheUpdaterConfig, CacheConfig 9 | 10 | 11 | class ImageUpdater(BaseUpdater): 12 | """ 13 | Updater for image data from parquet files. 14 | 15 | This class handles downloading parquet files from Hugging Face datasets 16 | and extracting images from them into the media cache. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | cache_config: CacheConfig, 22 | updater_config: CacheUpdaterConfig, 23 | data_manager: DatasetRegistry, 24 | ): 25 | super().__init__( 26 | cache_config=cache_config, 27 | updater_config=updater_config, 28 | data_manager=data_manager, 29 | ) 30 | 31 | @property 32 | def media_file_extensions(self) -> List[str]: 33 | """List of file extensions supported by this updater""" 34 | return [".jpg", ".jpeg", ".png", ".webp"] 35 | 36 | @property 37 | def compressed_file_extension(self) -> str: 38 | """File extension for compressed source files""" 39 | return ".parquet" 40 | 41 | async def _extract_items_from_source( 42 | self, source_path: Path, count: int 43 | ) -> List[Path]: 44 | """ 45 | Extract images from a parquet file. 46 | 47 | Args: 48 | source_path: Path to the parquet file 49 | count: Number of images to extract 50 | 51 | Returns: 52 | List of paths to extracted image files 53 | """ 54 | self.cache_fs._log_trace(f"Extracting up to {count} images from {source_path}") 55 | 56 | dataset_name = source_path.parent.name 57 | if not dataset_name: 58 | dataset_name = source_path.stem 59 | 60 | dest_dir = self.cache_fs.cache_dir / dataset_name 61 | dest_dir.mkdir(parents=True, exist_ok=True) 62 | 63 | try: 64 | from ..util import extract_images_from_parquet 65 | 66 | saved_files = extract_images_from_parquet( 67 | parquet_path=source_path, dest_dir=dest_dir, num_images=count 68 | ) 69 | 70 | self.cache_fs._log_trace( 71 | f"Extracted {len(saved_files)} images from {source_path}" 72 | ) 73 | return [Path(f) for f in saved_files] 74 | 75 | except Exception as e: 76 | self.cache_fs._log_error(f"Error extracting images from {source_path}: {e}") 77 | self.cache_fs._log_error(traceback.format_exc()) 78 | return [] 79 | -------------------------------------------------------------------------------- /bitmind/cache/updater/updater_registry.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | import bittensor as bt 4 | 5 | from bitmind.cache.updater import BaseUpdater 6 | 7 | 8 | class UpdaterRegistry: 9 | """ 10 | Registry for cache updaters. 11 | """ 12 | 13 | def __init__(self): 14 | self._updaters: Dict[str, BaseUpdater] = {} 15 | 16 | def register(self, name: str, updater: BaseUpdater) -> None: 17 | if name in self._updaters: 18 | bt.logging.warning(f"Updater {name} already registered, will be replaced") 19 | self._updaters[name] = updater 20 | 21 | def get(self, name: str) -> Optional[BaseUpdater]: 22 | return self._updaters.get(name) 23 | 24 | def get_all(self) -> Dict[str, BaseUpdater]: 25 | return dict(self._updaters) 26 | 27 | def deregister(self, name: str) -> None: 28 | if name in self._updaters: 29 | del self._updaters[name] 30 | -------------------------------------------------------------------------------- /bitmind/cache/updater/video_updater.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | from pathlib import Path 3 | from typing import List 4 | 5 | from bitmind.types import CacheUpdaterConfig, CacheConfig 6 | from bitmind.cache.updater import BaseUpdater 7 | from bitmind.cache.datasets import DatasetRegistry 8 | 9 | 10 | class VideoUpdater(BaseUpdater): 11 | """ 12 | Updater for video data from zip files. 13 | 14 | This class handles downloading zip files from Hugging Face datasets 15 | and extracting videos from them into the media cache. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | cache_config: CacheConfig, 21 | updater_config: CacheUpdaterConfig, 22 | data_manager: DatasetRegistry, 23 | ): 24 | super().__init__( 25 | cache_config=cache_config, 26 | updater_config=updater_config, 27 | data_manager=data_manager, 28 | ) 29 | 30 | @property 31 | def media_file_extensions(self) -> List[str]: 32 | """List of file extensions supported by this updater""" 33 | return [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".webm"] 34 | 35 | @property 36 | def compressed_file_extension(self) -> str: 37 | """File extension for compressed source files""" 38 | return ".zip" 39 | 40 | async def _extract_items_from_source( 41 | self, source_path: Path, count: int 42 | ) -> List[Path]: 43 | """ 44 | Extract videos from a zip file. 45 | 46 | Args: 47 | source_path: Path to the zip file 48 | count: Number of videos to extract 49 | 50 | Returns: 51 | List of paths to extracted video files 52 | """ 53 | self.cache_fs._log_trace(f"Extracting up to {count} videos from {source_path}") 54 | 55 | dataset_name = source_path.parent.name 56 | if not dataset_name: 57 | dataset_name = source_path.stem 58 | 59 | dest_dir = self.cache_fs.cache_dir / dataset_name 60 | dest_dir.mkdir(parents=True, exist_ok=True) 61 | 62 | try: 63 | from ..util import extract_videos_from_zip 64 | 65 | extracted_pairs = extract_videos_from_zip( 66 | zip_path=source_path, 67 | dest_dir=dest_dir, 68 | num_videos=count, 69 | file_extensions=set(self.media_file_extensions), 70 | ) 71 | 72 | # extract_videos_from_zip returns pairs of (video_path, metadata_path) 73 | # We just need the video paths for our return value 74 | video_paths = [Path(pair[0]) for pair in extracted_pairs] 75 | 76 | self.cache_fs._log_trace( 77 | f"Extracted {len(video_paths)} videos from {source_path}" 78 | ) 79 | return video_paths 80 | 81 | except Exception as e: 82 | self.cache_fs._log_trace(f"Error extracting videos from {source_path}: {e}") 83 | return [] 84 | -------------------------------------------------------------------------------- /bitmind/cache/util/__init__.py: -------------------------------------------------------------------------------- 1 | from bitmind.cache.util.filesystem import ( 2 | is_source_complete, 3 | is_zip_complete, 4 | is_parquet_complete, 5 | get_most_recent_update_time, 6 | ) 7 | 8 | from bitmind.cache.util.download import ( 9 | download_files, 10 | list_hf_files, 11 | openvid1m_err_handler, 12 | ) 13 | 14 | from bitmind.cache.util.video import ( 15 | get_video_duration, 16 | get_video_metadata, 17 | seconds_to_str, 18 | ) 19 | 20 | from bitmind.cache.util.extract import ( 21 | extract_videos_from_zip, 22 | extract_images_from_parquet, 23 | ) 24 | 25 | __all__ = [ 26 | # Filesystem 27 | "is_source_complete", 28 | "is_zip_complete", 29 | "is_parquet_complete", 30 | "get_most_recent_update_time", 31 | # Download 32 | "download_files", 33 | "list_hf_files", 34 | "openvid1m_err_handler", 35 | # Video 36 | "get_video_duration", 37 | "get_video_metadata", 38 | "seconds_to_str", 39 | # Extraction 40 | "extract_videos_from_zip", 41 | "extract_images_from_parquet", 42 | ] 43 | -------------------------------------------------------------------------------- /bitmind/cache/util/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import traceback 3 | from pathlib import Path 4 | from typing import List, Union, Optional 5 | 6 | import asyncio 7 | import aiohttp 8 | import bittensor as bt 9 | import huggingface_hub as hf_hub 10 | from requests.exceptions import RequestException 11 | 12 | 13 | def list_hf_files(repo_id, repo_type="dataset", extension=None): 14 | """List files from a Hugging Face repository. 15 | 16 | Args: 17 | repo_id: Repository ID 18 | repo_type: Type of repository ('dataset', 'model', etc.) 19 | extension: Filter files by extension 20 | 21 | Returns: 22 | List of files in the repository 23 | """ 24 | files = [] 25 | try: 26 | files = list(hf_hub.list_repo_files(repo_id=repo_id, repo_type=repo_type)) 27 | if extension: 28 | files = [f for f in files if f.endswith(extension)] 29 | except Exception as e: 30 | bt.logging.error(f"Failed to list files of type {extension} in {repo_id}: {e}") 31 | return files 32 | 33 | 34 | async def download_files( 35 | urls: List[str], output_dir: Union[str, Path], chunk_size: int = 8192 36 | ) -> List[Path]: 37 | """Download multiple files asynchronously. 38 | 39 | Args: 40 | urls: List of URLs to download 41 | output_dir: Directory to save the files 42 | chunk_size: Size of chunks to download at a time 43 | 44 | Returns: 45 | List of successfully downloaded file paths 46 | """ 47 | output_dir = Path(output_dir) 48 | output_dir.mkdir(parents=True, exist_ok=True) 49 | 50 | download_tasks = [] 51 | timeout = aiohttp.ClientTimeout( 52 | total=3600, 53 | ) 54 | 55 | async with aiohttp.ClientSession(timeout=timeout) as session: 56 | # Create download tasks for each URL 57 | for url in urls: 58 | download_tasks.append( 59 | download_single_file(session, url, output_dir, chunk_size) 60 | ) 61 | 62 | # Run all downloads concurrently and gather results 63 | downloaded_files = await asyncio.gather(*download_tasks, return_exceptions=True) 64 | 65 | # Filter out exceptions and return only successful downloads 66 | return [f for f in downloaded_files if isinstance(f, Path)] 67 | 68 | 69 | async def download_single_file( 70 | session: aiohttp.ClientSession, url: str, output_dir: Path, chunk_size: int 71 | ) -> Path: 72 | """Download a single file asynchronously. 73 | 74 | Args: 75 | session: aiohttp ClientSession to use for requests 76 | url: URL to download 77 | output_dir: Directory to save the file 78 | chunk_size: Size of chunks to download at a time 79 | 80 | Returns: 81 | Path to the downloaded file 82 | """ 83 | try: 84 | bt.logging.info(f"Downloading {url}") 85 | 86 | async with session.get(url) as response: 87 | if response.status != 200: 88 | bt.logging.error(f"Failed to download {url}: Status {response.status}") 89 | raise Exception(f"HTTP error {response.status}") 90 | 91 | filename = os.path.basename(url) 92 | filepath = output_dir / filename 93 | 94 | bt.logging.info(f"Writing to {filepath}") 95 | 96 | # Use async file I/O to write the file 97 | with open(filepath, "wb") as f: 98 | # Download and write in chunks 99 | async for chunk in response.content.iter_chunked(chunk_size): 100 | if chunk: # filter out keep-alive chunks 101 | f.write(chunk) 102 | 103 | return filepath 104 | 105 | except Exception as e: 106 | bt.logging.error(f"Error downloading {url}: {str(e)}") 107 | bt.logging.error(traceback.format_exc()) 108 | raise 109 | 110 | 111 | def openvid1m_err_handler( 112 | base_zip_url: str, 113 | output_path: Path, 114 | part_index: int, 115 | chunk_size: int = 8192, 116 | timeout: int = 300, 117 | ) -> Optional[Path]: 118 | """Synchronous error handler for OpenVid1M downloads that handles split files. 119 | 120 | Args: 121 | base_zip_url: Base URL for the zip parts 122 | output_path: Directory to save files 123 | part_index: Index of the part to download 124 | chunk_size: Size of download chunks 125 | timeout: Download timeout in seconds 126 | 127 | Returns: 128 | Path to combined file if successful, None otherwise 129 | """ 130 | part_urls = [ 131 | f"{base_zip_url}{part_index}_partaa", 132 | f"{base_zip_url}{part_index}_partab", 133 | ] 134 | error_log_path = output_path / "download_log.txt" 135 | downloaded_parts = [] 136 | 137 | # Download each part 138 | for part_url in part_urls: 139 | part_file_path = output_path / Path(part_url).name 140 | 141 | if part_file_path.exists(): 142 | bt.logging.warning(f"File {part_file_path} exists.") 143 | downloaded_parts.append(part_file_path) 144 | continue 145 | 146 | try: 147 | response = requests.get(part_url, stream=True, timeout=timeout) 148 | if response.status_code != 200: 149 | raise RequestException( 150 | f"HTTP {response.status_code}: {response.reason}" 151 | ) 152 | 153 | with open(part_file_path, "wb") as f: 154 | for chunk in response.iter_content(chunk_size=chunk_size): 155 | if chunk: # filter out keep-alive chunks 156 | f.write(chunk) 157 | 158 | bt.logging.info(f"File {part_url} saved to {part_file_path}") 159 | downloaded_parts.append(part_file_path) 160 | 161 | except Exception as e: 162 | error_message = f"File {part_url} download failed: {str(e)}\n" 163 | bt.logging.error(error_message) 164 | with open(error_log_path, "a") as error_log_file: 165 | error_log_file.write(error_message) 166 | return None 167 | 168 | if len(downloaded_parts) == len(part_urls): 169 | try: 170 | combined_file = output_path / f"OpenVid_part{part_index}.zip" 171 | combined_data = bytearray() 172 | for part_path in downloaded_parts: 173 | with open(part_path, "rb") as part_file: 174 | combined_data.extend(part_file.read()) 175 | 176 | with open(combined_file, "wb") as out_file: 177 | out_file.write(combined_data) 178 | 179 | for part_path in downloaded_parts: 180 | part_path.unlink() 181 | 182 | bt.logging.info(f"Successfully combined parts into {combined_file}") 183 | return combined_file 184 | 185 | except Exception as e: 186 | error_message = ( 187 | f"Failed to combine parts for index {part_index}: {str(e)}\n" 188 | ) 189 | bt.logging.error(error_message) 190 | with open(error_log_path, "a") as error_log_file: 191 | error_log_file.write(error_message) 192 | return None 193 | 194 | return None 195 | -------------------------------------------------------------------------------- /bitmind/cache/util/extract.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import hashlib 3 | import json 4 | import os 5 | import random 6 | import shutil 7 | from datetime import datetime 8 | from io import BytesIO 9 | from pathlib import Path 10 | from typing import List, Optional, Set, Tuple 11 | from zipfile import ZipFile 12 | 13 | import bittensor as bt 14 | import pyarrow.parquet as pq 15 | from PIL import Image 16 | 17 | 18 | def extract_videos_from_zip( 19 | zip_path: Path, 20 | dest_dir: Path, 21 | num_videos: int, 22 | file_extensions: Set[str] = {".mp4", ".avi", ".mov", ".mkv", ".wmv"}, 23 | include_checksums: bool = True, 24 | ) -> List[Tuple[str, str]]: 25 | """Extract random videos and their metadata from a zip file and save them to disk. 26 | 27 | Args: 28 | zip_path: Path to the zip file 29 | dest_dir: Directory to save videos and metadata 30 | num_videos: Number of videos to extract 31 | file_extensions: Set of valid video file extensions 32 | include_checksums: Whether to calculate and include file checksums in metadata 33 | 34 | Returns: 35 | List of tuples containing (video_path, metadata_path) 36 | """ 37 | dest_dir = Path(dest_dir) 38 | dest_dir.mkdir(parents=True, exist_ok=True) 39 | 40 | extracted_files = [] 41 | try: 42 | with ZipFile(zip_path) as zip_file: 43 | video_files = [ 44 | f 45 | for f in zip_file.namelist() 46 | if any(f.lower().endswith(ext) for ext in file_extensions) 47 | and "MACOSX" not in f 48 | ] 49 | if not video_files: 50 | bt.logging.warning(f"No video files found in {zip_path}") 51 | return extracted_files 52 | 53 | bt.logging.debug(f"{len(video_files)} video files found in {zip_path}") 54 | selected_videos = random.sample( 55 | video_files, min(num_videos, len(video_files)) 56 | ) 57 | 58 | bt.logging.debug( 59 | f"Extracting {len(selected_videos)} randomly sampled video files from {zip_path}" 60 | ) 61 | for video in selected_videos: 62 | try: 63 | # extract video and get metadata 64 | video_path = dest_dir / Path(video).name 65 | with zip_file.open(video) as source: 66 | with open(video_path, "wb") as target: 67 | shutil.copyfileobj(source, target) 68 | 69 | video_info = zip_file.getinfo(video) 70 | metadata = { 71 | "dataset": Path(zip_path).parent.name, 72 | "source_zip": str(zip_path), 73 | "path_in_zip": video, 74 | "extraction_date": datetime.now().isoformat(), 75 | "file_size": os.path.getsize(video_path), 76 | "zip_metadata": { 77 | "compress_size": video_info.compress_size, 78 | "file_size": video_info.file_size, 79 | "compress_type": video_info.compress_type, 80 | "date_time": datetime.strftime( 81 | datetime(*video_info.date_time), "%Y-%m-%d %H:%M:%S" 82 | ), 83 | }, 84 | } 85 | 86 | if include_checksums: 87 | with open(video_path, "rb") as f: 88 | file_data = f.read() 89 | metadata["checksums"] = { 90 | "md5": hashlib.md5(file_data).hexdigest(), 91 | "sha256": hashlib.sha256(file_data).hexdigest(), 92 | } 93 | 94 | metadata_filename = f"{video_path.stem}.json" 95 | metadata_path = dest_dir / metadata_filename 96 | 97 | with open(metadata_path, "w", encoding="utf-8") as f: 98 | json.dump(metadata, f, indent=2, ensure_ascii=False) 99 | 100 | extracted_files.append((str(video_path), str(metadata_path))) 101 | 102 | except Exception as e: 103 | bt.logging.warning(f"Error extracting {video}: {e}") 104 | continue 105 | 106 | except Exception as e: 107 | bt.logging.warning(f"Error processing zip file {zip_path}: {e}") 108 | 109 | return extracted_files 110 | 111 | 112 | def extract_images_from_parquet( 113 | parquet_path: Path, dest_dir: Path, num_images: int, seed: Optional[int] = None 114 | ) -> List[str]: 115 | """Extract random images and their metadata from a parquet file and save them to disk. 116 | 117 | Args: 118 | parquet_path: Path to the parquet file 119 | dest_dir: Directory to save images and metadata 120 | num_images: Number of images to extract 121 | seed: Random seed for sampling 122 | 123 | Returns: 124 | List of image file paths 125 | """ 126 | dest_dir = Path(dest_dir) 127 | dest_dir.mkdir(parents=True, exist_ok=True) 128 | 129 | # read parquet file, sample random image rows 130 | table = pq.read_table(parquet_path) 131 | df = table.to_pandas() 132 | sample_df = df.sample(n=min(num_images, len(df)), random_state=seed) 133 | image_col = next((col for col in sample_df.columns if "image" in col.lower()), None) 134 | metadata_cols = [c for c in sample_df.columns if c != image_col] 135 | 136 | saved_files = [] 137 | parquet_prefix = parquet_path.stem 138 | for idx, row in sample_df.iterrows(): 139 | try: 140 | img_data = row[image_col] 141 | if isinstance(img_data, dict): 142 | key = next( 143 | ( 144 | k 145 | for k in img_data 146 | if "bytes" in k.lower() or "image" in k.lower() 147 | ), 148 | None, 149 | ) 150 | img_data = img_data[key] 151 | 152 | try: 153 | img = Image.open(BytesIO(img_data)) 154 | except Exception: 155 | img_data = base64.b64decode(img_data) 156 | img = Image.open(BytesIO(img_data)) 157 | 158 | base_filename = f"{parquet_prefix}__image_{idx}" 159 | image_format = img.format.lower() if img.format else "png" 160 | img_filename = f"{base_filename}.{image_format}" 161 | img_path = dest_dir / img_filename 162 | img.save(img_path) 163 | 164 | metadata = { 165 | "dataset": Path(parquet_path).parent.name, 166 | "source_parquet": str(parquet_path), 167 | "original_index": str(idx), 168 | "image_format": image_format, 169 | "image_size": img.size, 170 | "image_mode": img.mode, 171 | } 172 | 173 | for col in metadata_cols: 174 | # Convert any non-serializable types to strings 175 | try: 176 | json.dumps({col: row[col]}) 177 | metadata[col] = row[col] 178 | except (TypeError, OverflowError): 179 | metadata[col] = str(row[col]) 180 | 181 | metadata_filename = f"{base_filename}.json" 182 | metadata_path = dest_dir / metadata_filename 183 | with open(metadata_path, "w", encoding="utf-8") as f: 184 | json.dump(metadata, f, indent=2, ensure_ascii=False) 185 | 186 | saved_files.append(str(img_path)) 187 | 188 | except Exception as e: 189 | bt.logging.warning(f"Failed to extract/save image {idx}: {e}") 190 | continue 191 | 192 | return saved_files 193 | -------------------------------------------------------------------------------- /bitmind/cache/util/filesystem.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Callable, Dict, List, Optional, Tuple, Union, Any 3 | import bittensor as bt 4 | import pyarrow.parquet as pq 5 | from zipfile import ZipFile, BadZipFile 6 | import asyncio 7 | import sys 8 | import time 9 | 10 | from bitmind.types import FileType 11 | 12 | 13 | def get_most_recent_update_time(directory: Path) -> float: 14 | """Get the most recent modification time of any file in directory.""" 15 | try: 16 | mtimes = [f.stat().st_mtime for f in directory.iterdir()] 17 | return max(mtimes) if mtimes else 0 18 | except Exception as e: 19 | bt.logging.error(f"Error getting modification times: {e}") 20 | return 0 21 | 22 | 23 | def is_source_complete(path: Union[str, Path]) -> Callable[[Path], bool]: 24 | """Checks integrity of parquet or zip file""" 25 | 26 | path = Path(path) 27 | if path.suffix.lower() == ".parquet": 28 | return is_parquet_complete(path) 29 | elif path.suffix.lower() == ".zip": 30 | return is_zip_complete(path) 31 | else: 32 | return None 33 | 34 | 35 | def is_zip_complete(zip_path: Union[str, Path], testzip=False) -> bool: 36 | try: 37 | with ZipFile(zip_path) as zf: 38 | if testzip: 39 | zf.testzip() 40 | else: 41 | zf.namelist() 42 | return True 43 | except (BadZipFile, Exception) as e: 44 | bt.logging.error(f"Zip file {zip_path} is invalid: {e}") 45 | return False 46 | 47 | 48 | def is_parquet_complete(path: Path) -> bool: 49 | try: 50 | with open(path, "rb") as f: 51 | pq.read_metadata(f) 52 | return True 53 | except Exception as e: 54 | bt.logging.error(f"Parquet file {path} is incomplete or corrupted: {e}") 55 | return False 56 | 57 | 58 | def get_dir_size( 59 | path: Union[str, Path], exclude_dirs: Optional[List[str]] = None 60 | ) -> Tuple[int, int]: 61 | if exclude_dirs is None: 62 | exclude_dirs = [] 63 | 64 | total_size = 0 65 | file_count = 0 66 | path_obj = Path(path) 67 | 68 | try: 69 | for item in path_obj.iterdir(): 70 | if item.is_dir() and item.name in exclude_dirs: 71 | continue 72 | elif item.is_file(): 73 | try: 74 | total_size += item.stat().st_size 75 | file_count += 1 76 | except (OSError, PermissionError): 77 | pass 78 | elif item.is_dir(): 79 | subdir_size, subdir_count = get_dir_size(item, exclude_dirs) 80 | total_size += subdir_size 81 | file_count += subdir_count 82 | except (PermissionError, OSError) as e: 83 | print(f"Error accessing {path}: {e}", file=sys.stderr) 84 | 85 | return total_size, file_count 86 | 87 | 88 | def scale_size(size: float, from_unit: str = "B", to_unit: str = "GB") -> float: 89 | if size == 0: 90 | return 0.0 91 | 92 | units = ["B", "KB", "MB", "GB", "TB", "PB"] 93 | from_unit, to_unit = from_unit.upper(), to_unit.upper() 94 | if from_unit not in units or to_unit not in units: 95 | raise ValueError(f"Units must be one of: {', '.join(units)}") 96 | 97 | from_index = units.index(from_unit) 98 | to_index = units.index(to_unit) 99 | scale_factor = from_index - to_index 100 | 101 | if scale_factor > 0: 102 | return size * (1024**scale_factor) 103 | elif scale_factor < 0: 104 | return size / (1024 ** abs(scale_factor)) 105 | return size 106 | 107 | 108 | def format_size( 109 | size: float, from_unit: str = "B", to_unit: Optional[str] = None 110 | ) -> str: 111 | if size == 0: 112 | return "0 B" 113 | 114 | units = ["B", "KB", "MB", "GB", "TB", "PB"] 115 | from_unit = from_unit.upper() 116 | 117 | if from_unit not in units: 118 | raise ValueError(f"From unit must be one of: {', '.join(units)}") 119 | 120 | if to_unit is None: 121 | current_size = scale_size(size, from_unit, "B") 122 | unit_index = 0 123 | 124 | while current_size >= 1024 and unit_index < len(units) - 1: 125 | current_size /= 1024 126 | unit_index += 1 127 | 128 | return f"{current_size:.2f} {units[unit_index]}" 129 | else: 130 | to_unit = to_unit.upper() 131 | if to_unit not in units: 132 | raise ValueError(f"To unit must be one of: {', '.join(units)}") 133 | scaled_size = scale_size(size, from_unit, to_unit) 134 | return f"{scaled_size:.2f} {to_unit}" 135 | 136 | 137 | def analyze_directory( 138 | root_path: Union[str, Path], 139 | exclude_dirs: Optional[List[str]] = None, 140 | min_file_count: int = 1, 141 | log_func=None, 142 | ) -> Dict[str, Any]: 143 | if exclude_dirs is None: 144 | exclude_dirs = [] 145 | 146 | path_obj = Path(root_path) 147 | result = { 148 | "name": path_obj.name or str(path_obj), 149 | "path": str(path_obj), 150 | "subdirs": [], 151 | "excluded_dirs": [], 152 | } 153 | 154 | size, count = get_dir_size(path_obj, exclude_dirs) 155 | result["size"] = size 156 | result["count"] = count 157 | 158 | try: 159 | subdirs = [d for d in path_obj.iterdir() if d.is_dir()] 160 | 161 | for subdir in sorted(subdirs): 162 | if subdir.name in exclude_dirs: 163 | _, excluded_count = get_dir_size(subdir, []) 164 | if excluded_count < min_file_count: 165 | continue 166 | 167 | excluded_data = analyze_directory(subdir, [], min_file_count, log_func) 168 | excluded_data["excluded"] = True 169 | result["excluded_dirs"].append(excluded_data) 170 | else: 171 | subdir_data = analyze_directory( 172 | subdir, exclude_dirs, min_file_count, log_func 173 | ) 174 | if subdir_data["count"] < min_file_count: 175 | continue 176 | 177 | result["subdirs"].append(subdir_data) 178 | except (PermissionError, OSError) as e: 179 | error_msg = f"Error accessing {path_obj}: {e}" 180 | if log_func: 181 | log_func(error_msg) 182 | else: 183 | print(error_msg, file=sys.stderr) 184 | 185 | return result 186 | 187 | 188 | def print_directory_tree( 189 | tree_data: Dict[str, Any], 190 | indent: str = "", 191 | is_last: bool = True, 192 | prefix: str = "", 193 | log_func=None, 194 | ) -> None: 195 | if ( 196 | tree_data["count"] == 0 197 | and not tree_data["subdirs"] 198 | and not tree_data["excluded_dirs"] 199 | ): 200 | return 201 | 202 | if is_last: 203 | branch = "└── " 204 | next_indent = indent + " " 205 | else: 206 | branch = "├── " 207 | next_indent = indent + "│ " 208 | 209 | name = tree_data["name"] 210 | count = tree_data["count"] 211 | size = scale_size(tree_data["size"]) 212 | 213 | tree_line = f"{indent}{prefix}{branch}[{name}] - {count} files, {size}" 214 | if log_func: 215 | log_func(tree_line) 216 | else: 217 | print(tree_line) 218 | 219 | num_subdirs = len(tree_data["subdirs"]) 220 | 221 | for i, subdir in enumerate(tree_data["subdirs"]): 222 | is_subdir_last = (i == num_subdirs - 1) and not tree_data["excluded_dirs"] 223 | print_directory_tree(subdir, next_indent, is_subdir_last, "", log_func) 224 | 225 | for i, excluded in enumerate(tree_data["excluded_dirs"]): 226 | is_excluded_last = i == len(tree_data["excluded_dirs"]) - 1 227 | print_directory_tree( 228 | excluded, next_indent, is_excluded_last, "(SOURCE) ", log_func 229 | ) 230 | 231 | 232 | def is_file_older_than(file_path: Union[str, Path], seconds: float = 1.0) -> bool: 233 | """Check if a file's last modification time is older than specified seconds.""" 234 | try: 235 | mtime = Path(file_path).stat().st_mtime 236 | return (time.time() - mtime) >= seconds 237 | except (FileNotFoundError, PermissionError): 238 | return False 239 | 240 | 241 | def has_stable_size(file_path: Union[str, Path], wait_time: float = 0.1) -> bool: 242 | """Check if a file's size is stable (not changing).""" 243 | path = Path(file_path) 244 | try: 245 | size1 = path.stat().st_size 246 | time.sleep(wait_time) 247 | size2 = path.stat().st_size 248 | return size1 == size2 249 | except (FileNotFoundError, PermissionError): 250 | return False 251 | 252 | 253 | def is_file_locked(file_path: Union[str, Path]) -> bool: 254 | """Check if a file is locked (being written to by another process).""" 255 | try: 256 | with open(file_path, "rb+") as _: 257 | pass 258 | return False 259 | except (PermissionError, OSError): 260 | return True 261 | 262 | 263 | def is_file_ready( 264 | file_path: Union[str, Path], 265 | min_age_seconds: float = 1.0, 266 | check_size_stability: bool = False, 267 | check_file_lock: bool = True, 268 | stability_wait_time: float = 0.1, 269 | ) -> bool: 270 | """ 271 | Determine if a file is ready for processing (not being downloaded/written to). 272 | 273 | Args: 274 | file_path: Path to the file to check 275 | min_age_seconds: Minimum age in seconds since last modification 276 | check_size_stability: Whether to check if file size is stable 277 | check_file_lock: Whether to check if file is locked by another process 278 | stability_wait_time: Time to wait when checking size stability 279 | 280 | Returns: 281 | bool: True if the file appears ready for processing 282 | """ 283 | file_path = Path(file_path) if isinstance(file_path, str) else file_path 284 | 285 | if not file_path.exists() or not file_path.is_file(): 286 | return False 287 | 288 | if not is_file_older_than(file_path, min_age_seconds): 289 | return False 290 | 291 | if check_size_stability and not has_stable_size(file_path, stability_wait_time): 292 | return False 293 | 294 | if check_file_lock and is_file_locked(file_path): 295 | return False 296 | 297 | return True 298 | 299 | 300 | def filter_ready_files( 301 | file_list: List[Union[str, Path]], **kwargs 302 | ) -> List[Union[str, Path]]: 303 | """ 304 | Filter a list of files to only include those that are ready for processing. 305 | 306 | Args: 307 | file_list: List of file paths 308 | **kwargs: Additional arguments to pass to is_file_ready() 309 | 310 | Returns: 311 | list: Filtered list containing only ready files 312 | """ 313 | return [f for f in file_list if is_file_ready(f, **kwargs)] 314 | 315 | 316 | async def wait_for_downloads_to_complete( 317 | files: List[Path], min_age_seconds: float = 2.0, timeout_seconds: int = 180 318 | ) -> bool: 319 | if not files: 320 | return True 321 | 322 | start_time = time.time() 323 | 324 | while time.time() - start_time < timeout_seconds: 325 | ready_files = filter_ready_files( 326 | file_list=files, min_age_seconds=min_age_seconds 327 | ) 328 | if len(ready_files) == len(files): 329 | return True 330 | # yield to event loop 331 | await asyncio.sleep(5) 332 | 333 | bt.logging.error(f"Timeout waiting for {files} after {timeout_seconds} seconds") 334 | return False 335 | -------------------------------------------------------------------------------- /bitmind/cache/util/video.py: -------------------------------------------------------------------------------- 1 | import json 2 | import subprocess 3 | import math 4 | import ffmpeg 5 | from typing import Dict, Any, Optional, Union, Tuple 6 | 7 | 8 | def get_video_duration(video_path: str) -> float: 9 | """Get the duration of a video file in seconds. 10 | 11 | Args: 12 | video_path: Path to the video file 13 | 14 | Returns: 15 | Duration in seconds 16 | 17 | Raises: 18 | Exception: If the duration cannot be determined 19 | """ 20 | try: 21 | probe = ffmpeg.probe(video_path) 22 | duration = float(probe["format"]["duration"]) 23 | return duration 24 | except Exception as e: 25 | try: 26 | result = subprocess.run( 27 | [ 28 | "ffprobe", 29 | "-v", 30 | "error", 31 | "-show_entries", 32 | "format=duration", 33 | "-of", 34 | "json", 35 | video_path, 36 | ], 37 | stdout=subprocess.PIPE, 38 | stderr=subprocess.PIPE, 39 | text=True, 40 | ) 41 | data = json.loads(result.stdout) 42 | duration = float(data["format"]["duration"]) 43 | return duration 44 | except Exception as sub_e: 45 | raise Exception(f"Failed to get video duration: {e}, {sub_e}") 46 | 47 | 48 | def get_video_metadata(video_path: str, max_fps: float = 30.0) -> Dict[str, Any]: 49 | """Get comprehensive metadata from a video file with sanity checks. 50 | 51 | Args: 52 | video_path: Path to the video file 53 | max_fps: Maximum reasonable FPS value (default: 60.0) 54 | 55 | Returns: 56 | Dictionary containing metadata with sanity-checked values 57 | """ 58 | try: 59 | ffprobe_fields = ( 60 | "format=duration,size,bit_rate,format_name:" 61 | "stream=width,height,codec_name,codec_type," 62 | "r_frame_rate,avg_frame_rate,pix_fmt,sample_rate,channels" 63 | ) 64 | result = subprocess.run( 65 | [ 66 | "ffprobe", 67 | "-v", 68 | "error", 69 | "-show_entries", 70 | ffprobe_fields, 71 | "-of", 72 | "json", 73 | video_path, 74 | ], 75 | stdout=subprocess.PIPE, 76 | stderr=subprocess.PIPE, 77 | text=True, 78 | check=True, # This will raise CalledProcessError if ffprobe fails 79 | ) 80 | 81 | data = json.loads(result.stdout) 82 | 83 | # Extract basic format information 84 | format_info = data.get("format", {}) 85 | streams = data.get("streams", []) 86 | 87 | # Find video and audio streams 88 | video_stream = next( 89 | (s for s in streams if s.get("codec_type") == "video"), None 90 | ) 91 | audio_stream = next( 92 | (s for s in streams if s.get("codec_type") == "audio"), None 93 | ) 94 | 95 | # Build base metadata 96 | metadata = { 97 | "duration": float(format_info.get("duration", 0)), 98 | "size_bytes": int(format_info.get("size", 0)), 99 | "bit_rate": ( 100 | int(format_info.get("bit_rate", 0)) 101 | if "bit_rate" in format_info 102 | else None 103 | ), 104 | "format": format_info.get("format_name"), 105 | "has_video": video_stream is not None, 106 | "has_audio": audio_stream is not None, 107 | } 108 | 109 | # Add video stream details if present 110 | if video_stream: 111 | fps, fps_corrected, original_fps = _get_sanitized_fps(video_stream, max_fps) 112 | 113 | metadata.update( 114 | { 115 | "fps": fps, 116 | "width": int(video_stream.get("width", 0)), 117 | "height": int(video_stream.get("height", 0)), 118 | "codec": video_stream.get("codec_name"), 119 | "pix_fmt": video_stream.get("pix_fmt"), 120 | } 121 | ) 122 | 123 | if fps_corrected: 124 | metadata["original_fps"] = original_fps 125 | metadata["fps_corrected"] = True 126 | 127 | # Add audio stream details if present 128 | if audio_stream: 129 | metadata.update( 130 | { 131 | "audio_codec": audio_stream.get("codec_name"), 132 | "sample_rate": audio_stream.get("sample_rate"), 133 | "channels": int(audio_stream.get("channels", 0)), 134 | } 135 | ) 136 | 137 | return metadata 138 | 139 | except subprocess.CalledProcessError as e: 140 | return _create_error_metadata(f"ffprobe process failed: {e.stderr.strip()}") 141 | except json.JSONDecodeError: 142 | return _create_error_metadata("Failed to parse ffprobe output as JSON") 143 | except Exception as e: 144 | return _create_error_metadata(f"Unexpected error: {str(e)}") 145 | 146 | 147 | def _get_sanitized_fps( 148 | video_stream: Dict[str, Any], max_fps: float = 60.0 149 | ) -> Tuple[float, bool, Optional[float]]: 150 | """Parse and sanitize frame rate from video stream information. 151 | 152 | Returns: 153 | Tuple of (sanitized_fps, was_corrected, original_fps_if_corrected) 154 | """ 155 | original_fps = None 156 | fps_corrected = False 157 | 158 | # Try r_frame_rate first (usually more accurate) 159 | fps = _parse_frame_rate_string(video_stream.get("r_frame_rate")) 160 | 161 | # Fall back to avg_frame_rate if needed 162 | if fps is None: 163 | fps = _parse_frame_rate_string(video_stream.get("avg_frame_rate")) 164 | 165 | # Save original before correction 166 | if fps is not None: 167 | original_fps = fps 168 | 169 | # Sanity check and correct if needed 170 | if fps is None or not (0 < fps <= max_fps) or not math.isfinite(fps): 171 | fps_corrected = True 172 | fps = 30.0 # Default to a standard frame rate 173 | 174 | return fps, fps_corrected, original_fps if fps_corrected else None 175 | 176 | 177 | def _parse_frame_rate_string(frame_rate_str: Optional[str]) -> Optional[float]: 178 | """Safely parse a frame rate string in format 'num/den'.""" 179 | if not frame_rate_str: 180 | return None 181 | 182 | try: 183 | if "/" in frame_rate_str: 184 | num, den = frame_rate_str.split("/") 185 | num, den = float(num), float(den) 186 | if den <= 0: # Avoid division by zero 187 | return None 188 | return num / den 189 | else: 190 | # Handle case where frame rate is just a number 191 | return float(frame_rate_str) 192 | except (ValueError, ZeroDivisionError): 193 | return None 194 | 195 | 196 | def _create_error_metadata(error_message: str) -> Dict[str, Any]: 197 | """Create a metadata dictionary for error cases.""" 198 | return { 199 | "duration": 0, 200 | "has_video": False, 201 | "has_audio": False, 202 | "error": error_message, 203 | } 204 | 205 | 206 | def seconds_to_str(seconds): 207 | """Convert seconds to formatted time string (HH:MM:SS).""" 208 | seconds = int(float(seconds)) 209 | hours = seconds // 3600 210 | minutes = (seconds % 3600) // 60 211 | seconds = seconds % 60 212 | return f"{hours:02}:{minutes:02}:{seconds:02}" 213 | -------------------------------------------------------------------------------- /bitmind/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import bittensor as bt 3 | 4 | MAINNET_UID = 34 5 | 6 | 7 | def validate_config_and_neuron_path(config): 8 | r"""Checks/validates the config namespace object.""" 9 | full_path = os.path.expanduser( 10 | "{}/{}/{}/netuid{}/{}".format( 11 | config.logging.logging_dir, 12 | config.wallet.name, 13 | config.wallet.hotkey, 14 | config.netuid, 15 | config.neuron.name, 16 | ) 17 | ) 18 | bt.logging.info(f"Logging path: {full_path}") 19 | config.neuron.full_path = os.path.expanduser(full_path) 20 | if not os.path.exists(config.neuron.full_path): 21 | os.makedirs(config.neuron.full_path, exist_ok=True) 22 | return config 23 | 24 | 25 | def add_args(parser): 26 | """ 27 | Adds relevant arguments to the parser for operation. 28 | """ 29 | parser.add_argument("--netuid", type=int, help="Subnet netuid", default=34) 30 | 31 | parser.add_argument( 32 | "--neuron.name", 33 | type=str, 34 | help="Neuron Name", 35 | default="bitmind", 36 | ) 37 | 38 | parser.add_argument( 39 | "--epoch-length", 40 | type=int, 41 | help="The default epoch length (how often we set weights, measured in 12 second blocks).", 42 | default=360, 43 | ) 44 | 45 | parser.add_argument( 46 | "--mock", 47 | action="store_true", 48 | help="Run in mock mode", 49 | default=False, 50 | ) 51 | 52 | parser.add_argument( 53 | "--autoupdate-off", 54 | action="store_false", 55 | dest="autoupdate", 56 | help="Disable automatic updates on latest version on Main.", 57 | default=True, 58 | ) 59 | 60 | parser.add_argument("--wandb.entity", type=str, default="bitmindai") 61 | 62 | parser.add_argument("--wandb.off", action="store_true", default=False) 63 | 64 | 65 | def add_miner_args(parser): 66 | """Add miner specific arguments to the parser.""" 67 | 68 | parser.add_argument( 69 | "--no-force-validator-permit", 70 | action="store_true", 71 | help="If set, we will not force incoming requests to have a permit.", 72 | default=False, 73 | ) 74 | 75 | parser.add_argument( 76 | "--device", 77 | type=str, 78 | default="cpu", 79 | help="Device to use for detection models (cuda/cpu)", 80 | ) 81 | 82 | 83 | def add_validator_args(parser): 84 | """Add validator specific arguments to the parser.""" 85 | 86 | parser.add_argument( 87 | "--vpermit-tao-limit", 88 | type=int, 89 | help="The maximum number of TAO allowed to query a validator with a vpermit.", 90 | default=20000, 91 | ) 92 | 93 | parser.add_argument( 94 | "--compressed-cache-update-interval", 95 | type=int, 96 | help="How often to download new zip/parquet files, measured in 12 second blocks", 97 | default=720, 98 | ) 99 | 100 | parser.add_argument( 101 | "--media-cache-update-interval", 102 | type=int, 103 | help="How often to unpack random media files, measured in 12 second blocks", 104 | default=300, 105 | ) 106 | 107 | parser.add_argument( 108 | "--challenge-interval", 109 | type=int, 110 | help="How often we set challenge miners, measured in 12 second blocks.", 111 | default=5, 112 | ) 113 | 114 | parser.add_argument( 115 | "--wandb-restart-interval", 116 | type=int, 117 | help="How often we restart wandb run to avoid log truncation", 118 | default=2000, 119 | ) 120 | 121 | parser.add_argument( 122 | "--cache.base-dir", 123 | type=str, 124 | default=os.path.expanduser("~/.cache/sn34"), 125 | help="Base directory for cache storage", 126 | ) 127 | 128 | parser.add_argument( 129 | "--cache.max-compressed-gb", 130 | type=float, 131 | default=50.0, 132 | help="Maximum size in GB for compressed cache", 133 | ) 134 | 135 | parser.add_argument( 136 | "--cache.max-media-gb", 137 | type=float, 138 | default=5.0, 139 | help="Maximum size in GB for media cache", 140 | ) 141 | 142 | parser.add_argument( 143 | "--cache.media-files-per-source", 144 | type=int, 145 | default=50, 146 | help="Number of media files to keep per source", 147 | ) 148 | 149 | parser.add_argument( 150 | "--neuron.max-state-backup-hours", 151 | type=float, 152 | help="The oldest backup of validator state to load in the case of a failure to load most recent", 153 | default=1, 154 | ) 155 | 156 | parser.add_argument( 157 | "--neuron.miner-total-timeout", 158 | type=float, 159 | help="Total timeout for miner requests in seconds", 160 | default=11.0, 161 | ) 162 | 163 | parser.add_argument( 164 | "--neuron.miner-connect-timeout", 165 | type=float, 166 | help="TCP connection timeout for miner requests in seconds", 167 | default=4.0, 168 | ) 169 | 170 | parser.add_argument( 171 | "--neuron.miner-sock-connect-timeout", 172 | type=float, 173 | help="Socket connection timeout for miner requests in seconds", 174 | default=3.0, 175 | ) 176 | 177 | parser.add_argument( 178 | "--neuron.heartbeat", 179 | action="store_true", 180 | help="Run validator heartbeat thread", 181 | default=False, 182 | ) 183 | 184 | parser.add_argument( 185 | "--neuron.heartbeat-interval-seconds", 186 | type=float, 187 | help="Interval between heartbeat checks in seconds", 188 | default=60.0, 189 | ) 190 | 191 | parser.add_argument( 192 | "--neuron.lock-sleep-seconds", 193 | type=float, 194 | help="Sleep duration when lock is held in seconds", 195 | default=5.0, 196 | ) 197 | 198 | parser.add_argument( 199 | "--neuron.max-stuck-count", 200 | type=int, 201 | help="Number of consecutive heartbeats with no progress before restart", 202 | default=5, 203 | ) 204 | 205 | parser.add_argument( 206 | "--neuron.sample-size", 207 | type=int, 208 | help="Number of miners to query per challenge", 209 | default=50, 210 | ) 211 | 212 | parser.add_argument( 213 | "--scoring.moving-average-alpha", 214 | type=float, 215 | help="Alpha for miner score EMA", 216 | default=0.05, 217 | ) 218 | 219 | parser.add_argument( 220 | "--scoring.image-weight", 221 | type=float, 222 | help="Weight for image modality scoring", 223 | default=0.6, 224 | ) 225 | 226 | parser.add_argument( 227 | "--scoring.video-weight", 228 | type=float, 229 | help="Weight for video modality scoring", 230 | default=0.4, 231 | ) 232 | 233 | parser.add_argument( 234 | "--scoring.binary-weight", 235 | type=float, 236 | help="Weight for binary classification scoring", 237 | default=0.75, 238 | ) 239 | 240 | parser.add_argument( 241 | "--scoring.multiclass-weight", 242 | type=float, 243 | help="Weight for multiclass classification scoring", 244 | default=0.25, 245 | ) 246 | 247 | parser.add_argument( 248 | "--challenge.image-prob", 249 | type=float, 250 | help="Probability of selecting image modality for challenges", 251 | default=0.5, 252 | ) 253 | 254 | parser.add_argument( 255 | "--challenge.video-prob", 256 | type=float, 257 | help="Probability of selecting video modality for challenges", 258 | default=0.5, 259 | ) 260 | 261 | parser.add_argument( 262 | "--challenge.real-prob", 263 | type=float, 264 | help="Probability of selecting real media for challenges", 265 | default=0.5, 266 | ) 267 | 268 | parser.add_argument( 269 | "--challenge.synthetic-prob", 270 | type=float, 271 | help="Probability of selecting synthetic media for challenges", 272 | default=0.3, 273 | ) 274 | 275 | parser.add_argument( 276 | "--challenge.semisynthetic-prob", 277 | type=float, 278 | help="Probability of selecting semisynthetic media for challenges", 279 | default=0.2, 280 | ) 281 | 282 | parser.add_argument( 283 | "--challenge.multi-video-prob", 284 | type=float, 285 | help="Probability of stitching together two videos of the same media type", 286 | default=0.2, 287 | ) 288 | 289 | parser.add_argument( 290 | "--challenge.min-clip-duration", 291 | type=float, 292 | help="Minimum video clip duration in seconds", 293 | default=1.0, 294 | ) 295 | 296 | parser.add_argument( 297 | "--challenge.max-clip-duration", 298 | type=float, 299 | help="Maximum video clip duration in seconds", 300 | default=6.0, 301 | ) 302 | 303 | parser.add_argument( 304 | "--challenge.max-frames", 305 | type=int, 306 | help="Maximum number of video frames to sample for a challenge", 307 | default=144, 308 | ) 309 | 310 | 311 | def add_data_generator_args(parser): 312 | parser.add_argument( 313 | "--cache-dir", 314 | type=str, 315 | default=os.path.expanduser("~/.cache/sn34"), 316 | help="Directory for caching data", 317 | ) 318 | 319 | parser.add_argument( 320 | "--batch-size", type=int, default=3, help="Batch size for generation" 321 | ) 322 | 323 | parser.add_argument( 324 | "--tasks", 325 | nargs="+", 326 | choices=["t2v", "t2i", "i2i", "i2v"], 327 | default=["t2v", "t2i", "i2i", "i2v"], 328 | help="List of tasks to run (t2v, t2i, i2i, i2v). Defaults to all.", 329 | ) 330 | 331 | parser.add_argument( 332 | "--device", 333 | type=str, 334 | default="cuda", 335 | help="Device to use for generation (cuda/cpu)", 336 | ) 337 | 338 | parser.add_argument( 339 | "--wandb.num-batches-per-run", 340 | type=int, 341 | default=50, 342 | help="Number of batches to generate before starting new W&B run (avoids log truncation)", 343 | ) 344 | 345 | parser.add_argument("--wandb.process-name", type=str, default="generator") 346 | 347 | 348 | def add_proxy_args(parser): 349 | parser.add_argument( 350 | "--proxy.sample-size", 351 | type=int, 352 | default=50, 353 | help="Number of miners to query for organics", 354 | ) 355 | 356 | parser.add_argument( 357 | "--proxy.client-url", 358 | type=str, 359 | default="https://subnet-api.bitmindlabs.ai", 360 | help="URL for the proxy client authentication service", 361 | ) 362 | 363 | parser.add_argument( 364 | "--proxy.host", 365 | type=str, 366 | default="0.0.0.0", 367 | help="Network interface to listen on", 368 | ) 369 | 370 | parser.add_argument( 371 | "--proxy.port", 372 | type=int, 373 | default=10913, 374 | help="Port for the proxy server", 375 | ) 376 | 377 | parser.add_argument( 378 | "--proxy.external_port", 379 | type=int, 380 | default=10913, 381 | help="Port for the proxy server", 382 | ) 383 | 384 | parser.add_argument( 385 | "--proxy.sample_size", 386 | type=int, 387 | default=50, 388 | help="Number of miners to query for organics", 389 | ) 390 | 391 | parser.add_argument( 392 | "--miner-health-sync-interval", 393 | type=int, 394 | default=2, 395 | help="How frequently to check miner health (in blocks)", 396 | ) 397 | 398 | -------------------------------------------------------------------------------- /bitmind/encoding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import ffmpeg 4 | import os 5 | import tempfile 6 | from typing import List 7 | from io import BytesIO 8 | from PIL import Image 9 | 10 | 11 | def image_to_bytes(img): 12 | """Convert image array to bytes using JPEG encoding with PIL. 13 | Args: 14 | img (np.ndarray): Image array of shape (C, H, W) or (H, W, C) 15 | Can be float32 [0,1] or uint8 [0,255] 16 | Returns: 17 | bytes: JPEG encoded image bytes 18 | str: Content type 'image/jpeg' 19 | """ 20 | # Convert float32 [0,1] to uint8 [0,255] if needed 21 | if img.dtype == np.float32: 22 | img = (img * 255).astype(np.uint8) 23 | elif img.dtype != np.uint8: 24 | raise ValueError(f"Image must be float32 or uint8, got {img.dtype}") 25 | 26 | if img.shape[0] == 3 and len(img.shape) == 3: # If in CHW format 27 | img = np.transpose(img, (1, 2, 0)) # CHW to HWC 28 | 29 | # Ensure we have a 3-channel image (H,W,3) 30 | if len(img.shape) == 2: 31 | # Convert grayscale to RGB 32 | img = np.stack([img, img, img], axis=2) 33 | elif img.shape[2] == 1: 34 | # Convert single channel to RGB 35 | img = np.concatenate([img, img, img], axis=2) 36 | elif img.shape[2] == 4: 37 | # Drop alpha channel 38 | img = img[:, :, :3] 39 | elif img.shape[2] != 3: 40 | raise ValueError(f"Expected 1, 3 or 4 channels, got {img.shape[2]}") 41 | 42 | pil_img = Image.fromarray(img) 43 | if pil_img.mode != "RGB": 44 | pil_img = pil_img.convert("RGB") 45 | 46 | buffer = BytesIO() 47 | pil_img.save(buffer, format="JPEG", quality=75) 48 | buffer.seek(0) 49 | 50 | return buffer.getvalue(), "image/jpeg" 51 | 52 | 53 | def video_to_bytes(video: np.ndarray, fps: int | None = None) -> tuple[bytes, str]: 54 | """ 55 | Convert a (T, H, W, C) uint8/float32 video to MP4, but *first* pass each frame 56 | through Pillow JPEG → adds normal JPEG artefacts, then encodes losslessly. 57 | 58 | Returns: 59 | bytes: In‑memory MP4 file. 60 | str: MIME‑type ("video/mp4"). 61 | """ 62 | # ------------- 0. validation / normalisation ------------------------------- 63 | if video.dtype == np.float32: 64 | assert video.max() <= 1.0, video.max() 65 | video = (video * 255).clip(0, 255).astype(np.uint8) 66 | elif video.dtype != np.uint8: 67 | raise ValueError(f"Unsupported dtype: {video.dtype}") 68 | 69 | fps = fps or 30 70 | 71 | # TCHW → THWC 72 | if video.shape[1] <= 4 and video.shape[3] > 4: 73 | video = np.transpose(video, (0, 2, 3, 1)) 74 | 75 | if video.ndim != 4 or video.shape[3] not in (1, 3): 76 | raise ValueError(f"Expected shape (T, H, W, C), got {video.shape}") 77 | 78 | T, H, W, C = video.shape 79 | 80 | # ------------- 1. apply Pillow JPEG to every frame ------------------------- 81 | jpeg_degraded_frames: List[np.ndarray] = [] 82 | for idx, frame in enumerate(video): 83 | buf = BytesIO() 84 | Image.fromarray(frame).save( 85 | buf, 86 | format="JPEG", 87 | quality=75, 88 | subsampling=2, # 0=4:4:4, 1=4:2:2, 2=4:2:0 (Pillow default = 2) 89 | optimize=False, 90 | progressive=False, 91 | ) 92 | buf.seek(0) 93 | # decode back to RGB so FFmpeg sees the artefact‑laden pixels 94 | degraded = np.array(Image.open(buf).convert("RGB"), dtype=np.uint8) 95 | if degraded.shape != (H, W, 3): 96 | raise ValueError(f"Decoded shape mismatch at frame {idx}: {degraded.shape}") 97 | jpeg_degraded_frames.append(degraded) 98 | 99 | degraded_video = np.stack(jpeg_degraded_frames, axis=0) # (T,H,W,3) 100 | 101 | # ------------- 2. write raw RGB + encode losslessly ------------------------ 102 | with tempfile.TemporaryDirectory() as tmpdir: 103 | raw_path = os.path.join(tmpdir, "input.raw") 104 | video_path = os.path.join(tmpdir, "output.mp4") 105 | 106 | degraded_video.tofile(raw_path) # write as one big rawvideo blob 107 | 108 | try: 109 | ( 110 | ffmpeg.input( 111 | raw_path, 112 | format="rawvideo", 113 | pix_fmt="rgb24", 114 | s=f"{W}x{H}", 115 | r=fps, 116 | ) 117 | .output( 118 | video_path, 119 | vcodec="libx264rgb", 120 | crf=0, # mathematically lossless 121 | preset="veryfast", 122 | pix_fmt="rgb24", 123 | movflags="+faststart", 124 | ) 125 | .global_args("-y", "-hide_banner", "-loglevel", "error") 126 | .run() 127 | ) 128 | except ffmpeg.Error as e: 129 | raise RuntimeError( 130 | f"FFmpeg encoding failed:\n{e.stderr.decode(errors='ignore')}" 131 | ) from e 132 | 133 | with open(video_path, "rb") as f: 134 | video_bytes = f.read() 135 | 136 | return video_bytes, "video/mp4" 137 | 138 | 139 | def media_to_bytes(media, fps=30): 140 | """Convert image or video array to bytes, using PNG encoding for both. 141 | 142 | Args: 143 | media (np.ndarray): Either: 144 | - Image array of shape (C, H, W) 145 | - Video array of shape (T, C, H, W) 146 | Can be float32 [0,1] or uint8 [0,255] 147 | fps (int): Frames per second for video (default: 30) 148 | 149 | Returns: 150 | bytes: Encoded media bytes 151 | str: Content type (either 'image/png' or 'video/avi') 152 | """ 153 | if len(media.shape) == 3: # Image 154 | return image_to_bytes(media) 155 | elif len(media.shape) == 4: # Video 156 | return video_to_bytes(media, fps) 157 | else: 158 | raise ValueError( 159 | f"Invalid media shape: {media.shape}. Expected (C,H,W) for image or (T,C,H,W) for video." 160 | ) 161 | -------------------------------------------------------------------------------- /bitmind/epistula.py: -------------------------------------------------------------------------------- 1 | import json 2 | from hashlib import sha256 3 | from uuid import uuid4 4 | from math import ceil 5 | from typing import Annotated, Any, Dict, Optional 6 | 7 | import bittensor as bt 8 | import numpy as np 9 | import asyncio 10 | import ast 11 | import time 12 | import httpx 13 | import aiohttp 14 | from substrateinterface import Keypair 15 | 16 | from bitmind.types import Modality 17 | 18 | 19 | EPISTULA_VERSION = str(2) 20 | 21 | 22 | def generate_header( 23 | hotkey: Keypair, 24 | body: Any, 25 | signed_for: Optional[str] = None, 26 | ) -> Dict[str, Any]: 27 | timestamp = round(time.time() * 1000) 28 | timestampInterval = ceil(timestamp / 1e4) * 1e4 29 | uuid = str(uuid4()) 30 | req_hash = None 31 | if isinstance(body, bytes): 32 | req_hash = sha256(body).hexdigest() 33 | else: 34 | req_hash = sha256(json.dumps(body).encode("utf-8")).hexdigest() 35 | 36 | headers = { 37 | "Epistula-Version": EPISTULA_VERSION, 38 | "Epistula-Timestamp": str(timestamp), 39 | "Epistula-Uuid": uuid, 40 | "Epistula-Signed-By": hotkey.ss58_address, 41 | "Epistula-Request-Signature": "0x" 42 | + hotkey.sign(f"{req_hash}.{uuid}.{timestamp}.{signed_for or ''}").hex(), 43 | } 44 | if signed_for: 45 | headers["Epistula-Signed-For"] = signed_for 46 | headers["Epistula-Secret-Signature-0"] = ( 47 | "0x" + hotkey.sign(str(timestampInterval - 1) + "." + signed_for).hex() 48 | ) 49 | headers["Epistula-Secret-Signature-1"] = ( 50 | "0x" + hotkey.sign(str(timestampInterval) + "." + signed_for).hex() 51 | ) 52 | headers["Epistula-Secret-Signature-2"] = ( 53 | "0x" + hotkey.sign(str(timestampInterval + 1) + "." + signed_for).hex() 54 | ) 55 | return headers 56 | 57 | 58 | def verify_signature( 59 | signature, body: bytes, timestamp, uuid, signed_for, signed_by, now 60 | ) -> Optional[Annotated[str, "Error Message"]]: 61 | if not isinstance(signature, str): 62 | return "Invalid Signature" 63 | timestamp = int(timestamp) 64 | if not isinstance(timestamp, int): 65 | return "Invalid Timestamp" 66 | if not isinstance(signed_by, str): 67 | return "Invalid Sender key" 68 | if not isinstance(signed_for, str): 69 | return "Invalid receiver key" 70 | if not isinstance(uuid, str): 71 | return "Invalid uuid" 72 | if not isinstance(body, bytes): 73 | return "Body is not of type bytes" 74 | ALLOWED_DELTA_MS = 8000 75 | keypair = Keypair(ss58_address=signed_by) 76 | if timestamp + ALLOWED_DELTA_MS < now: 77 | return "Request is too stale" 78 | message = f"{sha256(body).hexdigest()}.{uuid}.{timestamp}.{signed_for}" 79 | verified = keypair.verify(message, signature) 80 | if not verified: 81 | return "Signature Mismatch" 82 | return None 83 | 84 | 85 | def create_header_hook(hotkey, axon_hotkey, model): 86 | async def add_headers(request: httpx.Request): 87 | for key, header in generate_header(hotkey, request.read(), axon_hotkey).items(): 88 | request.headers[key] = header 89 | 90 | return add_headers 91 | 92 | 93 | async def query_miner( 94 | uid: int, 95 | media: bytes, 96 | content_type: str, 97 | modality: Modality, 98 | axon_info: bt.AxonInfo, 99 | session: aiohttp.ClientSession, 100 | hotkey: bt.Keypair, 101 | total_timeout: float, 102 | connect_timeout: Optional[float] = None, 103 | sock_connect_timeout: Optional[float] = None, 104 | testnet_metadata: dict = None, 105 | ) -> Dict[str, Any]: 106 | """ 107 | Query a miner with media data. 108 | 109 | Args: 110 | uid: miner uid 111 | media: encoded media 112 | content_type: determined by media_to_bytes 113 | modality: Type of media ('image' or 'video') 114 | axon_info: miner AxonInfo 115 | session: aiohttp client session 116 | hotkey: validator hotkey Keypair for signing the request 117 | total_timeout: Total timeout for the request 118 | connect_timeout: Connection timeout 119 | sock_connect_timeout: Socket connection timeout 120 | 121 | Returns: 122 | Dictionary containing the response. 123 | prediction field will be None if any error is encountered, including 124 | if the response contains a prediction that doesn't sum to ~1. 125 | """ 126 | response = { 127 | "uid": uid, 128 | "hotkey": axon_info.hotkey, 129 | "status": 500, 130 | "prediction": None, 131 | "error": "", 132 | } 133 | 134 | try: 135 | 136 | url = f"http://{axon_info.ip}:{axon_info.port}/detect_{modality}" 137 | headers = generate_header(hotkey, media, axon_info.hotkey) 138 | 139 | headers = { 140 | "Content-Type": content_type, 141 | "X-Media-Type": modality, 142 | **headers, 143 | } 144 | 145 | if testnet_metadata: 146 | testnet_headers = {f"X-Testnet-{k}": str(v) for k, v in testnet_metadata.items()} 147 | headers.update(testnet_headers) 148 | 149 | async with session.post( 150 | url, 151 | headers=headers, 152 | data=media, 153 | timeout=aiohttp.ClientTimeout( 154 | total=total_timeout, 155 | connect=connect_timeout, 156 | sock_connect=sock_connect_timeout, 157 | ), 158 | ) as res: 159 | response["status"] = res.status 160 | if res.status != 200: 161 | response["error"] = f"HTTP {res.status} error" 162 | return response 163 | try: 164 | data = await res.json() 165 | if "prediction" not in data: 166 | response["error"] = "Missing prediction in response" 167 | return response 168 | 169 | pred = [float(p) for p in data["prediction"]] 170 | 171 | # handle binary predictions, assume [real, fake] 172 | if len(pred) == 2: 173 | pred = pred + [0.0] 174 | 175 | pred = np.array(pred) 176 | 177 | # error on predictions that don't sum to ~1 or contain values outside of [0., 1.] 178 | if abs(sum(pred) - 1.0) > 1e-6 or np.any((pred < 0.0) | (pred > 1.0)): 179 | raise ValueError 180 | 181 | response["prediction"] = pred 182 | return response 183 | 184 | except json.JSONDecodeError: 185 | response["error"] = "Failed to decode JSON response" 186 | return response 187 | 188 | except (TypeError, ValueError) as e: 189 | response["error"] = ( 190 | f"Invalid prediction value {data.get('prediction')}" 191 | ) 192 | return response 193 | 194 | except asyncio.TimeoutError: 195 | response["status"] = 408 196 | response["error"] = "Request timed out" 197 | except aiohttp.ClientConnectorError as e: 198 | response["status"] = 503 199 | response["error"] = f"Connection error: {str(e)}" 200 | except aiohttp.ClientError as e: 201 | response["status"] = 400 202 | response["error"] = f"Network error: {str(e)}" 203 | except Exception as e: 204 | response["error"] = f"Unknown error: {str(e)}" 205 | 206 | return response 207 | -------------------------------------------------------------------------------- /bitmind/generation/__init__.py: -------------------------------------------------------------------------------- 1 | from .generation_pipeline import GenerationPipeline 2 | from .prompt_generator import PromptGenerator 3 | from .models import initialize_model_registry 4 | -------------------------------------------------------------------------------- /bitmind/generation/model_registry.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict, Union, Any, List 2 | import random 3 | 4 | from bitmind.types import ModelConfig, ModelTask 5 | 6 | 7 | class ModelRegistry: 8 | """ 9 | Registry for managing generative models. 10 | """ 11 | 12 | def __init__(self): 13 | self.models: Dict[str, ModelConfig] = {} 14 | 15 | def register(self, model_config: ModelConfig) -> None: 16 | self.models[model_config.path] = model_config 17 | 18 | def register_all(self, model_configs: List[ModelConfig]) -> None: 19 | for config in model_configs: 20 | self.register(config) 21 | 22 | def get_model(self, path: str) -> Optional[ModelConfig]: 23 | return self.models.get(path) 24 | 25 | def get_all_models(self) -> Dict[str, ModelConfig]: 26 | return self.models.copy() 27 | 28 | def get_models_by_task(self, task: ModelTask) -> Dict[str, ModelConfig]: 29 | return { 30 | path: config for path, config in self.models.items() if config.task == task 31 | } 32 | 33 | def get_model_names_by_task(self, task: ModelTask) -> Dict[str, ModelConfig]: 34 | return [path for path, config in self.models.items() if config.task == task] 35 | 36 | def get_models_by_tag(self, tag: str) -> Dict[str, ModelConfig]: 37 | return { 38 | path: config for path, config in self.models.items() if tag in config.tags 39 | } 40 | 41 | def get_model_names_by_task(self, task: ModelTask) -> List[str]: 42 | return list(self.get_models_by_task(task).keys()) 43 | 44 | @property 45 | def t2i_models(self) -> Dict[str, ModelConfig]: 46 | return self.get_models_by_task(ModelTask.TEXT_TO_IMAGE) 47 | 48 | @property 49 | def t2v_models(self) -> Dict[str, ModelConfig]: 50 | return self.get_models_by_task(ModelTask.TEXT_TO_VIDEO) 51 | 52 | @property 53 | def i2i_models(self) -> Dict[str, ModelConfig]: 54 | return self.get_models_by_task(ModelTask.IMAGE_TO_IMAGE) 55 | 56 | @property 57 | def i2v_models(self) -> List[str]: 58 | return self.get_models_by_task(ModelTask.IMAGE_TO_VIDEO) 59 | 60 | @property 61 | def t2i_model_names(self) -> List[str]: 62 | return list(self.t2i_models.keys()) 63 | 64 | @property 65 | def t2v_model_names(self) -> List[str]: 66 | return list(self.t2v_models.keys()) 67 | 68 | @property 69 | def i2i_model_names(self) -> List[str]: 70 | return list(self.i2i_models.keys()) 71 | 72 | @property 73 | def i2v_model_names(self) -> List[str]: 74 | return list(self.i2v_models.keys()) 75 | 76 | @property 77 | def model_names(self) -> List[str]: 78 | return list(self.models.keys()) 79 | 80 | def select_random_model(self, task: Optional[Union[ModelTask, str]] = None) -> str: 81 | if isinstance(task, str): 82 | task = ModelTask(task.lower()) 83 | 84 | if task is None: 85 | task = random.choice(list(ModelTask)) 86 | 87 | model_names = self.get_model_names_by_task(task) 88 | if not model_names: 89 | raise ValueError(f"No models available for task: {task}") 90 | 91 | return random.choice(model_names) 92 | 93 | def get_model_dict(self, model_name: str) -> Dict[str, Any]: 94 | model = self.get_model(model_name) 95 | if model is None: 96 | raise ValueError(f"Model not found: {model_name}") 97 | 98 | return model.to_dict() 99 | 100 | def get_interleaved_model_names(self, tasks=None) -> List[str]: 101 | from itertools import zip_longest 102 | 103 | model_names = [] 104 | if tasks is None: 105 | model_names = [ 106 | self.t2i_model_names, 107 | self.t2v_model_names, 108 | self.i2i_model_names, 109 | self.i2v_model_names, 110 | ] 111 | else: 112 | for task in tasks: 113 | model_names.append(self.get_model_names_by_task(task)) 114 | 115 | shuffled_model_names = ( 116 | random.sample(names, len(names)) for names in model_names 117 | ) 118 | return [ 119 | m 120 | for quad in zip_longest(*shuffled_model_names) 121 | for m in quad 122 | if m is not None 123 | ] 124 | 125 | def get_modality(self, model_name: str) -> str: 126 | model = self.get_model(model_name) 127 | if model is None: 128 | raise ValueError(f"Model not found: {model_name}") 129 | 130 | return "video" if model.task == ModelTask.TEXT_TO_VIDEO else "image" 131 | 132 | def get_task(self, model_name: str) -> str: 133 | model = self.get_model(model_name) 134 | if model is None: 135 | raise ValueError(f"Model not found: {model_name}") 136 | 137 | return model.task.value 138 | 139 | def get_output_media_type(self, model_name: str) -> str: 140 | model = self.get_model(model_name) 141 | if model is None: 142 | raise ValueError(f"Model not found: {model_name}") 143 | 144 | return model.media_type.value 145 | -------------------------------------------------------------------------------- /bitmind/generation/prompt_generator.py: -------------------------------------------------------------------------------- 1 | import re 2 | import gc 3 | import bittensor as bt 4 | import torch 5 | from PIL import Image 6 | from transformers import ( 7 | AutoModelForCausalLM, 8 | AutoTokenizer, 9 | Blip2ForConditionalGeneration, 10 | Blip2Processor, 11 | pipeline, 12 | ) 13 | 14 | 15 | class PromptGenerator: 16 | """ 17 | A class for generating and moderating image annotations using transformer models. 18 | 19 | This class provides functionality to generate descriptive captions for images 20 | using BLIP2 models and optionally moderate the generated text using a separate 21 | language model. 22 | """ 23 | 24 | def __init__( 25 | self, 26 | vlm_name: str, 27 | llm_name: str, 28 | device: str = "cuda", 29 | ) -> None: 30 | """ 31 | Initialize the ImageAnnotationGenerator with specific models and device settings. 32 | 33 | Args: 34 | model_name: The name of the BLIP model for generating image captions. 35 | text_moderation_model_name: The name of the model used for moderating 36 | text descriptions. 37 | device: The device to use. 38 | apply_moderation: Flag to determine whether text moderation should be 39 | applied to captions. 40 | """ 41 | self.vlm_name = vlm_name 42 | self.llm_name = llm_name 43 | self.vlm_processor = None 44 | self.vlm = None 45 | self.llm = None 46 | self.device = device 47 | 48 | def load_vlm(self) -> None: 49 | """ 50 | Load the vision-language model for image annotation. 51 | """ 52 | bt.logging.debug(f"Loading caption generation model {self.vlm_name}") 53 | self.vlm_processor = Blip2Processor.from_pretrained( 54 | self.vlm_name, torch_dtype=torch.float32 55 | ) 56 | self.vlm = Blip2ForConditionalGeneration.from_pretrained( 57 | self.vlm_name, torch_dtype=torch.float32 58 | ) 59 | self.vlm.to(self.device) 60 | bt.logging.info(f"Loaded image annotation model {self.vlm_name}") 61 | 62 | def load_llm(self) -> None: 63 | """ 64 | Load the language model for text moderation. 65 | """ 66 | bt.logging.debug(f"Loading caption moderation model {self.llm_name}") 67 | m = re.match(r"cuda:(\d+)", self.device) 68 | gpu_id = int(m.group(1)) if m else 0 69 | llm = AutoModelForCausalLM.from_pretrained( 70 | self.llm_name, 71 | torch_dtype=torch.bfloat16, 72 | device_map={"": gpu_id} 73 | ) 74 | tokenizer = AutoTokenizer.from_pretrained(self.llm_name) 75 | self.llm = pipeline("text-generation", model=llm, tokenizer=tokenizer) 76 | bt.logging.info(f"Loaded caption moderation model {self.llm_name}") 77 | 78 | def load_models(self) -> None: 79 | """ 80 | Load the necessary models for image annotation and text moderation onto 81 | the specified device. 82 | """ 83 | if self.vlm is None: 84 | self.load_vlm() 85 | else: 86 | bt.logging.warning(f"vlm already loaded") 87 | 88 | if self.llm is None: 89 | self.load_llm() 90 | else: 91 | bt.logging.warning(f"llm already loaded") 92 | 93 | def clear_gpu(self) -> None: 94 | """ 95 | Clear GPU memory by moving models back to CPU and deleting them, 96 | followed by collecting garbage. 97 | """ 98 | bt.logging.debug("Clearing GPU memory after prompt generation") 99 | if self.vlm: 100 | del self.vlm 101 | self.vlm = None 102 | 103 | if self.llm: 104 | del self.llm 105 | self.llm = None 106 | 107 | gc.collect() 108 | torch.cuda.empty_cache() 109 | 110 | def generate( 111 | self, image: Image.Image, downstream_task: str = None, max_new_tokens: int = 20 112 | ) -> str: 113 | """ 114 | Generate a string description for a given image using prompt-based 115 | captioning and building conversational context. 116 | 117 | Args: 118 | image: The image for which the description is to be generated. 119 | task: The generation task ('t2i', 't2v', 'i2i', 'i2v'). If video task, 120 | motion descriptions will be added. 121 | max_new_tokens: The maximum number of tokens to generate for each 122 | prompt. 123 | 124 | Returns: 125 | A generated description of the image. 126 | """ 127 | if self.vlm is None or self.vlm_processor is None: 128 | self.load_vlm() 129 | 130 | description = "" 131 | prompts = [ 132 | "An image of", 133 | "The setting is", 134 | "The background is", 135 | "The image type/style is", 136 | ] 137 | 138 | for i, prompt in enumerate(prompts): 139 | description += prompt + " " 140 | inputs = self.vlm_processor( 141 | image, text=description, return_tensors="pt" 142 | ).to(self.device, torch.float32) 143 | 144 | generated_ids = self.vlm.generate(**inputs, max_new_tokens=max_new_tokens) 145 | answer = self.vlm_processor.batch_decode( 146 | generated_ids, skip_special_tokens=True 147 | )[0].strip() 148 | 149 | bt.logging.trace(f"{i}. Prompt: {prompt}") 150 | bt.logging.trace(f"{i}. Answer: {answer}") 151 | 152 | if answer: 153 | answer = answer.rstrip(" ,;!?") 154 | if not answer.endswith("."): 155 | answer += "." 156 | description += answer + " " 157 | else: 158 | description = description[: -len(prompt) - 1] 159 | 160 | if description.startswith(prompts[0]): 161 | description = description[len(prompts[0]) :] 162 | 163 | description = description.strip() 164 | if not description.endswith("."): 165 | description += "." 166 | 167 | moderated_description = self.moderate(description) 168 | 169 | if downstream_task in ["t2v", "i2v"]: 170 | return self.enhance(moderated_description) 171 | return moderated_description 172 | 173 | def moderate(self, description: str, max_new_tokens: int = 80) -> str: 174 | """ 175 | Use the text moderation pipeline to make the description more concise 176 | and neutral. 177 | 178 | Args: 179 | description: The text description to be moderated. 180 | max_new_tokens: Maximum number of new tokens to generate in the 181 | moderated text. 182 | 183 | Returns: 184 | The moderated description text, or the original description if 185 | moderation fails. 186 | """ 187 | if self.llm is None: 188 | self.load_llm() 189 | 190 | messages = [ 191 | { 192 | "role": "system", 193 | "content": ( 194 | "[INST]You always concisely rephrase given descriptions, " 195 | "eliminate redundancy, and remove all specific references to " 196 | "individuals by name. You do not respond with anything other " 197 | "than the revised description.[/INST]" 198 | ), 199 | }, 200 | {"role": "user", "content": description}, 201 | ] 202 | try: 203 | moderated_text = self.llm( 204 | messages, 205 | max_new_tokens=max_new_tokens, 206 | pad_token_id=self.llm.tokenizer.eos_token_id, 207 | return_full_text=False, 208 | ) 209 | return moderated_text[0]["generated_text"] 210 | 211 | except Exception as e: 212 | bt.logging.error(f"An error occurred during moderation: {e}", exc_info=True) 213 | return description 214 | 215 | def enhance(self, description: str, max_new_tokens: int = 80) -> str: 216 | """ 217 | Enhance a static image description to make it suitable for video generation 218 | by adding dynamic elements and motion. 219 | 220 | Args: 221 | description: The static image description to enhance. 222 | max_new_tokens: Maximum number of new tokens to generate in the enhanced text. 223 | 224 | Returns: 225 | An enhanced description suitable for video generation, or the original 226 | description if enhancement fails. 227 | """ 228 | if self.llm is None: 229 | self.load_llm() 230 | 231 | messages = [ 232 | { 233 | "role": "system", 234 | "content": ( 235 | "[INST]You are an expert at converting image descriptions into video prompts. " 236 | "Analyze the existing motion in the scene and enhance it naturally:\n" 237 | "1. If motion exists in the image (falling, throwing, running, etc.):\n" 238 | " - Maintain and emphasize that existing motion\n" 239 | " - Add smooth continuation of the movement\n" 240 | "2. If the subject is static (sitting, standing, placed):\n" 241 | " - Keep it stable\n" 242 | " - Add minimal environmental motion if appropriate\n" 243 | "3. Add ONE subtle camera motion that complements the scene\n" 244 | "4. Keep the description concise and natural\n" 245 | "Only respond with the enhanced description.[/INST]" 246 | ), 247 | }, 248 | {"role": "user", "content": description}, 249 | ] 250 | 251 | try: 252 | enhanced_text = self.llm( 253 | messages, 254 | max_new_tokens=max_new_tokens, 255 | pad_token_id=self.llm.tokenizer.eos_token_id, 256 | return_full_text=False, 257 | ) 258 | return enhanced_text[0]["generated_text"] 259 | 260 | except Exception as e: 261 | bt.logging.error(f"An error occurred during motion enhancement: {e}") 262 | return description 263 | 264 | def sanitize(self, prompt: str, max_new_tokens: int = 80) -> str: 265 | """ 266 | Use the LLM to make the prompt more SFW (less NSFW). 267 | """ 268 | 269 | if self.llm is None: 270 | self.load_llm() 271 | 272 | messages = [ 273 | { 274 | "role": "system", 275 | "content": ( 276 | "[INST]You are an expert at making prompts safe for work (SFW). " 277 | "Rephrase the following prompt to remove or neutralize any NSFW, sexual, or explicit content. " 278 | "Keep the prompt as close as possible to the original intent, but ensure it is SFW. " 279 | "Only respond with the sanitized prompt.[/INST]" 280 | ), 281 | }, 282 | {"role": "user", "content": prompt}, 283 | ] 284 | try: 285 | sanitized = self.llm( 286 | messages, 287 | max_new_tokens=max_new_tokens, 288 | pad_token_id=self.llm.tokenizer.eos_token_id, 289 | return_full_text=False, 290 | ) 291 | return sanitized[0]["generated_text"] 292 | except Exception as e: 293 | bt.logging.error(f"An error occurred during prompt sanitization: {e}") 294 | return prompt 295 | -------------------------------------------------------------------------------- /bitmind/generation/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BitMind-AI/bitmind-subnet/1a108251c409c1015cde394ddcd73660f881a6a0/bitmind/generation/util/__init__.py -------------------------------------------------------------------------------- /bitmind/generation/util/image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import PIL 3 | import os 4 | from PIL import Image, ImageDraw 5 | from typing import Tuple, Union, List 6 | 7 | 8 | def resize_image( 9 | image: PIL.Image.Image, max_width: int, max_height: int 10 | ) -> PIL.Image.Image: 11 | """Resize the image to fit within specified dimensions while maintaining aspect ratio.""" 12 | original_width, original_height = image.size 13 | 14 | # Calculate the aspect ratio and determine new dimensions 15 | aspect_ratio = original_width / original_height 16 | new_width = min(max_width, original_width) 17 | new_height = int(new_width / aspect_ratio) 18 | 19 | if new_height > max_height: 20 | new_height = max_height 21 | new_width = int(new_height * aspect_ratio) 22 | 23 | # Resize the image using the high-quality LANCZOS filter 24 | resized_image = image.resize((new_width, new_height), PIL.Image.LANCZOS) 25 | return resized_image 26 | 27 | 28 | def resize_images_in_directory(directory, target_width, target_height): 29 | """ 30 | Resize all images in the specified directory to the target width and height. 31 | 32 | Args: 33 | directory (str): Path to the directory containing images. 34 | target_width (int): Target width for resizing the images. 35 | target_height (int): Target height for resizing the images. 36 | """ 37 | # List all files in the directory 38 | for filename in os.listdir(directory): 39 | if filename.endswith( 40 | (".png", ".jpg", ".jpeg", ".bmp", ".gif") 41 | ): # Check for image file extensions 42 | filepath = os.path.join(directory, filename) 43 | with PIL.Image.open(filepath) as img: 44 | # Resize the image and save back to the file location 45 | resized_img = resize_image( 46 | img, max_width=target_width, max_height=target_height 47 | ) 48 | resized_img.save(filepath) 49 | 50 | 51 | def save_images_to_disk( 52 | image_dataset, start_index, num_images, save_directory, resize=True 53 | ): 54 | if not os.path.exists(save_directory): 55 | os.makedirs(save_directory) 56 | 57 | for i in range(start_index, start_index + num_images): 58 | try: 59 | image_data = image_dataset[i] # Retrieve image using the __getitem__ method 60 | image = image_data["image"] # Extract the image 61 | image_id = image_data["id"] # Extract the image ID 62 | file_path = os.path.join( 63 | save_directory, f"{image_id}.jpg" 64 | ) # Construct file path 65 | # if resize: 66 | # image = resize_image(image, TARGET_IMAGE_SIZE[0], TARGET_IMAGE_SIZE[1]) 67 | image.save(file_path, "JPEG") # Save the image 68 | print(f"Saved: {file_path}") 69 | except Exception as e: 70 | print(f"Failed to save image {i}: {e}") 71 | 72 | 73 | def create_random_mask(size: Tuple[int, int]) -> Image.Image: 74 | """ 75 | Create a random mask for i2i transformation. 76 | """ 77 | w, h = size 78 | mask = Image.new("RGB", size, "black") 79 | 80 | if np.random.rand() < 0.5: 81 | # Rectangular mask with smoother edges 82 | width = np.random.randint(w // 4, w // 2) 83 | height = np.random.randint(h // 4, h // 2) 84 | 85 | # Center the rectangle with some random offset 86 | x = (w - width) // 2 + np.random.randint(-width // 4, width // 4) 87 | y = (h - height) // 2 + np.random.randint(-height // 4, height // 4) 88 | 89 | # Create mask with PIL draw for smoother edges 90 | draw = ImageDraw.Draw(mask) 91 | draw.rounded_rectangle( 92 | [x, y, x + width, y + height], 93 | radius=min(width, height) // 10, # Smooth corners 94 | fill="white", 95 | ) 96 | else: 97 | # Circular mask with feathered edges 98 | draw = ImageDraw.Draw(mask) 99 | x = w // 2 100 | y = h // 2 101 | 102 | # Make radius proportional to image size 103 | radius = min(w, h) // 4 104 | 105 | # Add small random offset to center 106 | x += np.random.randint(-radius // 4, radius // 4) 107 | y += np.random.randint(-radius // 4, radius // 4) 108 | 109 | # Draw multiple circles with decreasing opacity for feathered edge 110 | for r in range(radius, radius - 10, -1): 111 | opacity = int(255 * (r - (radius - 10)) / 10) 112 | draw.ellipse([x - r, y - r, x + r, y + r], fill=(255, 255, 255, opacity)) 113 | 114 | return mask, (x, y) 115 | 116 | 117 | def is_black_output( 118 | modality: str, output: Union[List[Image.Image], Image.Image], threshold: int = 10 119 | ) -> bool: 120 | """ 121 | Returns True if the image or frames are (almost) completely black. 122 | """ 123 | if modality == "image": 124 | arr = np.array(output[modality].images[0]) 125 | return np.mean(arr) < threshold 126 | elif modality == "video": 127 | return np.all([np.mean(np.array(arr)) < threshold for arr in output[modality].frames[0]]) 128 | -------------------------------------------------------------------------------- /bitmind/generation/util/prompt.py: -------------------------------------------------------------------------------- 1 | def get_tokenizer_with_min_len(model): 2 | """ 3 | Returns the tokenizer with the smallest maximum token length. 4 | 5 | Args: 6 | model: Single pipeline or dict of pipeline stages. 7 | 8 | Returns: 9 | tuple: (tokenizer, max_token_length) 10 | """ 11 | # Get the model to check for tokenizers 12 | pipeline = model["stage1"] if isinstance(model, dict) else model 13 | 14 | # If model has two tokenizers, return the one with smaller max length 15 | if hasattr(pipeline, "tokenizer_2"): 16 | len_1 = pipeline.tokenizer.model_max_length 17 | len_2 = pipeline.tokenizer_2.model_max_length 18 | return ( 19 | (pipeline.tokenizer_2, len_2) 20 | if len_2 < len_1 21 | else (pipeline.tokenizer, len_1) 22 | ) 23 | 24 | return pipeline.tokenizer, pipeline.tokenizer.model_max_length 25 | 26 | 27 | def truncate_prompt_if_too_long(prompt: str, model): 28 | """ 29 | Truncates the input string if it exceeds the maximum token length when tokenized. 30 | 31 | Args: 32 | prompt (str): The text prompt that may need to be truncated. 33 | 34 | Returns: 35 | str: The original prompt if within the token limit; otherwise, a truncated version of the prompt. 36 | """ 37 | tokenizer, max_token_len = get_tokenizer_with_min_len(model) 38 | tokens = tokenizer(prompt, verbose=False) # Suppress token max exceeded warnings 39 | if len(tokens["input_ids"]) < max_token_len: 40 | return prompt 41 | 42 | # Truncate tokens if they exceed the maximum token length, decode the tokens back to a string 43 | truncated_prompt = tokenizer.decode( 44 | token_ids=tokens["input_ids"][: max_token_len - 1], skip_special_tokens=True 45 | ) 46 | tokens = tokenizer(truncated_prompt) 47 | return truncated_prompt 48 | -------------------------------------------------------------------------------- /bitmind/metagraph.py: -------------------------------------------------------------------------------- 1 | import time 2 | import asyncio 3 | from typing import Callable, List, Tuple 4 | import numpy as np 5 | import bittensor as bt 6 | from bittensor.utils.weight_utils import process_weights_for_netuid 7 | 8 | from bitmind.utils import fail_with_none 9 | 10 | import threading 11 | 12 | 13 | def get_miner_uids( 14 | metagraph: "bt.metagraph", self_uid: int, vpermit_tao_limit: int 15 | ) -> List[int]: 16 | available_uids = [] 17 | for uid in range(int(metagraph.n.item())): 18 | if uid == self_uid: 19 | continue 20 | 21 | # Filter non serving axons. 22 | if not metagraph.axons[uid].is_serving: 23 | continue 24 | # Filter validator permit > 1024 stake. 25 | if metagraph.validator_permit[uid]: 26 | if metagraph.S[uid] > vpermit_tao_limit: 27 | continue 28 | available_uids.append(uid) 29 | continue 30 | return available_uids 31 | 32 | 33 | def create_set_weights(version: int, netuid: int): 34 | @fail_with_none("Failed setting weights") 35 | def set_weights( 36 | wallet: "bt.wallet", 37 | metagraph: "bt.metagraph", 38 | subtensor: "bt.subtensor", 39 | weights: Tuple[List[int], List[float]], 40 | ): 41 | uids, raw_weights = weights 42 | if not len(uids): 43 | bt.logging.info("No UIDS to score") 44 | return 45 | 46 | # Set the weights on chain via our subtensor connection. 47 | ( 48 | processed_weight_uids, 49 | processed_weights, 50 | ) = process_weights_for_netuid( 51 | uids=np.asarray(uids), 52 | weights=np.asarray(raw_weights), 53 | netuid=netuid, 54 | subtensor=subtensor, 55 | metagraph=metagraph, 56 | ) 57 | 58 | bt.logging.info("Setting Weights: " + str(processed_weights)) 59 | bt.logging.info("Weight Uids: " + str(processed_weight_uids)) 60 | for _ in range(3): 61 | result, message = subtensor.set_weights( 62 | wallet=wallet, 63 | netuid=netuid, 64 | uids=processed_weight_uids, # type: ignore 65 | weights=processed_weights, 66 | wait_for_finalization=False, 67 | wait_for_inclusion=False, 68 | version_key=version, 69 | max_retries=1, 70 | ) 71 | if result is True: 72 | bt.logging.success("set_weights on chain successfully!") 73 | break 74 | else: 75 | bt.logging.error(f"set_weights failed {message}") 76 | time.sleep(15) 77 | 78 | return set_weights 79 | 80 | 81 | def create_subscription_handler(substrate, callback: Callable): 82 | def inner(obj, update_nr, _): 83 | substrate.get_block(block_number=obj["header"]["number"]) 84 | 85 | if update_nr >= 1: 86 | loop = asyncio.new_event_loop() 87 | asyncio.set_event_loop(loop) 88 | return loop.run_until_complete(callback(obj["header"]["number"])) 89 | 90 | return inner 91 | 92 | 93 | def start_subscription(substrate, callback: Callable): 94 | return substrate.subscribe_block_headers( 95 | create_subscription_handler(substrate, callback) 96 | ) 97 | 98 | 99 | def run_block_callback_thread(substrate, callback: Callable): 100 | try: 101 | subscription_thread = threading.Thread( 102 | target=start_subscription, args=[substrate, callback], daemon=True 103 | ) 104 | subscription_thread.start() 105 | bt.logging.info("Block subscription started in background thread.") 106 | return subscription_thread 107 | except Exception as e: 108 | bt.logging.error(f"faaailuuure {callback} - {e}") 109 | -------------------------------------------------------------------------------- /bitmind/scoring/__init__.py: -------------------------------------------------------------------------------- 1 | from .eval_engine import EvalEngine 2 | -------------------------------------------------------------------------------- /bitmind/scoring/miner_history.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from collections import deque 3 | import bittensor as bt 4 | import numpy as np 5 | import joblib 6 | import traceback 7 | import os 8 | 9 | from bitmind.types import Modality 10 | 11 | 12 | class MinerHistory: 13 | """Tracks all recent miner performance to facilitate reward computation. 14 | Will be replaced with Redis in a future release """ 15 | 16 | VERSION = 2 17 | 18 | def __init__(self, store_last_n_predictions: int = 100): 19 | self.predictions: Dict[int, Dict[Modality, deque]] = {} 20 | self.labels: Dict[int, Dict[Modality, deque]] = {} 21 | self.miner_hotkeys: Dict[int, str] = {} 22 | self.health: Dict[int: int] = {} 23 | self.store_last_n_predictions = store_last_n_predictions 24 | self.version = self.VERSION 25 | 26 | def update( 27 | self, 28 | uid: int, 29 | prediction: np.ndarray, 30 | error: str, 31 | label: int, 32 | modality: Modality, 33 | miner_hotkey: str, 34 | ): 35 | """Update the miner prediction history. 36 | 37 | Args: 38 | prediction: numpy array of shape (3,) containing probabilities for 39 | [real, synthetic, semi-synthetic] 40 | label: integer label (0 for real, 1 for synthetic, 2 for semi-synthetic) 41 | """ 42 | if uid not in self.miner_hotkeys or self.miner_hotkeys[uid] != miner_hotkey: 43 | self.reset_miner_history(uid, miner_hotkey) 44 | bt.logging.info(f"Reset history for {uid} {miner_hotkey}") 45 | 46 | if not error: 47 | self.predictions[uid][modality].append(np.array(prediction)) 48 | self.labels[uid][modality].append(label) 49 | self.health[uid] = 1 50 | else: 51 | self.health[uid] = 0 52 | 53 | def _reset_predictions(self, uid: int): 54 | self.predictions[uid] = { 55 | Modality.IMAGE: deque(maxlen=self.store_last_n_predictions), 56 | Modality.VIDEO: deque(maxlen=self.store_last_n_predictions), 57 | } 58 | 59 | def _reset_labels(self, uid: int): 60 | self.labels[uid] = { 61 | Modality.IMAGE: deque(maxlen=self.store_last_n_predictions), 62 | Modality.VIDEO: deque(maxlen=self.store_last_n_predictions), 63 | } 64 | 65 | def reset_miner_history(self, uid: int, miner_hotkey: str): 66 | self._reset_predictions(uid) 67 | self._reset_labels(uid) 68 | self.miner_hotkeys[uid] = miner_hotkey 69 | 70 | def get_prediction_count(self, uid: int) -> int: 71 | counts = {} 72 | for modality in [Modality.IMAGE, Modality.VIDEO]: 73 | counts[modality] = len(self.get_recent_predictions_and_labels(uid, modality)[0]) 74 | return counts 75 | 76 | def get_recent_predictions_and_labels(self, uid, modality): 77 | if uid not in self.predictions or modality not in self.predictions[uid]: 78 | return [], [] 79 | valid_indices = [ 80 | i for i, p in enumerate(self.predictions[uid][modality]) 81 | if p is not None and (isinstance(p, (list, np.ndarray)) and not np.any(p == None)) 82 | ] 83 | valid_preds = np.array([ 84 | p for i, p in enumerate(self.predictions[uid][modality]) if i in valid_indices 85 | ]) 86 | labels_with_valid_preds = np.array([ 87 | p for i, p in enumerate(self.labels[uid][modality]) if i in valid_indices 88 | ]) 89 | return valid_preds, labels_with_valid_preds 90 | 91 | def get_healthy_miner_uids(self) -> list: 92 | return [uid for uid, healthy in self.health.items() if healthy] 93 | 94 | def get_unhealthy_miner_uids(self) -> list: 95 | return [uid for uid, healthy in self.health.items() if not healthy] 96 | 97 | def save_state(self, save_dir): 98 | path = os.path.join(save_dir, "history.pkl") 99 | state = { 100 | "version": self.version, 101 | "store_last_n_predictions": self.store_last_n_predictions, 102 | "miner_hotkeys": self.miner_hotkeys, 103 | "predictions": self.predictions, 104 | "labels": self.labels, 105 | "health": self.health 106 | } 107 | joblib.dump(state, path) 108 | 109 | def load_state(self, save_dir): 110 | path = os.path.join(save_dir, "history.pkl") 111 | if not os.path.isfile(path): 112 | bt.logging.warning(f"No saved state found at {path}") 113 | return False 114 | 115 | try: 116 | state = joblib.load(path) 117 | if state["version"] != self.VERSION: 118 | bt.logging.warning( 119 | f"Loading state from different version: {state['version']} != {self.VERSION}" 120 | ) 121 | 122 | self.version = state.get("version", self.VERSION) 123 | self.store_last_n_predictions = state.get("store_last_n_predictions", self.store_last_n_predictions) 124 | self.miner_hotkeys = state.get("miner_hotkeys", self.miner_hotkeys) 125 | self.predictions = state.get("predictions", self.predictions) 126 | self.labels = state.get("labels", self.labels) 127 | self.health = state.get("health", self.health) 128 | 129 | if len(self.miner_hotkeys) == 0: 130 | bt.logging.warning("Loaded state has no miner hotkeys") 131 | if len(self.predictions) == 0: 132 | bt.logging.warning("Loaded state has no predictions") 133 | if len(self.labels) == 0: 134 | bt.logging.warning("Loaded state has no labels") 135 | if len(self.health) == 0: 136 | bt.logging.warning("Loaded state has no health records") 137 | 138 | bt.logging.debug( 139 | f"Successfully loaded history for {len(self.miner_hotkeys)} miners" 140 | ) 141 | return True 142 | 143 | except Exception as e: 144 | bt.logging.error(f"Error deserializing MinerHistory state: {str(e)}") 145 | bt.logging.error(traceback.format_exc()) 146 | return False 147 | -------------------------------------------------------------------------------- /bitmind/types.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from dataclasses import dataclass, field 3 | from enum import Enum, auto 4 | from pydantic import BaseModel 5 | from typing import Dict, List, Any, Optional, Union 6 | 7 | 8 | class NeuronType(Enum): 9 | VALIDATOR = "VALIDATOR" 10 | VALIDATOR_PROXY = "VALIDATOR_PROXY" 11 | MINER = "MINER" 12 | 13 | 14 | class FileType(Enum): 15 | PARQUET = auto() 16 | ZIP = auto() 17 | VIDEO = auto() 18 | IMAGE = auto() 19 | 20 | 21 | class CacheType(str, Enum): 22 | MEDIA = "media" 23 | COMPRESSED = "compressed" 24 | 25 | 26 | class Modality(str, Enum): 27 | IMAGE = "image" 28 | VIDEO = "video" 29 | 30 | 31 | class MediaType(str, Enum): 32 | REAL = "real", 0 33 | SYNTHETIC = "synthetic", 1 34 | SEMISYNTHETIC = "semisynthetic", 2 35 | 36 | def __new__(cls, str_value, int_value): 37 | obj = str.__new__(cls, str_value) 38 | obj._value_ = str_value 39 | obj.int_value = int_value 40 | return obj 41 | 42 | 43 | @dataclass 44 | class CacheUpdaterConfig: 45 | num_sources_per_dataset: int = 1 46 | num_items_per_source: int = 100 47 | 48 | 49 | @dataclass 50 | class CacheConfig: 51 | """Configuration for a cache at base_dir / {modality} / {media_type}""" 52 | 53 | modality: str 54 | media_type: str 55 | base_dir: Path = Path("~/.cache/sn34").expanduser() 56 | tags: Optional[List[str]] = None 57 | max_compressed_gb: float = 100.0 58 | max_media_gb: float = 10.0 59 | 60 | def get_path(self): 61 | media_cache_path = Path(self.base_dir) / self.modality / self.media_type 62 | media_cache_path.mkdir(exist_ok=True, parents=True) 63 | return media_cache_path 64 | 65 | 66 | @dataclass 67 | class DatasetConfig: 68 | path: str # HuggingFace path 69 | type: Modality 70 | media_type: MediaType 71 | tags: List[str] = field(default_factory=list) 72 | file_format: str = "" 73 | compressed_format: str = "" 74 | priority: int = 1 # Optional: priority for sampling (higher is more frequent) 75 | enabled: bool = True 76 | 77 | def __post_init__(self): 78 | """Validate and set defaults""" 79 | if not self.compressed_format: 80 | if self.type == Modality.IMAGE: 81 | self.compressed_format = "parquet" 82 | elif self.type == Modality.VIDEO: 83 | self.compressed_format = "zip" 84 | 85 | if isinstance(self.tags, str): 86 | self.tags = [t.strip() for t in self.tags.split(",")] 87 | 88 | if isinstance(self.type, str): 89 | self.type = Modality(self.type.lower()) 90 | 91 | if isinstance(self.media_type, str): 92 | self.media_type = MediaType(self.media_type.lower()) 93 | 94 | 95 | class ModelTask(str, Enum): 96 | """Type of task the model is designed for""" 97 | 98 | TEXT_TO_IMAGE = "t2i" 99 | TEXT_TO_VIDEO = "t2v" 100 | IMAGE_TO_IMAGE = "i2i" 101 | IMAGE_TO_VIDEO = "i2v" 102 | 103 | 104 | class ModelConfig: 105 | """ 106 | Configuration for a generative AI model. 107 | 108 | Attributes: 109 | path: The Hugging Face model path or identifier 110 | task: The primary task of the model (T2I, T2V, I2I) 111 | media_type: Type of output (synthetic or semisynthetic) 112 | pipeline_cls: Pipeline class used to load the model 113 | pretrained_args: Arguments for the from_pretrained method 114 | generate_args: Default arguments for generation 115 | tags: List of tags for categorizing the model 116 | use_autocast: Whether to use autocast during generation 117 | scheduler: Optional scheduler configuration 118 | scheduler_cls: Optional scheduler class 119 | scheduler_args: Optional scheduler args 120 | """ 121 | 122 | def __init__( 123 | self, 124 | path: str, 125 | task: ModelTask, 126 | pipeline_cls: Union[Any, Dict[str, Any]], 127 | media_type: Optional[MediaType] = None, 128 | pretrained_args: Dict[str, Any] = None, 129 | generate_args: Dict[str, Any] = None, 130 | tags: List[str] = None, 131 | use_autocast: bool = True, 132 | enable_model_cpu_offload: bool = False, 133 | enable_sequential_cpu_offload: bool = False, 134 | vae_enable_slicing: bool = False, 135 | vae_enable_tiling: bool = False, 136 | scheduler: Dict[str, Any] = None, 137 | save_args: Dict[str, Any] = None, 138 | pipeline_stages: List[Dict[str, Any]] = None, 139 | clear_memory_on_stage_end: bool = False, 140 | lora_model_id: str = None, 141 | lora_loading_args: Dict[str, Any] = None, 142 | ): 143 | self.path = path 144 | self.task = task 145 | self.pipeline_cls = pipeline_cls 146 | self.media_type = media_type 147 | 148 | if self.media_type is None: 149 | self.media_type = ( 150 | MediaType.SEMISYNTHETIC 151 | if task == ModelTask.IMAGE_TO_IMAGE 152 | else MediaType.SYNTHETIC 153 | ) 154 | 155 | self.pretrained_args = pretrained_args or {} 156 | self.generate_args = generate_args or {} 157 | self.tags = tags or [] 158 | self.use_autocast = use_autocast 159 | self.enable_model_cpu_offload = enable_model_cpu_offload 160 | self.enable_sequential_cpu_offload = enable_sequential_cpu_offload 161 | self.vae_enable_slicing = vae_enable_slicing 162 | self.vae_enable_tiling = vae_enable_tiling 163 | self.scheduler = scheduler 164 | self.save_args = save_args or {} 165 | self.pipeline_stages = pipeline_stages 166 | self.clear_memory_on_stage_end = clear_memory_on_stage_end 167 | self.lora_model_id = lora_model_id 168 | self.lora_loading_args = lora_loading_args 169 | 170 | def to_dict(self) -> Dict[str, Any]: 171 | """Convert config to dictionary format""" 172 | return { 173 | "pipeline_cls": self.pipeline_cls, 174 | "from_pretrained_args": self.pretrained_args, 175 | "generate_args": self.generate_args, 176 | "use_autocast": self.use_autocast, 177 | "enable_model_cpu_offload": self.enable_model_cpu_offload, 178 | "enable_sequential_cpu_offload": self.enable_sequential_cpu_offload, 179 | "vae_enable_slicing": self.vae_enable_slicing, 180 | "vae_enable_tiling": self.vae_enable_tiling, 181 | "scheduler": self.scheduler, 182 | "save_args": self.save_args, 183 | "pipeline_stages": self.pipeline_stages, 184 | "clear_memory_on_stage_end": self.clear_memory_on_stage_end, 185 | } 186 | 187 | 188 | class ValidatorConfig(BaseModel): 189 | skip_weight_set: Optional[bool] = False 190 | set_weights_on_start: Optional[bool] = False 191 | max_concurrent_organics: Optional[int] = 2 192 | -------------------------------------------------------------------------------- /bitmind/utils.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | import bittensor as bt 3 | import functools 4 | import json 5 | import os 6 | 7 | 8 | def print_info(metagraph, hotkey, block, isMiner=True): 9 | uid = metagraph.hotkeys.index(hotkey) 10 | log = f"UID:{uid} | Block:{block} | Consensus:{metagraph.C[uid]} | " 11 | if isMiner: 12 | bt.logging.info( 13 | log 14 | + f"Stake:{metagraph.S[uid]} | Trust:{metagraph.T[uid]} | Incentive:{metagraph.I[uid]} | Emission:{metagraph.E[uid]}" 15 | ) 16 | return 17 | bt.logging.info(log + f"VTrust:{metagraph.Tv[uid]} | ") 18 | 19 | 20 | def fail_with_none(message: str = ""): 21 | def outer(func): 22 | def inner(*args, **kwargs): 23 | try: 24 | return func(*args, **kwargs) 25 | except Exception as e: 26 | bt.logging.error(message) 27 | bt.logging.error(str(e)) 28 | bt.logging.error(traceback.format_exc()) 29 | return None 30 | 31 | return inner 32 | 33 | return outer 34 | 35 | 36 | def on_block_interval(interval_attr_name): 37 | """ 38 | Decorator for methods that should only execute at specific block intervals. 39 | 40 | Args: 41 | interval_attr_name: String name of the config attribute that specifies the interval 42 | """ 43 | 44 | def decorator(func): 45 | @functools.wraps(func) 46 | async def wrapper(self, block, *args, **kwargs): 47 | interval = getattr(self.config, interval_attr_name) 48 | if interval is None: 49 | bt.logging.error(f"No interval found for {interval_attr_name}") 50 | if ( 51 | block == 0 or block % interval == 0 52 | ): # Allow execution on block 0 for initialization 53 | return await func(self, block, *args, **kwargs) 54 | return None 55 | 56 | return wrapper 57 | 58 | return decorator 59 | 60 | 61 | class ExitContext: 62 | """ 63 | Using this as a class lets us pass this to other threads 64 | """ 65 | 66 | isExiting: bool = False 67 | 68 | def startExit(self, *_): 69 | if self.isExiting: 70 | exit() 71 | self.isExiting = True 72 | 73 | def __bool__(self): 74 | return self.isExiting 75 | 76 | 77 | def get_metadata(media_path): 78 | """Get metadata for a media file if it exists.""" 79 | base_path = os.path.splitext(media_path)[0] 80 | json_path = f"{base_path}.json" 81 | 82 | if os.path.exists(json_path): 83 | try: 84 | with open(json_path, "r") as f: 85 | return json.load(f) 86 | except json.JSONDecodeError: 87 | bt.logging.error(f"Warning: Could not parse JSON file: {json_path}") 88 | return {} 89 | return {} 90 | 91 | 92 | def get_file_modality(filepath: str) -> str: 93 | """ 94 | Determine the type of media file based on its extension. 95 | 96 | Args: 97 | filepath: Path to the media file 98 | 99 | Returns: 100 | "image", "video", or "file" based on the file extension 101 | """ 102 | ext = os.path.splitext(filepath)[1].lower() 103 | if ext in [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".webp"]: 104 | return "image" 105 | elif ext in [".mp4", ".avi", ".mov", ".webm", ".mkv", ".flv"]: 106 | return "video" 107 | else: 108 | return "file" 109 | -------------------------------------------------------------------------------- /docs/Incentive.md: -------------------------------------------------------------------------------- 1 | # Bitmind Subnet Incentive Mechanism 2 | 3 | This document covers the current state of SN34's incentive mechanism. 4 | 1. [Overview](#overview) 5 | 2. [Rewards](#rewards) 6 | 3. [Scores](#scores) 7 | 4. [Weights](#weights) 8 | 5. [Incentive](#incentives) 9 | 10 | ## TLDR 11 | 12 | Miner rewards are a weighted combination of their performance on video and image detection challenges. Validators keep track of miner performance using a score vector, which is updated using an exponential moving average. These scores are used by validators to set weights for miners, which determine their reward distribution, incentivizing high-quality predictions and consistent performance. 13 | 14 | 15 | ## Rewards 16 | >A miner's total reward $C$ combines their performance across both image and video challenges, weighted by configurable parameters $p$ that controls the emphasis placed on each modality. 17 | 18 | $$ 19 | C = \sum_{m \in \{image, video\}} p_m \sum_{k \in \{b,m\}} w_k MCC_k 20 | $$ 21 | 22 | The reward for each modality $m$ is a weighted combination of binary and multiclass ($b$ and $m$) Matthews Correlation Coefficient (MCC) scores. The weights $w_k$ allow emphasis to be shifted as needed between the binary distinction between synthetic and authentic, and the more granular separation of fully- and semi-synthetic content. 23 | 24 | 25 | ## Scores 26 | 27 | >Validators set weights based on historical miner performances, tracked by their score vector. 28 | 29 | For each challenge *t*, a validator will randomly sample 50 miners, send them an image/video, and compute their rewards *C* as described above. These reward values are then used to update the validator's score vector *V* using an exponential moving average (EMA) with *α* = 0.02. 30 | 31 | $$ 32 | V_t = 0.02 \cdot C_t + 0.98 \cdot V_{t-1} 33 | $$ 34 | 35 | A low *α* value places emphasis on a miner's historical performance, adding additional smoothing to avoid having a single prediction cause significant score fluctuations. 36 | 37 | 38 | ## Weights 39 | 40 | > Validators set weights around once per tempo (360 blocks) by sending a normalized score vector to the Bittensor blockchain (in `UINT16` representation). 41 | 42 | Weight normalization by L1 norm: 43 | 44 | $$w = \frac{\text{V}}{\lVert\text{V}\rVert_1}$$ 45 | 46 | 47 | ## Incentives 48 | > The [Yuma Consensus algorithm](https://docs.bittensor.com/yuma-consensus) translates the weight matrix *W* into incentives for the subnet miners and dividends for the subnet validators 49 | 50 | Specifically, for each miner *j*, incentive is a function of rank *R*: 51 | 52 | $$I_j = \frac{R_j}{\sum_k R_k}$$ 53 | 54 | where rank *R* is *W* (a matrix of validator weight vectors) weighted by validator stake vector *S*. 55 | 56 | $$R_k = \sum_i S_i \cdot W_{ik}$$ 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /docs/Mining.md: -------------------------------------------------------------------------------- 1 | # Miner Setup Guide 2 | 3 | ## Before you proceed ⚠️ 4 | 5 | If you are new to Bittensor, we recommend familiarizing yourself with the basics in the [Bittensor Docs](https://docs.bittensor.com/) before proceeding. 6 | 7 | **Run your own local subtensor** to avoid rate limits set on public endpoints. See [Run a Subtensor Node Locally](https://github.com/opentensor/subtensor/blob/main/docs/running-subtensor-locally.md#compiling-your-own-binary) for setup instructions. 8 | 9 | **Understand your minimum compute requirements** for model training and miner deployment, which varies depending on your choice of model. You will likely need at least a consumer grade GPU for training. Many models can be deploying in CPU-only environments for mining. 10 | 11 | 12 | ## Installation 13 | 14 | Download the repository and navigate to the folder. 15 | ```bash 16 | git clone https://github.com/bitmind-ai/bitmind-subnet.git && cd bitmind-subnet 17 | ``` 18 | 19 | We recommend using a Conda virtual environment to install the necessary Python packages. 20 | - You can set up Conda with this [quick command-line install](https://docs.anaconda.com/free/miniconda/#quick-command-line-install). 21 | - Note that after you run the last commands in the miniconda setup process, you'll be prompted to start a new shell session to complete the initialization. 22 | 23 | With miniconda installed, you can create your virtual environment with this command: 24 | 25 | ```bash 26 | conda create -y -n bitmind python=3.10 27 | ``` 28 | 29 | - Activating your virtual environment: `conda activate bitmind` 30 | - Deactivating your virtual environment `conda deactivate` 31 | 32 | Install the remaining necessary requirements with the following chained command. 33 | ```bash 34 | conda activate bitmind 35 | export PIP_NO_CACHE_DIR=1 36 | chmod +x setup.sh 37 | ./setup.sh 38 | ``` 39 | 40 | Before you register a miner on testnet or mainnet, you must first fill out all the necessary fields in `.env.miner`. Make a copy of the template, and fill in your wallet and axon information. 41 | 42 | ``` 43 | cp .env.miner.template .env.miner 44 | ``` 45 | 46 | 47 | ## Miner Task 48 | 49 | ### Expected Miner Outputs 50 | 51 | > Miners respond to validator queries with a probability vector [$p_{real}$, $p_{synthetic}$, $p_{semisynthetic}$] 52 | 53 | Your task as a SN34 miner is to classify images and videos as real, synthetic, or semisynthetic. 54 | - **Real**: Authentic meida, not touched in any way by AI 55 | - **Synthetic**: Fully AI-generated media 56 | - **Semisynthetic**: AI-modified (spatially, not temporally) media. E.g. faceswaps, inpainting, etc. 57 | 58 | Minor details: 59 | - You are scored only on correctness, so rounding these probabilities will not give you extra incentive. 60 | - To maximize incentive, you must respond with the multiclass vector described above. 61 | - If your classifier returns a binary response (e.g. a float in $[0., 1.]$ or a vector [$p_{real}$, $p_{synthetic}$]), you will earn partial credit (as defined by our incentive mechanism) 62 | 63 | 64 | ### Training your Detector 65 | 66 | > [!IMPORTANT] 67 | > The default video and image detection models provided in `neurons/miner.py` serve only to exemplify the desired behavior of the miner neuron, and will not provide competitive performance on mainnet. 68 | 69 | #### Model 70 | 71 | #### Data 72 | 73 | 74 | ## Registration 75 | 76 | To run a miner, you must have a registered hotkey. 77 | 78 | > [!IMPORTANT] 79 | > Registering on a Bittensor subnet burns TAO. To reduce the risk of deregistration due to technical issues or a poor performing model, we recommend the following: 80 | > 1. Test your miner on testnet before you start mining on mainnet. 81 | > 2. Before registering your hotkey on mainnet, make sure your axon port is accepting incoming traffic by running `curl your_ip:your_port` 82 | 83 | 84 | #### Mainnet 85 | 86 | ```bash 87 | btcli s register --netuid 34 --wallet.name [wallet_name] --wallet.hotkey [wallet.hotkey] --subtensor.network finney 88 | ``` 89 | 90 | #### Testnet 91 | 92 | > For testnet tao, you can make requests in the [Bittensor Discord's "Requests for Testnet Tao" channel](https://discord.com/channels/799672011265015819/1190048018184011867) 93 | 94 | 95 | ```bash 96 | btcli s register --netuid 168 --wallet.name [wallet_name] --wallet.hotkey [wallet.hotkey] --subtensor.network test 97 | ``` 98 | 99 | #### Mining 100 | 101 | You can now launch your miner with `start_miner.sh`, which will use the configuration you provided in `.env.miner` (see the last step of the [Installation](#installation) section). 102 | -------------------------------------------------------------------------------- /docs/Validating.md: -------------------------------------------------------------------------------- 1 | # Validator Guide 2 | 3 | ## Before you proceed ⚠️ 4 | 5 | If you are new to Bittensor (you're probably not if you're reading the validator guide 😎), we recommend familiarizing yourself with the basics in the [Bittensor Docs](https://docs.bittensor.com/) before proceeding. 6 | 7 | **Run your own local subtensor** to avoid rate limits set on public endpoints. See [Run a Subtensor Node Locally](https://github.com/opentensor/subtensor/blob/main/docs/running-subtensor-locally.md#compiling-your-own-binary) for setup instructions. 8 | 9 | **Understand the minimum compute requirements to run a validator**. Validator neurons on SN34 run a suite of generative (text-to-image, text-to-video, etc.) models that require an **80GB VRAM GPU**. They also maintain a large cache of real and synthetic media to ensure diverse, locally available data for challenging miners. We recommend **1TB of storage**. For more details, please see our [minimum compute documentation](../min_compute.yml) 10 | 11 | ## Required Hugging Face Model Access 12 | 13 | To properly validate, you must gain access to several Hugging Face models used by the subnet. This requires logging in to your Hugging Face account and accepting the terms for each model below: 14 | 15 | - [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) 16 | - [DeepFloyd IF-II-L-v1.0](https://huggingface.co/DeepFloyd/IF-II-L-v1.0) 17 | - [DeepFloyd IF-I-XL-v1.0](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0) 18 | 19 | > **Note:** Accepting the terms for any one of the DeepFloyd IF models (e.g., IF-II-L or IF-I-XL) will grant you access to all DeepFloyd IF models. 20 | > 21 | > **If you've been validating with us for a while (prior to V3), you've likely already gotten access to these models and can disregard this step.** 22 | 23 | To do this: 24 | 1. Log in to your Hugging Face account. 25 | 2. Visit each model page above. 26 | 3. Click the "Access repository" or "Agree and access repository" button to accept the terms. 27 | 28 | ## Installation 29 | 30 | Download the repository and navigate to the folder. 31 | ```bash 32 | git clone https://github.com/bitmind-ai/bitmind-subnet.git && cd bitmind-subnet 33 | ``` 34 | 35 | We recommend using a Conda virtual environment to install the necessary Python packages. 36 | - You can set up Conda with this [quick command-line install](https://www.anaconda.com/docs/getting-started/miniconda/install#linux). 37 | - Note that after you run the last commands in the miniconda setup process, you'll be prompted to start a new shell session to complete the initialization. 38 | 39 | With miniconda installed, you can create your virtual environment with this command: 40 | 41 | ```bash 42 | conda create -y -n bitmind python=3.10 43 | ``` 44 | 45 | - Activating your virtual environment: `conda activate bitmind` 46 | - Deactivating your virtual environment `conda deactivate` 47 | 48 | Install the remaining necessary requirements with the following chained command. 49 | ```bash 50 | conda activate bitmind 51 | export PIP_NO_CACHE_DIR=1 52 | chmod +x setup.sh 53 | ./setup.sh 54 | ``` 55 | 56 | Before you register, you should first fill out all the necessary fields in `.env.validator`. Make a copy of the template, and fill in your wallet information. 57 | 58 | ``` 59 | cp .env.validator.template .env.validator 60 | ``` 61 | 62 | ## Registration 63 | 64 | To validate on our subnet, you must have a registered hotkey. 65 | 66 | #### Mainnet 67 | 68 | ```bash 69 | btcli s register --netuid 34 --wallet.name [wallet_name] --wallet.hotkey [wallet.hotkey] --subtensor.network finney 70 | ``` 71 | 72 | #### Testnet 73 | 74 | ```bash 75 | btcli s register --netuid 168 --wallet.name [wallet_name] --wallet.hotkey [wallet.hotkey] --subtensor.network test 76 | ``` 77 | 78 | 79 | ## Validating 80 | 81 | Before starting your validator, please ensure you've populated the empty fields in `.env.validator`, including `WANDB_API_KEY` and `HUGGING_FACE_TOKEN`. 82 | 83 | If you haven't already, you can start by copying the template, 84 | ``` 85 | cp .env.validator.template .env.validator 86 | ``` 87 | 88 | If you don't have a W&B API key, please reach out to the BitMind team via Discord and we can provide one. 89 | 90 | Now you're ready to run your validator! 91 | 92 | ```bash 93 | conda activate bitmind 94 | ./start_validator.sh 95 | ``` 96 | 97 | - Auto updates are enabled by default. To disable, run with `--no-auto-updates`. 98 | - Self-healing restarts are enabled by default (every 6 hours). To disable, run with `--no-self-heal`. 99 | 100 | 101 | The above command will kick off 3 `pm2` processes 102 | ``` 103 | ┌────┬───────────────────┬─────────────┬─────────┬─────────┬──────────┬────────┬──────┬───────────┬──────────┬──────────┬──────────┬──────────┐ 104 | │ id │ name │ namespace │ version │ mode │ pid │ uptime │ ↺ │ status │ cpu │ mem │ user │ watching │ 105 | ├────┼───────────────────┼─────────────┼─────────┼─────────┼──────────┼────────┼──────┼───────────┼──────────┼──────────┼──────────┼──────────┤ 106 | │ 0 │ sn34-generator │ default │ N/A │ fork │ 2397505 │ 38m │ 2 │ online │ 100% │ 3.0gb │ user │ disabled │ 107 | │ 2 │ sn34-proxy │ default │ N/A │ fork │ 2398000 │ 27m │ 1 │ online │ 0% │ 695.2mb │ user │ disabled │ 108 | │ 1 │ sn34-validator │ default │ N/A │ fork │ 2394939 │ 108m │ 0 │ online │ 0% │ 5.8gb │ user │ disabled │ 109 | └────┴───────────────────┴─────────────┴─────────┴─────────┴──────────┴────────┴──────┴───────────┴──────────┴──────────┴──────────┴──────────┘ 110 | ``` 111 | - `sn34-validator` is the validator process 112 | - `sn34-generator` runs our data generation pipeline to produce **synthetic images and videos** (stored in `~/.cache/sn34`) 113 | - `sn34-proxy`routes organic traffic from our applications to miners. 114 | -------------------------------------------------------------------------------- /docs/static/Bitmind-Logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BitMind-AI/bitmind-subnet/1a108251c409c1015cde394ddcd73660f881a6a0/docs/static/Bitmind-Logo.png -------------------------------------------------------------------------------- /docs/static/Join-BitMind-Discord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BitMind-AI/bitmind-subnet/1a108251c409c1015cde394ddcd73660f881a6a0/docs/static/Join-BitMind-Discord.png -------------------------------------------------------------------------------- /docs/static/Subnet-Arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BitMind-AI/bitmind-subnet/1a108251c409c1015cde394ddcd73660f881a6a0/docs/static/Subnet-Arch.png -------------------------------------------------------------------------------- /min_compute.yml: -------------------------------------------------------------------------------- 1 | # NOTE FOR MINERS: 2 | # Miner min compute varies based on selected model architecture. 3 | # For model training, you will most likely need a GPU. For miner deployment, depending 4 | # on your model, you may be able to get away with CPU. 5 | 6 | version: '3.0.0' 7 | 8 | compute_spec: 9 | 10 | validator: 11 | 12 | cpu: 13 | min_cores: 4 # Minimum number of CPU cores 14 | min_speed: 2.5 # Minimum speed per core (GHz) 15 | recommended_cores: 8 # Recommended number of CPU cores 16 | recommended_speed: 3.5 # Recommended speed per core (GHz) 17 | architecture: "x86_64" # Architecture type (e.g., x86_64, arm64) 18 | 19 | gpu: 20 | required: True # Does the application require a GPU? 21 | min_vram: 80 # Minimum GPU VRAM (GB) 22 | recommended_vram: 80 # Recommended GPU VRAM (GB) 23 | min_compute_capability: 8.0 # Minimum CUDA compute capability 24 | recommended_compute_capability: 8.0 # Recommended CUDA compute capability 25 | recommended_gpu: "NVIDIA A100 80GB PCIE" # Recommended GPU to purchase/rent 26 | fp64: 9.7 # TFLOPS 27 | fp64_tensor_core: 19.5 # TFLOPS 28 | fp32: 19.5 # TFLOPS 29 | tf32: 156 # TFLOPS* 30 | bfloat16_tensor_core: 312 # TFLOPS* 31 | int8_tensor_core: 624 # TOPS* 32 | 33 | # See NVIDIA A100 datasheet for details: 34 | # https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/ 35 | # nvidia-a100-datasheet-nvidia-us-2188504-web.pdf 36 | 37 | # *double with sparsity 38 | 39 | memory: 40 | min_ram: 32 # Minimum RAM (GB) 41 | min_swap: 4 # Minimum swap space (GB) 42 | recommended_swap: 8 # Recommended swap space (GB) 43 | ram_type: "DDR6" # RAM type (e.g., DDR4, DDR3, etc.) 44 | 45 | storage: 46 | min_space: 1000 # Minimum free storage space (GB) 47 | recommended_space: 1000 # Recommended free storage space (GB) 48 | type: "SSD" # Preferred storage type (e.g., SSD, HDD) 49 | min_iops: 1000 # Minimum I/O operations per second (if applicable) 50 | recommended_iops: 5000 # Recommended I/O operations per second 51 | 52 | os: 53 | name: "Ubuntu" # Name of the preferred operating system(s) 54 | version: 22.04 # Version of the preferred operating system(s) 55 | 56 | network_spec: 57 | bandwidth: 58 | download: 100 # Minimum download bandwidth (Mbps) 59 | upload: 20 # Minimum upload bandwidth (Mbps) 60 | -------------------------------------------------------------------------------- /neurons/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BitMind-AI/bitmind-subnet/1a108251c409c1015cde394ddcd73660f881a6a0/neurons/__init__.py -------------------------------------------------------------------------------- /neurons/base.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from threading import Thread 3 | from typing import Callable, List 4 | import bittensor as bt 5 | import copy 6 | import inspect 7 | import traceback 8 | 9 | from bittensor.core.settings import SS58_FORMAT, TYPE_REGISTRY 10 | from nest_asyncio import asyncio 11 | from substrateinterface import SubstrateInterface 12 | import signal 13 | 14 | from bitmind import ( 15 | __spec_version__ as spec_version, 16 | ) 17 | from bitmind.metagraph import run_block_callback_thread 18 | from bitmind.types import NeuronType 19 | from bitmind.utils import ExitContext, on_block_interval 20 | from bitmind.config import ( 21 | add_args, 22 | add_validator_args, 23 | add_miner_args, 24 | add_proxy_args, 25 | validate_config_and_neuron_path, 26 | ) 27 | 28 | 29 | class BaseNeuron: 30 | config: "bt.config" 31 | neuron_type: NeuronType 32 | exit_context = ExitContext() 33 | next_sync_block = None 34 | block_callbacks: List[Callable] = [] 35 | substrate_thread: Thread 36 | 37 | def check_registered(self): 38 | if not self.subtensor.is_hotkey_registered( 39 | netuid=self.config.netuid, 40 | hotkey_ss58=self.wallet.hotkey.ss58_address, 41 | ): 42 | bt.logging.error( 43 | f"Wallet: {self.wallet} is not registered on netuid {self.config.netuid}." 44 | f" Please register the hotkey using `btcli subnets register` before trying again" 45 | ) 46 | exit() 47 | 48 | @on_block_interval("epoch_length") 49 | async def maybe_sync_metagraph(self, block): 50 | self.check_registered() 51 | bt.logging.info("Resyncing Metagraph") 52 | self.metagraph.sync(subtensor=self.subtensor) 53 | 54 | if self.neuron_type == NeuronType.VALIDATOR: 55 | bt.logging.info("Metagraph updated, re-syncing hotkeys and moving averages") 56 | self.eval_engine.sync_to_metagraph() 57 | 58 | async def run_callbacks(self, block): 59 | if ( 60 | hasattr(self, "initialization_complete") 61 | and not self.initialization_complete 62 | ): 63 | bt.logging.debug( 64 | f"Skipping callbacks at block {block} during initialization" 65 | ) 66 | return 67 | 68 | for callback in self.block_callbacks: 69 | try: 70 | res = callback(block) 71 | if inspect.isawaitable(res): 72 | await res 73 | except Exception as e: 74 | bt.logging.error( 75 | f"Failed running callback {callback.__name__}: {str(e)}" 76 | ) 77 | bt.logging.error(traceback.format_exc()) 78 | 79 | def __init__(self, config=None): 80 | bt.logging.info( 81 | f"Bittensor Version: {bt.__version__} | SN34 Version {spec_version}" 82 | ) 83 | 84 | parser = argparse.ArgumentParser() 85 | bt.wallet.add_args(parser) 86 | bt.subtensor.add_args(parser) 87 | bt.logging.add_args(parser) 88 | add_args(parser) 89 | 90 | if self.neuron_type == NeuronType.VALIDATOR: 91 | bt.axon.add_args(parser) 92 | add_validator_args(parser) 93 | if self.neuron_type == NeuronType.VALIDATOR_PROXY: 94 | add_validator_args(parser) 95 | add_proxy_args(parser) 96 | if self.neuron_type == NeuronType.MINER: 97 | bt.axon.add_args(parser) 98 | add_miner_args(parser) 99 | 100 | self.config = bt.config(parser) 101 | if config: 102 | base_config = copy.deepcopy(config) 103 | self.config.merge(base_config) 104 | 105 | validate_config_and_neuron_path(self.config) 106 | 107 | ## Add kill signals 108 | signal.signal(signal.SIGINT, self.exit_context.startExit) 109 | signal.signal(signal.SIGTERM, self.exit_context.startExit) 110 | 111 | ## LOGGING 112 | bt.logging(config=self.config, logging_dir=self.config.neuron.full_path) 113 | bt.logging.set_info() 114 | if self.config.logging.debug: 115 | bt.logging.set_debug(True) 116 | if self.config.logging.trace: 117 | bt.logging.set_trace(True) 118 | 119 | ## BITTENSOR INITIALIZATION 120 | bt.logging.success(self.config) 121 | self.wallet = bt.wallet(config=self.config) 122 | self.subtensor = bt.subtensor( 123 | config=self.config, network=self.config.subtensor.chain_endpoint 124 | ) 125 | self.metagraph = self.subtensor.metagraph(self.config.netuid) 126 | 127 | self.loop = asyncio.get_event_loop() 128 | bt.logging.debug(f"Wallet: {self.wallet}") 129 | bt.logging.debug(f"Subtensor: {self.subtensor}") 130 | bt.logging.debug(f"Metagraph: {self.metagraph}") 131 | 132 | ## CHECK IF REGG'D 133 | self.check_registered() 134 | self.uid = self.metagraph.hotkeys.index(self.wallet.hotkey.ss58_address) 135 | 136 | ## Substrate, Subtensor and Metagraph 137 | self.substrate = SubstrateInterface( 138 | ss58_format=SS58_FORMAT, 139 | use_remote_preset=True, 140 | url=self.config.subtensor.chain_endpoint, 141 | type_registry=TYPE_REGISTRY, 142 | ) 143 | 144 | self.block_callbacks.append(self.maybe_sync_metagraph) 145 | self.substrate_thread = run_block_callback_thread( 146 | self.substrate, self.run_callbacks 147 | ) 148 | -------------------------------------------------------------------------------- /neurons/generator.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import sys 3 | import io 4 | import time 5 | import signal 6 | import traceback 7 | import argparse 8 | from pathlib import Path 9 | from PIL import Image 10 | from typing import List, Dict, Any 11 | import os 12 | import atexit 13 | 14 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 15 | os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" 16 | os.environ["TRANSFORMERS_VERBOSITY"] = "error" 17 | os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" 18 | os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" 19 | 20 | import warnings 21 | 22 | for module in ["diffusers", "transformers.tokenization_utils_base"]: 23 | warnings.filterwarnings("ignore", category=FutureWarning, module=module) 24 | 25 | import logging 26 | 27 | logging.getLogger("transformers").setLevel(logging.ERROR) 28 | logging.getLogger("diffusers").setLevel(logging.ERROR) 29 | logging.getLogger("torch").setLevel(logging.ERROR) 30 | logging.getLogger("datasets").setLevel(logging.ERROR) 31 | 32 | import transformers 33 | 34 | transformers.logging.set_verbosity_error() 35 | 36 | import bittensor as bt 37 | from bitmind.config import add_args, add_data_generator_args 38 | from bitmind.utils import ExitContext, get_metadata 39 | from bitmind.wandb_utils import init_wandb, clean_wandb_cache 40 | from bitmind.types import CacheConfig, MediaType, Modality 41 | from bitmind.cache.sampler import ImageSampler 42 | from bitmind.generation import ( 43 | GenerationPipeline, 44 | initialize_model_registry, 45 | ) 46 | 47 | 48 | class Generator: 49 | def __init__(self): 50 | self.exit_context = ExitContext() 51 | self.task = None 52 | self.generation_pipeline = None 53 | self.image_sampler = None 54 | 55 | self.setup_signal_handlers() 56 | atexit.register(self.cleanup) 57 | 58 | parser = argparse.ArgumentParser() 59 | bt.subtensor.add_args(parser) 60 | bt.wallet.add_args(parser) 61 | bt.logging.add_args(parser) 62 | add_data_generator_args(parser) 63 | add_args(parser) 64 | 65 | self.config = bt.config(parser) 66 | 67 | bt.logging(config=self.config, logging_dir=self.config.neuron.full_path) 68 | bt.logging.set_trace() 69 | if self.config.logging.debug: 70 | bt.logging.set_debug(True) 71 | if self.config.logging.trace: 72 | bt.logging.set_trace(True) 73 | 74 | bt.logging.success(self.config) 75 | wallet_configured = ( 76 | self.config.wallet.name is not None 77 | and self.config.wallet.hotkey is not None 78 | ) 79 | if wallet_configured and not self.config.wandb_off: 80 | try: 81 | self.wallet = bt.wallet(config=self.config) 82 | self.uid = ( 83 | bt.subtensor( 84 | config=self.config, network=self.config.subtensor.chain_endpoint 85 | ) 86 | .metagraph(self.config.netuid) 87 | .hotkeys.index(self.wallet.hotkey.ss58_address) 88 | ) 89 | self.wandb_dir = str(Path(__file__).parent.parent) 90 | clean_wandb_cache(self.wandb_dir) 91 | self.wandb_run = init_wandb( 92 | self.config.copy(), 93 | self.config.wandb.process_name, 94 | self.uid, 95 | self.wallet.hotkey, 96 | ) 97 | 98 | except Exception as e: 99 | bt.logging.error("Not registered, can't sign W&B run") 100 | bt.logging.error(e) 101 | self.config.wandb.off = True 102 | 103 | def setup_signal_handlers(self): 104 | signal.signal(signal.SIGINT, self.signal_handler) 105 | signal.signal(signal.SIGTERM, self.signal_handler) 106 | signal.signal(signal.SIGQUIT, self.signal_handler) 107 | 108 | def signal_handler(self, sig, frame): 109 | signal_name = signal.Signals(sig).name 110 | bt.logging.info(f"Received {signal_name}, initiating shutdown...") 111 | self.cleanup() 112 | sys.exit(0) 113 | 114 | def cleanup(self): 115 | if self.task and not self.task.done(): 116 | self.task.cancel() 117 | 118 | if self.generation_pipeline: 119 | try: 120 | bt.logging.trace("Shutting down generator...") 121 | self.generation_pipeline.shutdown() 122 | bt.logging.success("Generator shut down gracefully") 123 | except Exception as e: 124 | bt.logging.error(f"Error during generator shutdown: {e}") 125 | 126 | # Force cleanup of any GPU memory 127 | try: 128 | import torch 129 | 130 | if torch.cuda.is_available(): 131 | torch.cuda.empty_cache() 132 | bt.logging.trace("CUDA memory cache cleared") 133 | except Exception as e: 134 | pass 135 | 136 | async def wait_for_cache(self, timeout: int = 300): 137 | """Wait for the cache to be populated with images for prompt generation""" 138 | start = time.time() 139 | attempts = 0 140 | while True: 141 | if time.time() - start > timeout: 142 | return False 143 | 144 | available_count = self.image_sampler.get_available_count(use_index=False) 145 | if available_count > 0: 146 | return True 147 | 148 | await asyncio.sleep(10) 149 | if not attempts % 3: 150 | bt.logging.info("Waiting for images in cache...") 151 | attempts += 1 152 | 153 | async def sample_images(self, k: int = 1) -> List[Dict[str, Any]]: 154 | """Sample images from the cache""" 155 | result = await self.image_sampler.sample(k, remove_from_cache=False) 156 | if result["count"] == 0: 157 | raise ValueError("No images available in cache") 158 | 159 | # Convert bytes to PIL images 160 | for item in result["items"]: 161 | if isinstance(item["image"], bytes): 162 | item["image"] = Image.open(io.BytesIO(item["image"])) 163 | 164 | return result["items"] 165 | 166 | async def run(self): 167 | """Main generator loop""" 168 | try: 169 | cache_dir = self.config.cache_dir 170 | batch_size = self.config.batch_size 171 | device = self.config.device 172 | 173 | Path(cache_dir).mkdir(parents=True, exist_ok=True) 174 | 175 | self.image_sampler = ImageSampler( 176 | CacheConfig( 177 | modality=Modality.IMAGE.value, 178 | media_type=MediaType.REAL.value, 179 | base_dir=Path(cache_dir), 180 | ) 181 | ) 182 | 183 | await self.wait_for_cache() 184 | bt.logging.success("Cache populated. Proceeding to generation.") 185 | 186 | model_registry = initialize_model_registry() 187 | model_names = model_registry.get_interleaved_model_names(self.config.tasks) 188 | bt.logging.info(f"Starting generator") 189 | bt.logging.info(f"Tasks: {self.config.tasks}") 190 | bt.logging.info(f"Models: {model_names}") 191 | 192 | self.generation_pipeline = GenerationPipeline( 193 | output_dir=cache_dir, 194 | device=device, 195 | ) 196 | 197 | gen_count = 0 198 | batch_count = 0 199 | while not self.exit_context.isExiting: 200 | if asyncio.current_task().cancelled(): 201 | break 202 | 203 | try: 204 | image_samples = await self.sample_images(batch_size) 205 | bt.logging.info( 206 | f"Starting batch generation | Batch Size: {len(image_samples)} | Batch Count: {gen_count}" 207 | ) 208 | 209 | start_time = time.time() 210 | 211 | filepaths = self.generation_pipeline.generate( 212 | image_samples, model_names=model_names 213 | ) 214 | await asyncio.sleep(1) 215 | 216 | duration = time.time() - start_time 217 | gen_count += len(filepaths) 218 | batch_count += 1 219 | bt.logging.info( 220 | f"Generated {len(filepaths)} files in batch #{batch_count} in {duration:.2f} seconds" 221 | ) 222 | 223 | if not self.config.wandb.off: 224 | if batch_count >= self.config.wandb.num_batches_per_run: 225 | batch_count = 0 226 | self.wandb_run.finish() 227 | clean_wandb_cache(self.wandb_dir) 228 | self.wandb_run = init_wandb( 229 | self.config.copy(), 230 | self.config.wandb.process_name, 231 | self.uid, 232 | self.wallet.hotkey, 233 | ) 234 | 235 | except asyncio.CancelledError: 236 | bt.logging.info("Task cancelled, exiting loop") 237 | break 238 | except Exception as e: 239 | bt.logging.error(f"Error in batch processing: {e}") 240 | bt.logging.error(traceback.format_exc()) 241 | await asyncio.sleep(10) 242 | except Exception as e: 243 | bt.logging.error(f"Unhandled exception in main task: {e}") 244 | bt.logging.error(traceback.format_exc()) 245 | raise 246 | finally: 247 | self.cleanup() 248 | 249 | def start(self): 250 | """Start the generator""" 251 | loop = asyncio.get_event_loop() 252 | try: 253 | self.task = asyncio.ensure_future(self.run()) 254 | loop.run_until_complete(self.task) 255 | except KeyboardInterrupt: 256 | bt.logging.info("Generator interrupted by KeyboardInterrupt, shutting down") 257 | except Exception as e: 258 | bt.logging.error(f"Unhandled exception: {e}") 259 | bt.logging.error(traceback.format_exc()) 260 | finally: 261 | self.cleanup() 262 | 263 | 264 | if __name__ == "__main__": 265 | generator = Generator() 266 | generator.start() 267 | sys.exit(0) 268 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=64", "wheel", "pip>=21.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "bitmind" 7 | dynamic = ["version"] 8 | description = "SN34 on bittensor" 9 | authors = [ 10 | {name = "BitMind", email = "intern@bitmind.ai"} 11 | ] 12 | readme = "README.md" 13 | requires-python = ">=3.10" 14 | license = {text = ""} 15 | urls = {homepage = "http://bitmind.ai"} 16 | 17 | dependencies = [ 18 | "bittensor==9.3.0", 19 | "bittensor-cli==9.4.1", 20 | "pillow==10.4.0", 21 | "substrate-interface==1.7.11", 22 | "numpy==2.0.1", 23 | "pandas==2.2.3", 24 | "torch==2.5.1", 25 | "asyncpg==0.29.0", 26 | "httpcore==1.0.7", 27 | "httpx==0.28.1", 28 | "pyarrow==19.0.1", 29 | "ffmpeg-python==0.2.0", 30 | "bitsandbytes==0.45.4", 31 | "black==25.1.0", 32 | "pre-commit==4.2.0", 33 | "diffusers==0.33.1", 34 | "transformers==4.50.0", 35 | "scikit-learn==1.6.1", 36 | "av==14.2.0", 37 | "opencv-python==4.11.0.86", 38 | "wandb==0.19.9", 39 | "uvicorn==0.27.1", 40 | "python-multipart==0.0.20", 41 | "peft==0.15.0", 42 | "hf_xet==1.1.1" 43 | ] 44 | 45 | [tool.setuptools] 46 | packages = {find = {where = ["."], exclude = ["docs*", "wandb*", "*.egg-info"]}} 47 | 48 | [tool.setuptools.dynamic] 49 | version = {file = "VERSION"} -------------------------------------------------------------------------------- /requirements-git.txt: -------------------------------------------------------------------------------- 1 | janus @ git+https://github.com/deepseek-ai/Janus.git -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bittensor==9.3.0 2 | bittensor-cli==9.4.1 3 | pillow==10.4.0 4 | substrate-interface==1.7.11 5 | numpy==2.0.1 6 | pandas==2.2.3 7 | torch==2.5.1 8 | asyncpg==0.29.0 9 | httpcore==1.0.7 10 | httpx==0.28.1 11 | pyarrow==19.0.1 12 | ffmpeg-python==0.2.0 13 | bitsandbytes==0.45.4 14 | black==25.1.0 15 | pre-commit==4.2.0 16 | diffusers==0.33.1 17 | transformers==4.50.0 18 | scikit-learn==1.6.1 19 | av==14.2.0 20 | opencv-python==4.11.0.86 21 | wandb==0.19.9 22 | uvicorn==0.27.1 23 | python-multipart==0.0.20 24 | peft==0.15.0 -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ########################################### 3 | # System Updates and Package Installation # 4 | ########################################### 5 | 6 | # Update system 7 | sudo apt update -y 8 | 9 | # Install core dependencies 10 | sudo apt install -y \ 11 | python3-pip \ 12 | nano \ 13 | libgl1 \ 14 | ffmpeg \ 15 | unzip 16 | 17 | """ 18 | # Remove old nodejs and npm if present 19 | #sudo apt-get remove --purge -y nodejs npm 20 | 21 | # Install Node.js 20.x (LTS) from NodeSource for stability and universal standard 22 | # NOTE: Update the version here when a new LTS is released 23 | curl -fsSL https://deb.nodesource.com/setup_20.x | sudo -E bash - 24 | sudo apt-get install -y nodejs 25 | 26 | 27 | # Install build dependencies 28 | sudo apt install -y \ 29 | build-essential \ 30 | cmake \ 31 | libopenblas-dev \ 32 | liblapack-dev \ 33 | libx11-dev \ 34 | libgtk-3-dev 35 | 36 | # Install process manager (pm2) globally 37 | sudo npm install -g pm2@latest 38 | """ 39 | ############################ 40 | # Python Package Installation 41 | ############################ 42 | 43 | pip install --use-pep517 -e . -r requirements-git.txt 44 | -------------------------------------------------------------------------------- /start_miner.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ################################### 4 | # LOAD ENV FILE 5 | ################################### 6 | set -a 7 | source .env.miner 8 | set +a 9 | 10 | ################################### 11 | # PREPARE CLI ARGS 12 | ################################### 13 | if [[ "$CHAIN_ENDPOINT" == *"test"* ]]; then 14 | NETUID=168 15 | NETWORK="test" 16 | elif [[ "$CHAIN_ENDPOINT" == *"finney"* ]]; then 17 | NETUID=34 18 | NETWORK="finney" 19 | fi 20 | 21 | case "$LOGLEVEL" in 22 | "trace") 23 | LOG_PARAM="--logging.trace" 24 | ;; 25 | "debug") 26 | LOG_PARAM="--logging.debug" 27 | ;; 28 | "info") 29 | LOG_PARAM="--logging.info" 30 | ;; 31 | *) 32 | # Default to info if LOGLEVEL is not set or invalid 33 | LOG_PARAM="--logging.info" 34 | ;; 35 | esac 36 | 37 | # Set auto-update parameter based on AUTO_UPDATE 38 | FORCE_VPERMIT_PARAM="" 39 | if [ "$FORCE_VPERMIT" = false ]; then 40 | FORCE_VPERMIT_PARAM="--no-force-validator-permit" 41 | fi 42 | 43 | 44 | ################################### 45 | # RESTART PROCESSES 46 | ################################### 47 | NAME="bitmind-miner" 48 | 49 | # Stop any existing processes 50 | if pm2 list | grep -q "$NAME"; then 51 | echo "'$NAME' is already running. Deleting it..." 52 | pm2 delete $NAME 53 | fi 54 | 55 | echo "Starting $NAME | chain_endpoint: $CHAIN_ENDPOINT | netuid: $NETUID" 56 | 57 | # Run data generator 58 | pm2 start neurons/miner.py \ 59 | --interpreter python3 \ 60 | --name $NAME \ 61 | -- \ 62 | --wallet.name $WALLET_NAME \ 63 | --wallet.hotkey $WALLET_HOTKEY \ 64 | --netuid $NETUID \ 65 | --subtensor.chain_endpoint $CHAIN_ENDPOINT \ 66 | --axon.port $AXON_PORT \ 67 | --axon.external_ip $AXON_EXTERNAL_IP \ 68 | --device $DEVICE \ 69 | $FORCE_VPERMIT_PARAM 70 | 71 | -------------------------------------------------------------------------------- /start_validator.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ################################### 4 | # LOAD ENV FILE 5 | ################################### 6 | set -a 7 | source .env.validator 8 | set +a 9 | 10 | ################################### 11 | # LOG IN TO THIRD PARTY SERVICES 12 | ################################### 13 | # Login to Weights & Biases 14 | if ! wandb login $WANDB_API_KEY; then 15 | echo "Failed to login to Weights & Biases with the provided API key." 16 | exit 1 17 | fi 18 | echo "Logged into W&B with API key provided in .env.validator" 19 | 20 | # Login to Hugging Face 21 | if ! huggingface-cli login --token $HUGGING_FACE_TOKEN; then 22 | echo "Failed to login to Hugging Face with the provided token." 23 | exit 1 24 | fi 25 | echo "Logged into W&B with token provided in .env.validator" 26 | 27 | ################################### 28 | # PREPARE CLI ARGS 29 | ################################### 30 | : ${PROXY_PORT:=10913} 31 | : ${PROXY_EXTERNAL_PORT:=$PROXY_PORT} 32 | : ${DEVICE:=cuda} 33 | 34 | if [[ "$CHAIN_ENDPOINT" == *"test"* ]]; then 35 | NETUID=168 36 | NETWORK="test" 37 | elif [[ "$CHAIN_ENDPOINT" == *"finney"* ]]; then 38 | NETUID=34 39 | NETWORK="finney" 40 | fi 41 | 42 | case "$LOGLEVEL" in 43 | "trace") 44 | LOG_PARAM="--logging.trace" 45 | ;; 46 | "debug") 47 | LOG_PARAM="--logging.debug" 48 | ;; 49 | "info") 50 | LOG_PARAM="--logging.info" 51 | ;; 52 | *) 53 | # Default to info if LOGLEVEL is not set or invalid 54 | LOG_PARAM="--logging.info" 55 | ;; 56 | esac 57 | 58 | # Set auto-update parameter based on AUTO_UPDATE 59 | if [ "$AUTO_UPDATE" = true ]; then 60 | AUTO_UPDATE_PARAM="" 61 | else 62 | AUTO_UPDATE_PARAM="--autoupdate-off" 63 | fi 64 | 65 | if [ "$HEARTBEAT" = true ]; then 66 | HEARTBEAT_PARAM="--heartbeat" 67 | else 68 | HEARTBEAT_PARAM="" 69 | fi 70 | 71 | ################################### 72 | # STOP AND WAIT FOR CLEANUP 73 | ################################### 74 | VALIDATOR="sn34-validator" 75 | GENERATOR="sn34-generator" 76 | PROXY="sn34-proxy" 77 | 78 | # Stop any existing processes 79 | if pm2 list | grep -q "$VALIDATOR"; then 80 | echo "'$VALIDATOR' is already running. Deleting it..." 81 | pm2 delete $VALIDATOR 82 | sleep 1 83 | fi 84 | 85 | if pm2 list | grep -q "$GENERATOR"; then 86 | echo "'$GENERATOR' is already running. Deleting it..." 87 | pm2 delete $GENERATOR 88 | sleep 2 89 | fi 90 | 91 | if pm2 list | grep -q "$PROXY"; then 92 | echo "'$PROXY' is already running. Deleting it..." 93 | pm2 delete $PROXY 94 | sleep 1 95 | fi 96 | 97 | 98 | ################################### 99 | # START PROCESSES 100 | ################################### 101 | SN34_CACHE_DIR=$(eval echo "$SN34_CACHE_DIR") 102 | 103 | echo "Starting validator and generator | chain_endpoint: $CHAIN_ENDPOINT | netuid: $NETUID" 104 | 105 | # Run data generator 106 | pm2 start neurons/generator.py \ 107 | --interpreter python3 \ 108 | --kill-timeout 2000 \ 109 | --name $GENERATOR \ 110 | -- \ 111 | --wallet.name $WALLET_NAME \ 112 | --wallet.hotkey $WALLET_HOTKEY \ 113 | --netuid $NETUID \ 114 | --subtensor.chain_endpoint $CHAIN_ENDPOINT \ 115 | --cache-dir $SN34_CACHE_DIR \ 116 | --device $DEVICE 117 | 118 | # Run validator 119 | pm2 start neurons/validator.py \ 120 | --interpreter python3 \ 121 | --kill-timeout 1000 \ 122 | --name $VALIDATOR \ 123 | -- \ 124 | --wallet.name $WALLET_NAME \ 125 | --wallet.hotkey $WALLET_HOTKEY \ 126 | --netuid $NETUID \ 127 | --subtensor.chain_endpoint $CHAIN_ENDPOINT \ 128 | --epoch-length 360 \ 129 | --cache-dir $SN34_CACHE_DIR \ 130 | --proxy.port $PROXY_PORT \ 131 | $LOG_PARAM \ 132 | $AUTO_UPDATE_PARAM \ 133 | $HEARTBEAT_PARAM 134 | 135 | # Run validator proxy 136 | pm2 start neurons/proxy.py \ 137 | --interpreter python3 \ 138 | --kill-timeout 1000 \ 139 | --name $PROXY \ 140 | -- \ 141 | --wallet.name $WALLET_NAME \ 142 | --wallet.hotkey $WALLET_HOTKEY \ 143 | --netuid $NETUID \ 144 | --subtensor.chain_endpoint $CHAIN_ENDPOINT \ 145 | --proxy.port $PROXY_PORT \ 146 | --proxy.external_port $PROXY_EXTERNAL_PORT \ 147 | $LOG_PARAM 148 | --------------------------------------------------------------------------------