├── .DS_Store ├── LICENSE ├── README.md ├── docs ├── bittaudio.jpg ├── miner.md └── validator.md ├── lib ├── __init__.py └── hashing.py ├── min_compute.yml ├── neurons ├── __init__.py ├── miner.py └── validator.py ├── requirements.txt ├── scripts ├── check_compatibility.sh ├── check_requirements_changes.sh └── start_valid.py ├── setup.py └── ttm ├── aimodel.py ├── protocol.py ├── ttm.py └── ttm_score.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UncleTensor/BittAudio/d964362e50733ce185bcbacdd1c172aa26121f1b/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Muhammad Farhan Aslam 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BittAudio (SN 50) | Audio Generation Subnet on Bittensor 2 | ![bitaudio](docs/bittaudio.jpg) 3 | The main goal of the BittAudio is to establish a decentralized platform that incentivizes the creation, distribution and also monetization of AI audio content, such as: 4 | - Text-to-Music (TTM)
5 | 6 | Validators and miners work together to ensure high-quality outputs, fostering innovation and rewarding contributions in the audio domain.
7 | By introducing audio generation service such as Text-to-Music, this subnetwork expands the range of available service within the Bittensor ecosystem. This diversification enhances the utility and appeal of the Bittensor platform to a broader audience, including creators, influencers, developers, and end-users interested in audio content.

8 | 9 | ## Validators & Miners Interaction 10 | - Validators initiate requests filled with the required data and encrypt them with a symmetric key. 11 | - Requests are signed with the validator’s private key to certify authenticity. 12 | - Miners decrypt the requests, verify the signatures to ensure authenticity, process the requests, and then send back the results, encrypted and signed for security. 13 | 14 | **Validators** are responsible for initiating the generation process by providing prompts to the Miners on the network. These prompts serve as the input for TTM service. The Validators then evaluate the quality of the generated audio and reward the Miners based on the output quality.
15 | Please refer to the [Validator Documentation](docs/validator.md) 16 | 17 | **Miners** in the Audio Subnetwork are tasked with generating audio from the text prompts received from the Validators. Leveraging advanced TTM models, miners aim to produce high-fidelity music melodies. The quality of the generated audio is crucial, as it directly influences the miners' rewards.
18 | Please refer to the [Miner Documentation](docs/miner.md) 19 | 20 | ## Workflow 21 | 22 | 1. **Prompt Generation:** The Validators generates TTM prompts and distributes them to the Miners on the network. 23 | 24 | 2. **Audio Processing:** Miners receive the prompts and utilize TTM models to convert the text into audio (music). 25 | 26 | 3. **Quality Evaluation:** The Validator assesses the quality of the generated audio, considering factors such as: clarity, naturalness, and adherence to the prompt. 27 | 28 | 4. **Reward Distribution:** Based on the quality assessment, the Validator rewards Miners accordingly. Miners with consistently higher-quality outputs receive a larger share of rewards. 29 | 30 | ## Benefits 31 | 32 | - **Decentralized Text-to-Audio:** The subnetwork decentralizes the Text-to-Music process, distributing the workload among participating Miners. 33 | 34 | - **Quality Incentives:** The incentive mechanism encourages Miners to continually improve the quality of their generated audio. 35 | 36 | - **Bittensor Network Integration:** Leveraging the Bittensor network ensures secure and transparent interactions between Validators and Miners. 37 | 38 | Join BittAudio and contribute to the advancement of decentralized Text-to-Music technology within the Bittensor ecosystem. 39 | 40 | ## SOTA Benchmarking for Audio Evaluation 41 | 42 | To ensure the quality of audio generated by miners, we use a SOTA (State-of-the-Art) benchmarking process. Validators download sound files from the following link: 43 | https://huggingface.co/datasets/etechgrid/ttm-validation-dataset 44 | 45 | The output generated by miners is evaluated using three different metrics: 46 | 47 | 1. **Kullback-Leibler Divergence (KLD):** Measures the divergence between two probability distributions, allowing us to assess the distribution similarity between generated music and the original distribution of audio data. 48 | 49 | 2. **Frechet Audio Distance (FAD):** Calculates the difference between the statistical distribution of generated audio and real-world audio. This metric provides a robust evaluation of the overall quality, including the structure and timbre of generated music. 50 | 51 | 3. **CLAP Metric:** Evaluates text consistency by determining how well the generated audio adheres to the original text prompt. It measures the semantic alignment between the input text and the generated output. 52 | 53 | These metrics provide a comprehensive evaluation of the audio quality, and this new system is more robust compared to the previous metrics, which were a combination of CLAP, SNR (Signal-to-Noise Ratio), and HNR (Harmonic-to-Noise Ratio). 54 | 55 | ### Benchmark Milestone: 56 | | Model | Frechet Audio Distance | KLD | Text Consistency | Chroma Cosine Similarity | 57 | |--------------------------|------------------------|------|------------------|--------------------------| 58 | | facebook/musicgen-small | 4.88 | 1.42 | 0.27 | - | 59 | | facebook/musicgen-medium | 5.14 | 1.38 | 0.28 | - | 60 | | facebook/musicgen-large | 5.48 | 1.37 | 0.28 | - | 61 | | facebook/musicgen-melody | 4.93 | 1.41 | 0.27 | 0.44 | 62 | 63 | For more information on these models, you can visit: 64 | https://huggingface.co/facebook/musicgen-small 65 | 66 | ## Installation 67 | ```bash 68 | git clone https://github.com/UncleTensor/BittAudio.git 69 | cd BittAudio 70 | pip install -e . 71 | pip install -r requirements.txt 72 | wandb login 73 | ``` 74 | 75 | ## Recommended GPU Configuration 76 | 77 | It is recommended to use NVIDIA GeForce RTX A6000 GPUs at minimum for both Validators and Miners. 78 | 79 | 80 | **Evaluation Mechanism:**
81 | The evaluation mechanism involves the Validators querying miners on the network with random prompts and receiving TTM responses. These responses are scored based on correctness, and the weights on the Bittensor network are updated accordingly. The scoring is conducted using a reward function from the lib module. 82 | 83 | **Miner/Validator Hardware Specs:**
84 | The hardware requirements for miners and validators vary depending on the complexity and resource demands of the selected TTM models. Typically, a machine equipped with a capable CPU and GPU, along with sufficient VRAM and RAM, is necessary. The amount of disk space required will depend on the size of the models and any additional data. 85 | 86 | **How to Run a Validator:**
87 | To operate a validator, you need to run the validator.py script with the required command-line arguments. This script initiates the setup of Bittensor objects, establishes a connection to the network, queries miners, scores their responses, and updates weights accordingly. 88 | 89 | **How to Run a Miner:**
90 | To operate a miner, run the miner.py script with the necessary configuration. This process involves initializing Bittensor objects, establishing a connection to the network, and processing incoming TTM requests. 91 | 92 | **TTM Models Supported:**
93 | The code incorporates various Text-to-Music models. The specific requirements for each model, including CPU, GPU VRAM, RAM, and disk space, are not explicitly stated in the provided code. For these type of requirements, it may be necessary to consult the documentation or delve into the implementation details of these models. 94 | 95 | In general, the resource demands of TTM models can vary significantly. Larger models often necessitate more powerful GPUs and additional system resources. It is advisable to consult the documentation or model repository for the specific requirements of each model. Additionally, if GPU acceleration is employed, having a compatible GPU with enough VRAM is typically advantageous for faster processing. 96 | 97 | ## License 98 | This repository is licensed under the MIT License. 99 | 100 | ```text 101 | MIT License 102 | 103 | Copyright (c) 2024 Opentensor 104 | 105 | Permission is hereby granted, free of charge, to any person obtaining a copy 106 | of this software and associated documentation files (the "Software"), to deal 107 | in the Software without restriction, including without limitation the rights 108 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 109 | copies of the Software, and to permit persons to whom the Software is 110 | furnished to do so, subject to the following conditions: 111 | 112 | The above copyright notice and this permission notice shall be included in all 113 | copies or substantial portions of the Software. 114 | 115 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 116 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 117 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 118 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 119 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 120 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 121 | SOFTWARE. 122 | 123 | ``` 124 | -------------------------------------------------------------------------------- /docs/bittaudio.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UncleTensor/BittAudio/d964362e50733ce185bcbacdd1c172aa26121f1b/docs/bittaudio.jpg -------------------------------------------------------------------------------- /docs/miner.md: -------------------------------------------------------------------------------- 1 | # Audio Generation Subnetwork Miner Guide 2 | Welcome to the Miner's guide for the Audio Generation Subnetwork within the Bittensor network. This document provides instructions for setting up and running a Miner node in the network. 3 | 4 | ## Overview 5 | Miners in the Audio Subnetwork are responsible for generating audio from text prompts received from Validators. Utilizing advanced text-to-music models, miners aim to produce high-fidelity, natural-sounding music. The quality of the generated audio directly influences the rewards miners receive. 6 | 7 | ## Installation 8 | Follow these steps to install the necessary components: 9 | 10 | **Set Conda Enviornment** 11 | ```bash 12 | mkdir -p ~/miniconda3 13 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh 14 | bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3 15 | rm -rf ~/miniconda3/miniconda.sh 16 | ~/miniconda3/bin/conda init bash 17 | ~/miniconda3/bin/conda init zsh 18 | conda create -n {conda-env} python=3.10 -y 19 | conda activate {conda-env} 20 | ``` 21 | **Install Repo** 22 | ```bash 23 | git clone https://github.com/UncleTensor/BittAudio.git 24 | cd BittAudio 25 | pip install -e . 26 | pip install -r requirements.txt 27 | ``` 28 | **Install pm2** 29 | ```bash 30 | sudo apt install nodejs npm 31 | sudo npm install pm2 -g 32 | ``` 33 | 34 | ### Recommended GPU Configuration 35 | - NVIDIA GeForce RTX A6000 GPUs are recommended for optimal performance. 36 | 37 | ### Running a Miner 38 | - To operate a miner, run the miner.py script with the necessary configuration. 39 | 40 | ### Miner Commands 41 | ```bash 42 | pm2 start neurons/miner.py -- \ 43 | --netuid 50 \ 44 | --wallet.name {wallet_name} \ 45 | --wallet.hotkey {hotkey_name} \ 46 | --logging.trace \ 47 | --music_path {ttm-model} \ 48 | --axon.port {machine_port} 49 | ``` 50 | 51 | ### Bittensor Miner Script Arguments: 52 | 53 | | **Category** | **Argument** | **Default Value** | **Description** | 54 | |---------------------------------|--------------------------------------|----------------------------|-----------------------------------------------------------------------------------------------------------------------| 55 | | **Text To Music Model** | `--music_model` | 'facebook/musicgen-medium' ; 'facebook/musicgen-large' | The model to use for Text-To-Music | 56 | | **Music Finetuned Model** | `--music_path` | /path/to/model | The model to use for Text-To-Music | 57 | | **Network UID** | `--netuid` | Mainnet: 50 | The chain subnet UID. | 58 | | **Bittensor Subtensor Arguments** | `--subtensor.chain_endpoint` | - | Endpoint for Bittensor chain connection.| 59 | | | `--subtensor.network` | - | Bittensor network endpoint.| 60 | | **Bittensor Logging Arguments** | `--logging.debug` | - | Enable debugging logs.| 61 | | **Bittensor Wallet Arguments** | `--wallet.name` | - | Name of the wallet.| 62 | | | `--wallet.hotkey` | - | Hotkey path for the wallet.| 63 | | **Bittensor Axon Arguments** | `--axon.port` | - | Port number for the axon server.| 64 | | **PM2 process name** | `--pm2_name` | 'SN50Miner' | Name for the pm2 process for Auto Update. | 65 | 66 | 67 | 68 | 69 | 70 | ### License 71 | Refer to the main README for the MIT License details. 72 | -------------------------------------------------------------------------------- /docs/validator.md: -------------------------------------------------------------------------------- 1 | # Audio Generation Subnetwork Validator Guide 2 | 3 | Welcome to the Validator's guide for the Audio Generation Subnetwork within the Bittensor network. This document provides instructions for setting up and running a Validator node in the network. 4 | 5 | ## Overview 6 | Validators initiate the audio generation process by providing prompts to the Miners and evaluate the quality of the generated audio. They play a crucial role in maintaining the quality standards of the network. The prompts will be generated with the help of the Corcel API, Product by Subnet 18, which provides a infinite range of prompts for Text-To-Music. 7 | 8 | ## Installation 9 | Follow these steps to install the necessary components: 10 | 11 | **Set Conda Enviornment** 12 | ```bash 13 | mkdir -p ~/miniconda3 14 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh 15 | bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3 16 | rm -rf ~/miniconda3/miniconda.sh 17 | ~/miniconda3/bin/conda init bash 18 | bash 19 | ~/miniconda3/bin/conda init zsh 20 | conda create -n {conda-env} python=3.10 -y 21 | conda activate {conda-env} 22 | ``` 23 | **Install Repo** 24 | ```bash 25 | sudo apt update 26 | sudo apt install build-essential -y 27 | git clone https://github.com/UncleTensor/BittAudio.git 28 | cd BittAudio 29 | pip install -e. 30 | pip install audiocraft 31 | pip install laion_clap==1.1.4 32 | pip install git+https://github.com/haoheliu/audioldm_eval 33 | pip install git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt 34 | sudo mkdir -p /tmp/music 35 | wandb login 36 | ``` 37 | **Install pm2** 38 | ```bash 39 | sudo apt install nodejs npm 40 | sudo npm install pm2 -g 41 | ``` 42 | 43 | ## Validator Command for Auto Update 44 | ```bash 45 | pm2 start scripts/start_valid.py -- \ 46 | --pm2_name {name} \ 47 | --netuid 50 \ 48 | --wallet.name {wallet_name} \ 49 | --wallet.hotkey {hotkey_name} \ 50 | --subtensor.network {finney} 51 | ``` 52 | 53 | ### Bittensor Validator Script Arguments: 54 | 55 | | **Category** | **Argument** | **Default Value** | **Description** | 56 | |---------------------------------|--------------------------------------|----------------------------|-----------------------------------------------------------------------------------------------------------------------| 57 | | **Configuration Arguments** | `--alpha` | 0.9 | The weight moving average scoring. | 58 | | | `--netuid` | Mainnet: 50 | The chain subnet UID. | 59 | | **Bittensor Subtensor Arguments** | `--subtensor.chain_endpoint` | - | Endpoint for Bittensor chain connection. | 60 | | | `--subtensor.network` | - | Bittensor network endpoint. | 61 | | **Bittensor Logging Arguments** | `--logging.debug` | - | Enable debugging logs. | 62 | | **Bittensor Wallet Arguments** | `--wallet.name` | - | Name of the wallet. | 63 | | | `--wallet.hotkey` | - | Hotkey path for the wallet. | 64 | | **PM2 process name** | `--pm2_name` | 'SN50Miner' | Name for the pm2 process for Auto Update. | 65 | 66 | ## Miners logs in Validator 67 | 68 | If it is required to check miners logs in the validator, one can go ahead to ~/.pm2/logs directory and grep the miners scoring logs 69 | as follows: 70 | 71 | sudo grep -a -A 10 "Raw score for hotkey:5DXTGaAQm99AEAvhMRqWQ77b1aob4mAXwX" ~/.pm2/logs/validator-out.log 72 | 73 | sudo grep -a -A 10 "Normalized score for hotkey:5DXTGaAQm99AEAvhMRqWQ77b1aob4mAXwX" ~/.pm2/logs/validator-out.log 74 | 75 | ### License 76 | Refer to the main README for the MIT License details. 77 | 78 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.1.3" 2 | version_split = __version__.split(".") 3 | __spec_version__ = ( 4 | (1000 * int(version_split[0])) 5 | + (10 * int(version_split[1])) 6 | + (1 * int(version_split[2])) 7 | ) 8 | 9 | MIN_STAKE = 10000 -------------------------------------------------------------------------------- /lib/hashing.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import json 4 | from datetime import datetime 5 | import bittensor as bt 6 | 7 | hashes_file = 'music_hashes.json' 8 | 9 | # In-memory cache for fast lookup 10 | cache = set() 11 | 12 | def calculate_audio_hash(audio_data: bytes) -> str: 13 | """Calculate a SHA256 hash of the given audio data.""" 14 | return hashlib.sha256(audio_data).hexdigest() 15 | 16 | def load_hashes_to_cache(): 17 | """Load existing hashes from JSON file into in-memory cache.""" 18 | if os.path.exists(hashes_file): 19 | with open(hashes_file, 'r') as file: 20 | data = json.load(file) 21 | for entry in data: 22 | cache.add(entry['hash']) # Add hash to in-memory cache 23 | 24 | def save_hash_to_file(hash_value: str, timestamp: str, miner_id: str = None): 25 | """Save the new hash to the JSON file and in-memory cache.""" 26 | cache.add(hash_value) # Add to cache for fast lookups 27 | if os.path.exists(hashes_file): 28 | with open(hashes_file, 'r+') as file: 29 | data = json.load(file) 30 | data.append({'hash': hash_value, 'miner_id': miner_id, 'timestamp': timestamp}) 31 | file.seek(0) 32 | json.dump(data, file) 33 | else: 34 | # If the file doesn't exist, create it with the initial hash entry 35 | with open(hashes_file, 'w') as file: 36 | json.dump([{'hash': hash_value, 'miner_id': miner_id, 'timestamp': timestamp}], file) 37 | 38 | 39 | def check_duplicate_music(hash_value: str) -> bool: 40 | """Check if the given hash already exists in the in-memory cache.""" 41 | return hash_value in cache 42 | 43 | def process_miner_music(miner_id: str, audio_data: bytes): 44 | """Process music sent by a miner and check for duplicates.""" 45 | audio_hash = calculate_audio_hash(audio_data) # Calculate the audio hash 46 | 47 | if check_duplicate_music(audio_hash): # Check if it's a duplicate 48 | bt.logging.info(f"Duplicate music detected from miner: {miner_id}") 49 | return # Do nothing if it's a duplicate 50 | else: 51 | timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") 52 | save_hash_to_file(audio_hash, miner_id, timestamp) # Save the hash to the file and cache 53 | bt.logging.info(f"Music processed and saved successfully for miner: {miner_id}") 54 | return audio_hash 55 | -------------------------------------------------------------------------------- /min_compute.yml: -------------------------------------------------------------------------------- 1 | # Use this document to specify the minimum compute requirements. 2 | # This document will be used to generate a list of recommended hardware for your subnet. 3 | 4 | # This is intended to give a rough estimate of the minimum requirements 5 | # so that the user can make an informed decision about whether or not 6 | # they want to run a miner or validator on their machine. 7 | 8 | # NOTE: Specification for miners may be different from validators 9 | 10 | version: '1.0' # update this version key as needed, ideally should match your release version 11 | 12 | compute_spec: 13 | 14 | miner: 15 | 16 | cpu: 17 | min_cores: 8 # Minimum number of CPU cores 18 | min_speed: 3.0GHz # Minimum speed per core 19 | architecture: x86_64 # Architecture type (e.g., x86_64, arm64) 20 | 21 | gpu: 22 | required: true # Does the application require a GPU? 23 | min_vram: 24GB # Minimum GPU VRAM 24 | cuda_cores: 1024 # Minimum number of CUDA cores (if applicable) 25 | min_compute_capability: 6.0 # Minimum CUDA compute capability 26 | recommended_gpu: "NVIDIA A6000" # provide a recommended GPU to purchase/rent 27 | 28 | memory: 29 | min_ram: 32GB # Minimum RAM 30 | min_swap: 4GB # Minimum swap space 31 | ram_type: "DDR4" # RAM type (e.g., DDR4, DDR3, etc.) 32 | 33 | storage: 34 | min_space: 100GB # Minimum free storage space 35 | type: SSD # Preferred storage type (e.g., SSD, HDD) 36 | iops: 1000 # Minimum I/O operations per second (if applicable) 37 | 38 | os: 39 | name: Ubuntu # Name of the preferred operating system(s) 40 | version: "20.04" # Version of the preferred operating system(s) 41 | 42 | validator: 43 | 44 | cpu: 45 | min_cores: 8 # Minimum number of CPU cores 46 | min_speed: 3.0GHz # Minimum speed per core 47 | architecture: x86_64 # Architecture type (e.g., x86_64, arm64) 48 | 49 | gpu: 50 | required: true # Does the application require a GPU? 51 | min_vram: 24GB # Minimum GPU VRAM 52 | cuda_cores: 1024 # Minimum number of CUDA cores (if applicable) 53 | min_compute_capability: 6.0 # Minimum CUDA compute capability 54 | recommended_gpu: "NVIDIA A6000" # provide a recommended GPU to purchase/rent 55 | 56 | memory: 57 | min_ram: 32GB # Minimum RAM 58 | min_swap: 4GB # Minimum swap space 59 | ram_type: "DDR4" # RAM type (e.g., DDR4, DDR3, etc.) 60 | 61 | storage: 62 | min_space: 100GB # Minimum free storage space 63 | type: SSD # Preferred storage type (e.g., SSD, HDD) 64 | iops: 1000 # Minimum I/O operations per second (if applicable) 65 | 66 | os: 67 | name: Ubuntu # Name of the preferred operating system(s) 68 | version: ">=20.04" # Version of the preferred operating system(s) 69 | 70 | network_spec: 71 | bandwidth: 72 | download: ">=200Mbps" # Minimum download bandwidth 73 | upload: ">=100Mbps" # Minimum upload bandwidth 74 | -------------------------------------------------------------------------------- /neurons/__init__.py: -------------------------------------------------------------------------------- 1 | from . import validator 2 | from . import miner -------------------------------------------------------------------------------- /neurons/miner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import lib 4 | import time 5 | import torch 6 | import typing 7 | import argparse 8 | import traceback 9 | import torchaudio 10 | import bittensor as bt 11 | import ttm.protocol as protocol 12 | # from ttm.protocol import MusicGeneration 13 | from scipy.io.wavfile import write as write_wav 14 | from transformers import AutoProcessor, MusicgenForConditionalGeneration 15 | 16 | # Set the project root path 17 | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 18 | audio_subnet_path = os.path.abspath(project_root) 19 | 20 | # Add the project root and 'AudioSubnet' directories to sys.path 21 | sys.path.insert(0, project_root) 22 | sys.path.insert(0, audio_subnet_path) 23 | 24 | class MusicGenerator: 25 | def __init__(self, model_path="facebook/musicgen-medium"): 26 | """Initializes the MusicGenerator with a specified model path.""" 27 | self.model_name = model_path 28 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 29 | 30 | # Load the processor and model 31 | self.processor = AutoProcessor.from_pretrained(self.model_name) 32 | self.model = MusicgenForConditionalGeneration.from_pretrained(self.model_name).to(self.device) 33 | 34 | def generate_music(self, prompt, token): 35 | """Generates music based on a given prompt and token count.""" 36 | try: 37 | inputs = self.processor(text=[prompt], padding=True, return_tensors="pt").to(self.device) 38 | audio_values = self.model.generate(**inputs, max_new_tokens=token) 39 | return audio_values[0, 0].cpu().numpy() 40 | except Exception as e: 41 | print(f"Error occurred with {self.model_name}: {e}") 42 | return None 43 | 44 | # Configuration setup 45 | def get_config(): 46 | parser = argparse.ArgumentParser() 47 | 48 | # Add model selection for Text-to-Music 49 | parser.add_argument("--music_model", default='facebook/musicgen-medium', help="The model to be used for Music Generation.") 50 | parser.add_argument("--music_path", default=None, help="Path to a custom finetuned model for Music Generation.") 51 | 52 | # Add Bittensor specific arguments 53 | parser.add_argument("--netuid", type=int, default=50, help="The chain subnet uid.") 54 | bt.subtensor.add_args(parser) 55 | bt.logging.add_args(parser) 56 | bt.wallet.add_args(parser) 57 | bt.axon.add_args(parser) 58 | 59 | config = bt.config(parser) 60 | 61 | # Set up logging paths 62 | config.full_path = os.path.expanduser(f"{config.logging.logging_dir}/{config.wallet.name}/{config.wallet.hotkey}/netuid{config.netuid}/miner") 63 | 64 | # Ensure the logging directory exists 65 | if not os.path.exists(config.full_path): 66 | os.makedirs(config.full_path, exist_ok=True) 67 | 68 | return config 69 | 70 | # Main function 71 | def main(config): 72 | bt.logging(config=config, logging_dir=config.full_path) 73 | bt.logging.info(f"Running TTM miner for subnet: {config.netuid} on network: {config.subtensor.chain_endpoint}") 74 | 75 | # Text-to-Music Model Setup 76 | try: 77 | if config.music_path: 78 | bt.logging.info(f"Using custom model for Text-To-Music from: {config.music_path}") 79 | ttm_models = MusicGenerator(model_path=config.music_path) 80 | elif config.music_model in ["facebook/musicgen-medium", "facebook/musicgen-large"]: 81 | bt.logging.info(f"Using Text-To-Music model: {config.music_model}") 82 | ttm_models = MusicGenerator(model_path=config.music_model) 83 | else: 84 | bt.logging.error(f"Invalid music model: {config.music_model}") 85 | exit(1) 86 | except Exception as e: 87 | bt.logging.error(f"Error initializing Text-To-Music model: {e}") 88 | exit(1) 89 | 90 | # Bittensor object setup 91 | wallet = bt.wallet(config=config) 92 | bt.logging.info(f"Wallet: {wallet}") 93 | 94 | subtensor = bt.subtensor(config=config) 95 | bt.logging.info(f"Subtensor: {subtensor}") 96 | 97 | metagraph = subtensor.metagraph(config.netuid) 98 | bt.logging.info(f"Metagraph: {metagraph}") 99 | 100 | if wallet.hotkey.ss58_address not in metagraph.hotkeys: 101 | bt.logging.error("Miner not registered. Run btcli register and try again.") 102 | exit() 103 | 104 | # Check the miner's subnet UID 105 | my_subnet_uid = metagraph.hotkeys.index(wallet.hotkey.ss58_address) 106 | 107 | ######################## Text to Music Processing ######################## 108 | 109 | def music_blacklist_fn(synapse: protocol.MusicGeneration) -> typing.Tuple[bool, str]: 110 | if synapse.dendrite.hotkey not in metagraph.hotkeys: 111 | bt.logging.trace(f"Blacklisting unrecognized hotkey {synapse.dendrite.hotkey}") 112 | return True, "Unrecognized hotkey" 113 | elif synapse.dendrite.hotkey in metagraph.hotkeys and metagraph.S[metagraph.hotkeys.index(synapse.dendrite.hotkey)] < lib.MIN_STAKE: 114 | # Ignore requests from entities with low stake. 115 | bt.logging.trace( 116 | f"Blacklisting hotkey {synapse.dendrite.hotkey} with low stake" 117 | ) 118 | return True, "Low stake" 119 | else: 120 | return False, "Accepted" 121 | 122 | # The priority function determines the request handling order. 123 | def music_priority_fn(synapse: protocol.MusicGeneration) -> float: 124 | caller_uid = metagraph.hotkeys.index(synapse.dendrite.hotkey) 125 | priority = float(metagraph.S[caller_uid]) 126 | bt.logging.trace(f"Prioritizing {synapse.dendrite.hotkey} with stake: {priority}") 127 | return priority 128 | 129 | def convert_music_to_tensor(audio_file): 130 | """Convert the audio file to a tensor.""" 131 | try: 132 | _, file_extension = os.path.splitext(audio_file) 133 | if file_extension.lower() in ['.wav', '.mp3']: 134 | audio, sample_rate = torchaudio.load(audio_file) 135 | return audio[0].tolist() # Convert to tensor/list 136 | else: 137 | bt.logging.error(f"Unsupported file format: {file_extension}") 138 | return None 139 | except Exception as e: 140 | bt.logging.error(f"Error converting file: {e}") 141 | 142 | def ProcessMusic(synapse: protocol.MusicGeneration) -> protocol.MusicGeneration: 143 | bt.logging.info(f"Generating music with model: {config.music_path if config.music_path else config.music_model}") 144 | print(f"synapse.text_input: {synapse.text_input}") 145 | print(f"synapse.duration: {synapse.duration}") 146 | music = ttm_models.generate_music(synapse.text_input, synapse.duration) 147 | 148 | if music is None: 149 | bt.logging.error("No music generated!") 150 | return None 151 | try: 152 | sampling_rate = 32000 153 | write_wav("random_sample.wav", rate=sampling_rate, data=music) 154 | bt.logging.success("Music generated and saved to random_sample.wav") 155 | music_tensor = convert_music_to_tensor("random_sample.wav") 156 | synapse.music_output = music_tensor 157 | return synapse 158 | except Exception as e: 159 | bt.logging.error(f"Error processing music output: {e}") 160 | return None 161 | 162 | ######################## Attach Axon and Serve ######################## 163 | 164 | axon = bt.axon(wallet=wallet, config=config) 165 | bt.logging.info(f"Axon {axon}") 166 | 167 | # Attach forward function for TTM processing 168 | axon.attach( 169 | forward_fn=ProcessMusic, 170 | blacklist_fn=music_blacklist_fn, 171 | priority_fn=music_priority_fn, 172 | ) 173 | 174 | # Serve the axon on the network 175 | bt.logging.info(f"Serving axon on network: {config.subtensor.chain_endpoint} with netuid: {config.netuid}") 176 | axon.serve(netuid=config.netuid, subtensor=subtensor) 177 | 178 | # Start the miner's axon 179 | bt.logging.info(f"Starting axon server on port: {config.axon.port}") 180 | axon.start() 181 | 182 | # Keep the miner running 183 | bt.logging.info("Starting main loop") 184 | step = 0 185 | while True: 186 | try: 187 | # Periodically update knowledge of the network graph 188 | if step % 500 == 0: 189 | metagraph = subtensor.metagraph(config.netuid) 190 | log = ( 191 | f"Step:{step} | " 192 | f"Block:{metagraph.block.item()} | " 193 | f"Stake:{metagraph.S[my_subnet_uid]:.6f} | " 194 | f"Rank:{metagraph.R[my_subnet_uid]:.6f} | " 195 | f"Trust:{metagraph.T[my_subnet_uid]} | " 196 | f"Consensus:{metagraph.C[my_subnet_uid]:.6f} | " 197 | f"Incentive:{metagraph.I[my_subnet_uid]:.6f} | " 198 | f"Emission:{metagraph.E[my_subnet_uid]}" 199 | ) 200 | bt.logging.info(log) 201 | step += 1 202 | time.sleep(1) 203 | 204 | # Stop the miner safely 205 | except KeyboardInterrupt: 206 | axon.stop() 207 | break 208 | 209 | # Log any unexpected errors 210 | except Exception as e: 211 | bt.logging.error(f"unexpected error",traceback.format_exc()) 212 | continue 213 | 214 | # Entry point 215 | if __name__ == "__main__": 216 | config = get_config() 217 | main(config) -------------------------------------------------------------------------------- /neurons/validator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import asyncio 4 | import datetime as dt 5 | import wandb 6 | import bittensor as bt 7 | import uvicorn 8 | from pyngrok import ngrok # Import ngrok from pyngrok 9 | 10 | # Set the project root path 11 | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 12 | # Set the 'AudioSubnet' directory path 13 | audio_subnet_path = os.path.abspath(project_root) 14 | 15 | # Add the project root and 'AudioSubnet' directories to sys.path 16 | sys.path.insert(0, project_root) 17 | sys.path.insert(0, audio_subnet_path) 18 | 19 | from ttm.ttm import MusicGenerationService 20 | from ttm.aimodel import AIModelService 21 | 22 | 23 | class AIModelController(): 24 | def __init__(self): 25 | self.aimodel = AIModelService() 26 | self.music_generation_service = MusicGenerationService() 27 | self.current_service = self.music_generation_service 28 | self.last_run_start_time = dt.datetime.now() 29 | 30 | async def run_fastapi_with_ngrok(self, app): 31 | # Setup ngrok tunnel 32 | ngrok_tunnel = ngrok.connect(38287, bind_tls=True) 33 | print('Public URL:', ngrok_tunnel.public_url) 34 | # Create and start the uvicorn server as a background task 35 | config = uvicorn.Config(app=app, host="0.0.0.0", port=38287) # Ensure port matches ngrok's 36 | server = uvicorn.Server(config) 37 | # No need to await here, as we want this to run in the background 38 | task = asyncio.create_task(server.serve()) 39 | return ngrok_tunnel, task # Returning task if you need to cancel it later 40 | 41 | 42 | async def run_services(self): 43 | while True: 44 | self.check_and_update_wandb_run() 45 | await self.music_generation_service.run_async() 46 | 47 | def check_and_update_wandb_run(self): 48 | # Calculate the time difference between now and the last run start time 49 | current_time = dt.datetime.now() 50 | time_diff = current_time - self.last_run_start_time 51 | # Check if 4 hours have passed since the last run start time 52 | if time_diff.total_seconds() >= 4 * 3600: # 4 hours * 3600 seconds/hour 53 | self.last_run_start_time = current_time # Update the last run start time to now 54 | if self.wandb_run: 55 | wandb.finish() # End the current run 56 | self.new_wandb_run() # Start a new run 57 | 58 | def new_wandb_run(self): 59 | now = dt.datetime.now() 60 | run_id = now.strftime("%Y-%m-%d_%H-%M-%S") 61 | name = f"Validator-{self.aimodel.uid}-{run_id}" 62 | commit = self.aimodel.get_git_commit_hash() 63 | self.wandb_run = wandb.init( 64 | name=name, 65 | project="AudioSubnet_Valid", 66 | entity="subnet16team", 67 | config={ 68 | "uid": self.aimodel.uid, 69 | "hotkey": self.aimodel.wallet.hotkey.ss58_address, 70 | "run_name": run_id, 71 | "type": "Validator", 72 | "tao (stake)": self.aimodel.metagraph.neurons[self.aimodel.uid].stake.tao, 73 | "commit": commit, 74 | }, 75 | tags=self.aimodel.sys_info, 76 | allow_val_change=True, 77 | anonymous="allow", 78 | ) 79 | bt.logging.debug(f"Started a new wandb run: {name}") 80 | 81 | async def setup_and_run(controller): 82 | tasks = [] 83 | secret_key = os.getenv("AUTH_SECRET_KEY") 84 | if os.path.exists(os.path.join(project_root, 'app')) and secret_key: 85 | app = create_app(secret_key) 86 | # Start FastAPI with ngrok without blocking 87 | ngrok_tunnel, server_task = await controller.run_fastapi_with_ngrok(app) 88 | tasks.append(server_task) # Keep track of the server task if you need to cancel it later 89 | 90 | # Start service-related tasks 91 | service_task = asyncio.create_task(controller.run_services()) 92 | tasks.append(service_task) 93 | 94 | # Wait for all tasks to complete 95 | await asyncio.gather(*tasks) 96 | 97 | # Cleanup, in case you need to close ngrok or other resources 98 | ngrok_tunnel.close() 99 | 100 | async def main(): 101 | controller = AIModelController() 102 | await setup_and_run(controller) 103 | 104 | if __name__ == "__main__": 105 | asyncio.run(main()) 106 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bittensor==7.4.0 2 | pydantic 3 | psutil 4 | GPUtil 5 | inflect 6 | huggingface_hub 7 | librosa 8 | torchaudio 9 | scipy 10 | transformers 11 | datasets==3.0.1 12 | wandb==0.18.5 13 | pyngrok==7.2.0 14 | tabulate==0.9.0 15 | resampy==0.4.3 16 | aiofiles==23.2.1 17 | altair==5.4.1 18 | annotated-types==0.7.0 19 | antlr4-python3-runtime==4.9.3 20 | anyio==4.6.0 21 | attrs==24.2.0 22 | audioread==3.0.1 23 | av==11.0.0 24 | blis==0.7.11 25 | catalogue==2.0.10 26 | certifi==2024.7.4 27 | cffi==1.17.1 28 | charset-normalizer==3.3.2 29 | click==8.1.7 30 | cloudpathlib==0.19.0 31 | cloudpickle==3.0.0 32 | colorlog==6.8.2 33 | confection==0.1.5 34 | contourpy==1.2.1 35 | cycler==0.12.1 36 | cymem==2.0.8 37 | decorator==5.1.1 38 | demucs==4.0.1 39 | docopt==0.6.2 40 | dora_search==0.1.12 41 | einops==0.8.0 42 | encodec==0.1.1 43 | exceptiongroup==1.2.2 44 | fastapi==0.110.1 45 | ffmpy==0.4.0 46 | filelock==3.16.1 47 | flashy==0.0.2 48 | fonttools==4.54.1 49 | fsspec 50 | gradio==4.43.0 51 | gradio_client==1.3.0 52 | h11==0.14.0 53 | httpcore==1.0.6 54 | httpx==0.27.2 55 | huggingface-hub==0.25.1 56 | hydra-colorlog==1.2.0 57 | hydra-core==1.3.2 58 | idna==3.10 59 | importlib_resources==6.4.5 60 | Jinja2==3.1.4 61 | joblib==1.4.2 62 | jsonschema==4.23.0 63 | jsonschema-specifications==2023.12.1 64 | julius==0.2.7 65 | kiwisolver==1.4.7 66 | lameenc==1.7.0 67 | langcodes==3.4.1 68 | language_data==1.2.0 69 | lazy_loader==0.4 70 | librosa==0.10.0 71 | lightning-utilities==0.11.7 72 | llvmlite==0.42.0 73 | marisa-trie==1.2.0 74 | markdown-it-py==3.0.0 75 | MarkupSafe==2.1.5 76 | matplotlib==3.8.3 77 | mdurl==0.1.2 78 | mpmath==1.3.0 79 | msgpack==1.1.0 80 | murmurhash==1.0.10 81 | narwhals==1.9.0 82 | networkx==3.2.1 83 | num2words==0.5.13 84 | numba==0.59.1 85 | numpy==1.26.4 86 | nvidia-cublas-cu12==12.1.3.1 87 | nvidia-cuda-cupti-cu12==12.1.105 88 | nvidia-cuda-nvrtc-cu12==12.1.105 89 | nvidia-cuda-runtime-cu12==12.1.105 90 | nvidia-cudnn-cu12==8.9.2.26 91 | nvidia-cufft-cu12==11.0.2.54 92 | nvidia-curand-cu12==10.3.2.106 93 | nvidia-cusolver-cu12==11.4.5.107 94 | nvidia-cusparse-cu12==12.1.0.106 95 | nvidia-nccl-cu12==2.18.1 96 | nvidia-nvjitlink-cu12==12.6.77 97 | nvidia-nvtx-cu12==12.1.105 98 | omegaconf==2.3.0 99 | openunmix==1.3.0 100 | orjson==3.10.7 101 | packaging==24.1 102 | pandas==2.2.1 103 | pesq==0.0.4 104 | pillow==10.4.0 105 | pip==23.0.1 106 | platformdirs==4.3.6 107 | pooch==1.8.2 108 | preshed==3.0.9 109 | protobuf==5.28.2 110 | pycparser==2.22 111 | pydantic==2.9.2 112 | pydantic_core==2.23.4 113 | pydub==0.25.1 114 | Pygments==2.18.0 115 | pyparsing==3.1.4 116 | pystoi==0.4.1 117 | python-dateutil==2.9.0.post0 118 | python-multipart==0.0.12 119 | pytz==2024.2 120 | PyYAML==6.0.2 121 | referencing==0.35.1 122 | regex==2024.9.11 123 | requests==2.32.3 124 | retrying==1.3.4 125 | rich==13.9.2 126 | rpds-py==0.20.0 127 | ruff==0.6.9 128 | safetensors==0.4.5 129 | scikit-learn==1.4.2 130 | scipy==1.13.0 131 | semantic-version==2.10.0 132 | sentencepiece==0.2.0 133 | setuptools==70.0.0 134 | shellingham==1.5.4 135 | six==1.16.0 136 | smart-open==7.0.5 137 | sniffio==1.3.1 138 | soundfile==0.12.1 139 | soxr==0.5.0.post1 140 | spacy==3.7.5 141 | spacy-legacy==3.0.12 142 | spacy-loggers==1.0.5 143 | srsly==2.4.8 144 | starlette==0.37.2 145 | submitit==1.5.2 146 | sympy==1.13.3 147 | thinc==8.2.5 148 | threadpoolctl==3.5.0 149 | tokenizers==0.15.2 150 | tomlkit==0.12.0 151 | torch==2.1.0 152 | torchaudio==2.1.0 153 | torchdata==0.7.0 154 | torchmetrics==1.3.2 155 | torchtext==0.16.0 156 | torchvision==0.16.0 157 | tqdm==4.66.5 158 | transformers==4.38.2 159 | treetable==0.2.5 160 | triton 161 | typer==0.12.5 162 | typing_extensions==4.12.2 163 | tzdata==2024.2 164 | urllib3==2.2.3 165 | uvicorn==0.31.0 166 | wasabi==1.1.3 167 | weasel==0.4.1 168 | websockets==11.0.3 169 | wrapt==1.16.0 170 | xformers==0.0.22.post7 171 | zipp==3.20.2 -------------------------------------------------------------------------------- /scripts/check_compatibility.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -z "$1" ]; then 4 | echo "Please provide a Python version as an argument." 5 | exit 1 6 | fi 7 | 8 | python_version="$1" 9 | all_passed=true 10 | 11 | GREEN='\033[0;32m' 12 | YELLOW='\033[0;33m' 13 | RED='\033[0;31m' 14 | NC='\033[0m' # No Color 15 | 16 | check_compatibility() { 17 | all_supported=0 18 | 19 | while read -r requirement; do 20 | # Skip lines starting with git+ 21 | if [[ "$requirement" == git+* ]]; then 22 | continue 23 | fi 24 | 25 | package_name=$(echo "$requirement" | awk -F'[!=<>]' '{print $1}' | awk -F'[' '{print $1}') # Strip off brackets 26 | echo -n "Checking $package_name... " 27 | 28 | url="https://pypi.org/pypi/$package_name/json" 29 | response=$(curl -s $url) 30 | status_code=$(curl -s -o /dev/null -w "%{http_code}" $url) 31 | 32 | if [ "$status_code" != "200" ]; then 33 | echo -e "${RED}Information not available for $package_name. Failure.${NC}" 34 | all_supported=1 35 | continue 36 | fi 37 | 38 | classifiers=$(echo "$response" | jq -r '.info.classifiers[]') 39 | requires_python=$(echo "$response" | jq -r '.info.requires_python') 40 | 41 | base_version="Programming Language :: Python :: ${python_version%%.*}" 42 | specific_version="Programming Language :: Python :: $python_version" 43 | 44 | if echo "$classifiers" | grep -q "$specific_version" || echo "$classifiers" | grep -q "$base_version"; then 45 | echo -e "${GREEN}Supported${NC}" 46 | elif [ "$requires_python" != "null" ]; then 47 | if echo "$requires_python" | grep -Eq "==$python_version|>=$python_version|<=$python_version"; then 48 | echo -e "${GREEN}Supported${NC}" 49 | else 50 | echo -e "${RED}Not compatible with Python $python_version due to constraint $requires_python.${NC}" 51 | all_supported=1 52 | fi 53 | else 54 | echo -e "${YELLOW}Warning: Specific version not listed, assuming compatibility${NC}" 55 | fi 56 | done < requirements.txt 57 | 58 | return $all_supported 59 | } 60 | 61 | echo "Checking compatibility for Python $python_version..." 62 | check_compatibility 63 | if [ $? -eq 0 ]; then 64 | echo -e "${GREEN}All requirements are compatible with Python $python_version.${NC}" 65 | else 66 | echo -e "${RED}All requirements are NOT compatible with Python $python_version.${NC}" 67 | all_passed=false 68 | fi 69 | 70 | echo "" 71 | if $all_passed; then 72 | echo -e "${GREEN}All tests passed.${NC}" 73 | else 74 | echo -e "${RED}All tests did not pass.${NC}" 75 | exit 1 76 | fi 77 | -------------------------------------------------------------------------------- /scripts/check_requirements_changes.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check if requirements files have changed in the last commit 4 | if git diff --name-only HEAD~1 | grep -E 'requirements.txt|requirements.txt'; then 5 | echo "Requirements files have changed. Running compatibility checks..." 6 | echo 'export REQUIREMENTS_CHANGED="true"' >> $BASH_ENV 7 | else 8 | echo "Requirements files have not changed. Skipping compatibility checks..." 9 | echo 'export REQUIREMENTS_CHANGED="false"' >> $BASH_ENV 10 | fi 11 | -------------------------------------------------------------------------------- /scripts/start_valid.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script runs a validator process and automatically updates it when a new version is released. 3 | Command-line arguments will be forwarded to validator (`neurons/validator.py`), so you can pass 4 | them like this: 5 | python3 scripts/start_validator.py --wallet.name=my-wallet 6 | Auto-updates are enabled by default and will make sure that the latest version is always running 7 | by pulling the latest version from git and upgrading python packages. This is done periodically. 8 | Local changes may prevent the update, but they will be preserved. 9 | 10 | The script will use the same virtual environment as the one used to run it. If you want to run 11 | validator within virtual environment, run this auto-update script from the virtual environment. 12 | 13 | Pm2 is required for this script. This script will start a pm2 process using the name provided by 14 | the --pm2_name argument. 15 | """ 16 | import argparse 17 | import logging 18 | import subprocess 19 | import sys 20 | import time 21 | from datetime import timedelta 22 | from shlex import split 23 | from pathlib import Path 24 | from typing import List 25 | 26 | log = logging.getLogger(__name__) 27 | UPDATES_CHECK_TIME = timedelta(minutes=5) 28 | ROOT_DIR = Path(__file__).parent.parent 29 | 30 | 31 | 32 | def get_version() -> str: 33 | """Extract the version as current git commit hash""" 34 | result = subprocess.run( 35 | split("git rev-parse HEAD"), 36 | check=True, 37 | capture_output=True, 38 | cwd=ROOT_DIR, 39 | ) 40 | commit = result.stdout.decode().strip() 41 | assert len(commit) == 40, f"Invalid commit hash: {commit}" 42 | return commit[:8] 43 | 44 | 45 | def start_validator_process(pm2_name: str, args: List[str]) -> subprocess.Popen: 46 | """ 47 | Spawn a new python process running neurons.validator. 48 | `sys.executable` ensures the same python interpreter is used as the one 49 | used to run this auto-updater. 50 | """ 51 | assert sys.executable, "Failed to get python executable" 52 | 53 | log.info("Starting validator process with pm2, name: %s", pm2_name) 54 | process = subprocess.Popen( 55 | ( 56 | "pm2", 57 | "start", 58 | sys.executable, 59 | "--name", 60 | pm2_name, 61 | "--", 62 | "-m", 63 | "neurons.validator", 64 | *args, # Added to include additional arguments 65 | ), 66 | cwd=ROOT_DIR, 67 | ) 68 | process.pm2_name = pm2_name 69 | 70 | return process 71 | 72 | 73 | 74 | def stop_validator_process(process: subprocess.Popen) -> None: 75 | """Stop the validator process""" 76 | subprocess.run( 77 | ("pm2", "delete", process.pm2_name), cwd=ROOT_DIR, check=True 78 | ) 79 | 80 | 81 | def pull_latest_version() -> None: 82 | """ 83 | Pull the latest version from git. 84 | This uses `git pull --rebase`, so if any changes were made to the local repository, 85 | this will try to apply them on top of origin's changes. This is intentional, as we 86 | don't want to overwrite any local changes. However, if there are any conflicts, 87 | this will abort the rebase and return to the original state. 88 | The conflicts are expected to happen rarely since validator is expected 89 | to be used as-is. 90 | """ 91 | try: 92 | subprocess.run( 93 | split("git pull --rebase --autostash"), check=True, cwd=ROOT_DIR 94 | ) 95 | except subprocess.CalledProcessError as exc: 96 | log.error("Failed to pull, reverting: %s", exc) 97 | subprocess.run(split("git rebase --abort"), check=True, cwd=ROOT_DIR) 98 | 99 | 100 | def upgrade_packages() -> None: 101 | """ 102 | Upgrade python packages by running `pip install --upgrade -r requirements.txt`. 103 | Notice: this won't work if some package in `requirements.txt` is downgraded. 104 | Ignored as this is unlikely to happen. 105 | """ 106 | 107 | log.info("Upgrading packages") 108 | try: 109 | subprocess.run( 110 | split(f"{sys.executable} -m pip install -e ."), 111 | check=True, 112 | cwd=ROOT_DIR, 113 | ) 114 | except subprocess.CalledProcessError as exc: 115 | log.error("Failed to upgrade packages, proceeding anyway. %s", exc) 116 | 117 | 118 | def main(pm2_name: str, args: List[str]) -> None: 119 | """ 120 | Run the validator process and automatically update it when a new version is released. 121 | This will check for updates every `UPDATES_CHECK_TIME` and update the validator 122 | if a new version is available. Update is performed as simple `git pull --rebase`. 123 | """ 124 | 125 | validator = start_validator_process(pm2_name, args) 126 | current_version = latest_version = get_version() 127 | log.info("Current version: %s", current_version) 128 | 129 | try: 130 | while True: 131 | pull_latest_version() 132 | latest_version = get_version() 133 | log.info("Latest version: %s", latest_version) 134 | 135 | if latest_version != current_version: 136 | log.info( 137 | "Upgraded to latest version: %s -> %s", 138 | current_version, 139 | latest_version, 140 | ) 141 | upgrade_packages() 142 | 143 | stop_validator_process(validator) 144 | validator = start_validator_process(pm2_name, args) 145 | current_version = latest_version 146 | 147 | time.sleep(UPDATES_CHECK_TIME.total_seconds()) 148 | 149 | finally: 150 | stop_validator_process(validator) 151 | 152 | 153 | if __name__ == "__main__": 154 | logging.basicConfig( 155 | level=logging.INFO, 156 | format="%(asctime)s %(levelname)s %(message)s", 157 | handlers=[logging.StreamHandler(sys.stdout)], 158 | ) 159 | 160 | parser = argparse.ArgumentParser( 161 | description="Automatically update and restart the validator process when a new version is released.", 162 | epilog="Example usage: python start_validator.py --pm2_name 'validator' --wallet_name 'wallet1' --wallet_hotkey 'key123'", 163 | ) 164 | 165 | parser.add_argument( 166 | "--pm2_name", default="validator", help="Name of the PM2 process." 167 | ) 168 | 169 | flags, extra_args = parser.parse_known_args() 170 | 171 | main(flags.pm2_name, extra_args) 172 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # Copyright © 2023 Yuma Rao 3 | # Copyright © 2023 Omega Labs, Inc. 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 6 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 7 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 8 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 11 | # the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 14 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 15 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 16 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 17 | # DEALINGS IN THE SOFTWARE. 18 | 19 | import re 20 | import os 21 | import codecs 22 | from os import path 23 | from io import open 24 | from setuptools import setup, find_packages 25 | 26 | 27 | def read_requirements(path): 28 | with open(path, "r") as f: 29 | requirements = f.read().splitlines() 30 | return requirements 31 | 32 | 33 | requirements = read_requirements("requirements.txt") 34 | here = path.abspath(path.dirname(__file__)) 35 | 36 | with open(path.join(here, "README.md"), encoding="utf-8") as f: 37 | long_description = f.read() 38 | 39 | # loading version from setup.py 40 | with codecs.open( 41 | os.path.join(here, "lib/__init__.py"), encoding="utf-8" 42 | ) as init_file: 43 | version_match = re.search( 44 | r"^__version__ = ['\"]([^'\"]*)['\"]", init_file.read(), re.M 45 | ) 46 | version_string = version_match.group(1) 47 | 48 | setup( 49 | name="ttm_bittensor_subnet", 50 | version=version_string, 51 | description="ttm_bittensor_subnet", 52 | long_description=long_description, 53 | long_description_content_type="text/markdown", 54 | url="https://github.com/UncleTensor/BittAudio.git", 55 | author="ttm", 56 | packages=find_packages(), 57 | include_package_data=True, 58 | author_email="", 59 | license="MIT", 60 | python_requires=">=3.10", 61 | install_requires=requirements, 62 | classifiers=[ 63 | "Development Status :: 3 - Alpha", 64 | "Intended Audience :: Developers", 65 | "Topic :: Software Development :: Build Tools", 66 | "License :: OSI Approved :: MIT License", 67 | "Programming Language :: Python :: 3 :: Only", 68 | "Programming Language :: Python :: 3.8", 69 | "Programming Language :: Python :: 3.9", 70 | "Programming Language :: Python :: 3.10", 71 | "Topic :: Scientific/Engineering", 72 | "Topic :: Scientific/Engineering :: Mathematics", 73 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 74 | "Topic :: Software Development", 75 | "Topic :: Software Development :: Libraries", 76 | "Topic :: Software Development :: Libraries :: Python Modules", 77 | ], 78 | ) -------------------------------------------------------------------------------- /ttm/aimodel.py: -------------------------------------------------------------------------------- 1 | import bittensor as bt 2 | import pandas as pd 3 | import subprocess 4 | import platform 5 | import argparse 6 | import inflect 7 | import psutil 8 | import GPUtil 9 | import sys 10 | import os 11 | import re 12 | from lib import __spec_version__ as spec_version 13 | 14 | 15 | class AIModelService: 16 | _scores = None 17 | _base_initialized = False # Class-level flag for one-time initialization 18 | version: int = spec_version # Adjust version as necessary 19 | 20 | def __init__(self): 21 | self.config = self.get_config() 22 | self.sys_info = self.get_system_info() 23 | self.setup_paths() 24 | self.setup_logging() 25 | self.wallet = bt.wallet(config=self.config) 26 | self.subtensor = bt.subtensor(config=self.config) 27 | self.dendrite = bt.dendrite(wallet=self.wallet) 28 | self.metagraph = self.subtensor.metagraph(self.config.netuid) 29 | self.p = inflect.engine() 30 | 31 | if not AIModelService._base_initialized: 32 | bt.logging.info(f"Wallet: {self.wallet}") 33 | bt.logging.info(f"Subtensor: {self.subtensor}") 34 | bt.logging.info(f"Dendrite: {self.dendrite}") 35 | bt.logging.info(f"Metagraph: {self.metagraph}") 36 | AIModelService._base_initialized = True 37 | 38 | if AIModelService._scores is None: 39 | AIModelService._scores = self.metagraph.E.copy() 40 | self.scores = AIModelService._scores 41 | self.uid = self.metagraph.hotkeys.index(self.wallet.hotkey.ss58_address) 42 | 43 | def get_config(self): 44 | parser = argparse.ArgumentParser() 45 | 46 | parser.add_argument("--alpha", default=0.1, type=float, help="The weight moving average scoring.") 47 | parser.add_argument("--custom", default="my_custom_value", help="Adds a custom value to the parser.") 48 | parser.add_argument("--subtensor.network", type=str, help="The logging directory.") 49 | parser.add_argument("--netuid", default=50, type=int, help="The chain subnet uid.") 50 | parser.add_argument("--wallet.name", type=str, help="The wallet name.") 51 | parser.add_argument("--wallet.hotkey", type=str, help="The wallet hotkey.") 52 | 53 | # Add Bittensor specific arguments 54 | bt.subtensor.add_args(parser) 55 | bt.logging.add_args(parser) 56 | bt.wallet.add_args(parser) 57 | 58 | # Parse and return the config 59 | config = bt.config(parser) 60 | return config 61 | 62 | def priority_uids(self, metagraph): 63 | hotkeys = metagraph.hotkeys # List of hotkeys 64 | coldkeys = metagraph.coldkeys # List of coldkeys 65 | UIDs = range(len(hotkeys)) # Assuming UID is the index of neurons 66 | stakes = metagraph.S.numpy() # Total stake 67 | emissions = metagraph.E.numpy() # Emission 68 | 69 | # Create a DataFrame from the metagraph data 70 | df = pd.DataFrame({ 71 | "UID": UIDs, 72 | "HOTKEY": hotkeys, 73 | "COLDKEY": coldkeys, 74 | "STAKE": stakes, 75 | "EMISSION": emissions, 76 | "AXON": metagraph.axons, 77 | }) 78 | 79 | # Filter and sort the DataFrame 80 | df = df[df['STAKE'] < 500] 81 | df = df.sort_values(by=["EMISSION"], ascending=False) 82 | uid = df.iloc[0]['UID'] 83 | axon_info = df.iloc[0]['AXON'] 84 | 85 | result = [(uid, axon_info)] 86 | return result 87 | 88 | def get_system_info(self): 89 | system_info = { 90 | "OS Version": platform.platform(), 91 | "CPU Count": os.cpu_count(), 92 | "RAM": f"{psutil.virtual_memory().total / (1024**3):.2f} GB", 93 | } 94 | 95 | gpus = GPUtil.getGPUs() 96 | if gpus: 97 | system_info["GPU"] = gpus[0].name 98 | 99 | # Convert dictionary to list of strings for logging purposes 100 | tags = [f"{key}: {value}" for key, value in system_info.items()] 101 | return tags 102 | 103 | def setup_paths(self): 104 | # Set the project root path 105 | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 106 | 107 | # Add project root to sys.path 108 | sys.path.insert(0, project_root) 109 | 110 | def convert_numeric_values(self, input_prompt): 111 | # Regular expression to identify date patterns 112 | date_pattern = r'(\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b)|(\b\d{2,4}[/-]\d{1,2}[/-]\d{1,2}\b)' 113 | 114 | # Regular expression to find ordinal numbers 115 | ordinal_pattern = r'\b\d{1,2}(st|nd|rd|th)\b' 116 | 117 | # Regular expression to find numeric values with and without commas, excluding those part of date patterns and ordinals 118 | numeric_pattern = r'\b(? List: 42 | """ 43 | Processes and returns the music_output into a format ready for audio rendering or further analysis. 44 | """ 45 | return self 46 | -------------------------------------------------------------------------------- /ttm/ttm.py: -------------------------------------------------------------------------------- 1 | from lib.hashing import load_hashes_to_cache, check_duplicate_music, save_hash_to_file 2 | from ttm.ttm_score import MusicQualityEvaluator 3 | from ttm.protocol import MusicGeneration 4 | from ttm.aimodel import AIModelService 5 | from datasets import load_dataset 6 | from datetime import datetime 7 | from tabulate import tabulate 8 | import bittensor as bt 9 | import soundfile as sf 10 | import numpy as np 11 | import torchaudio 12 | import contextlib 13 | import traceback 14 | import asyncio 15 | import hashlib 16 | import random 17 | import torch 18 | import wandb 19 | import wave 20 | import lib 21 | import sys 22 | import os 23 | import re 24 | 25 | 26 | # Set the project root path 27 | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 28 | audio_subnet_path = os.path.abspath(project_root) 29 | sys.path.insert(0, project_root) 30 | sys.path.insert(0, audio_subnet_path) 31 | refrence_dir = "audio_files" # Directory to save the audio files 32 | class MusicGenerationService(AIModelService): 33 | def __init__(self): 34 | super().__init__() 35 | self.load_prompts() 36 | self.total_dendrites_per_query = 10 37 | self.minimum_dendrites_per_query = 3 38 | self.current_block = self.subtensor.block 39 | self.last_updated_block = self.current_block - (self.current_block % 100) 40 | self.last_reset_weights_block = self.current_block 41 | self.filtered_axon = [] 42 | self.combinations = [] 43 | self.duration = None 44 | self.lock = asyncio.Lock() 45 | self.audio_path = None 46 | # Load hashes from file to cache at startup 47 | load_hashes_to_cache() 48 | 49 | def load_prompts(self): 50 | # Load the dataset (you can change this to any other dataset name) 51 | dataset = load_dataset("etechgrid/ttm-validation-dataset", split="train") # Adjust the split if needed (train, test, etc.) 52 | random_index = random.randint(0, len(dataset) - 1) 53 | self.random_sample = dataset[random_index] 54 | # Checking if the prompt exists in the dataset 55 | if 'Prompts' in self.random_sample: 56 | prompt = self.random_sample['Prompts'] 57 | bt.logging.info(f"Returning the prompt: {prompt}") 58 | else: 59 | print("'Prompt' not found in the sample.") 60 | return None # Return None if no prompt found 61 | 62 | # Check if audio data exists and save it 63 | if 'File_Path' in self.random_sample and isinstance(self.random_sample['File_Path'], dict): 64 | file_path = self.random_sample['File_Path'] 65 | if 'array' in file_path and 'sampling_rate' in file_path: 66 | audio_array = file_path['array'] 67 | sample_rate = file_path['sampling_rate'] 68 | 69 | # Save the audio to a file 70 | os.makedirs(refrence_dir, exist_ok=True) # Create output directory if it doesn't exist 71 | audio_path = os.path.join(refrence_dir, "random_sample.wav") 72 | 73 | try: 74 | # Save the audio data using soundfile 75 | sf.write(audio_path, audio_array, sample_rate) 76 | bt.logging.info(f"Audio saved successfully at: {audio_path}") 77 | self.audio_path = audio_path 78 | 79 | # Read the audio file into a numerical array 80 | audio_data, sample_rate = sf.read(self.audio_path) 81 | 82 | # Convert the numerical array to a tensor 83 | speech_tensor = torch.Tensor(audio_data) 84 | 85 | # Normalize the speech data 86 | audio_data = speech_tensor / torch.max(torch.abs(speech_tensor)) 87 | audio_hash = hashlib.sha256(audio_data.numpy().tobytes()).hexdigest() 88 | timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") 89 | # Check if the music hash is a duplicate 90 | if check_duplicate_music(audio_hash): 91 | bt.logging.info(f"Duplicate music detected from Validator. skipping hash.") 92 | else: 93 | try: 94 | save_hash_to_file(audio_hash, timestamp) 95 | bt.logging.info(f"Music hash processed and saved successfully for Validator") 96 | except Exception as e: 97 | bt.logging.error(f"Error saving audio hash: {e}") 98 | 99 | except Exception as e: 100 | bt.logging.error(f"Error saving audio file: {e}") 101 | else: 102 | print("Invalid audio data in 'File_Path'. Expected 'array' and 'sampling_rate'.") 103 | return None 104 | else: 105 | print("'File_Path' not found or invalid format in the sample.") 106 | return None 107 | 108 | return prompt # Return the prompt after saving the audio file 109 | 110 | 111 | async def run_async(self): 112 | step = 0 113 | while True: 114 | try: 115 | await self.main_loop_logic(step) 116 | step += 1 117 | if step % 10 == 0: 118 | self.metagraph.sync(subtensor=self.subtensor) 119 | bt.logging.info(f"🔄 Syncing metagraph with subtensor.") 120 | except KeyboardInterrupt: 121 | print("Keyboard interrupt detected. Exiting MusicGenerationService.") 122 | break 123 | except Exception as e: 124 | print(f"An error occurred in MusicGenerationService: {e}") 125 | traceback.print_exc() 126 | 127 | async def main_loop_logic(self, step): 128 | g_prompt = None 129 | try: 130 | # Load prompt from the dataset using the load_prompts function 131 | bt.logging.info(f"Using prompt from HuggingFace Dataset for Text-To-Music at Step: {step}") 132 | g_prompt = self.load_prompts() 133 | 134 | if isinstance(g_prompt, str): 135 | g_prompt = self.convert_numeric_values(g_prompt) 136 | 137 | # Ensure prompt length does not exceed 256 characters 138 | while isinstance(g_prompt, str) and len(g_prompt) > 256: 139 | bt.logging.error(f'The length of current Prompt is greater than 256. Skipping current prompt.') 140 | g_prompt = self.load_prompts() # Reload another prompt 141 | g_prompt = self.convert_numeric_values(g_prompt) 142 | 143 | # Get filtered axons and query the network 144 | filtered_axons = self.get_filtered_axons_from_combinations() 145 | responses = self.query_network(filtered_axons, g_prompt) 146 | try: 147 | self.process_responses(filtered_axons, responses, g_prompt) 148 | except Exception as e: 149 | bt.logging.error(f"getting an error in processing response: {e}") 150 | 151 | if self.last_reset_weights_block + 50 < self.current_block: 152 | bt.logging.info(f"Resetting weights for validators and nodes without IPs") 153 | self.last_reset_weights_block = self.current_block 154 | # set all nodes without ips set to 0 155 | self.scores = torch.Tensor(self.scores) # Convert NumPy array to PyTorch tensor 156 | self.scores = self.scores * torch.Tensor([self.metagraph.neurons[uid].axon_info.ip != '0.0.0.0' for uid in self.metagraph.uids]) 157 | 158 | except Exception as e: 159 | bt.logging.error(f"An error occurred in main loop logic: {e}") 160 | 161 | def query_network(self, filtered_axons, prompt, duration=15): 162 | # Network querying logic 163 | if duration == 15: 164 | self.duration = 755 165 | self.time_out = 100 166 | elif duration == 30: 167 | self.duration = 1510 168 | self.time_out = 200 169 | 170 | """Queries the network with filtered axons and prompt.""" 171 | responses = self.dendrite.query( 172 | filtered_axons, 173 | MusicGeneration(text_input=prompt, duration=self.duration), 174 | deserialize=True, 175 | timeout=200, 176 | ) 177 | return responses 178 | 179 | def update_block(self): 180 | self.current_block = self.subtensor.block 181 | if self.current_block - self.last_updated_block > 120: 182 | bt.logging.info(f"Updating weights. Last update was at block: {self.last_updated_block}") 183 | bt.logging.info(f"Current block is for weight update is: {self.current_block}") 184 | self.update_weights(self.scores) 185 | self.last_updated_block = self.current_block 186 | else: 187 | bt.logging.info(f"Updating weights. Last update was at block: {self.last_updated_block}") 188 | bt.logging.info(f"Current block is: {self.current_block}") 189 | bt.logging.info(f"Next update will be at block: {self.last_updated_block + 120}") 190 | bt.logging.info(f"Skipping weight update. Last update was at block {self.last_updated_block}") 191 | 192 | def process_responses(self, filtered_axons, responses, prompt): 193 | """Processes responses received from the network.""" 194 | for axon, response in zip(filtered_axons, responses): 195 | if response is not None and isinstance(response, MusicGeneration): 196 | self.process_response(axon, response, prompt) 197 | 198 | bt.logging.info(f"Scores after update in TTM: {self.scores}") 199 | self.update_block() 200 | 201 | def process_response(self, axon, response, prompt, api=False): 202 | try: 203 | music_output = response.music_output 204 | if response is not None and isinstance(response, MusicGeneration) and response.music_output is not None and response.dendrite.status_code == 200: 205 | bt.logging.success(f"Received music output from {axon.hotkey}") 206 | if api: 207 | file = self.handle_music_output(axon, music_output, prompt, response.model_name) 208 | return file 209 | else: 210 | self.handle_music_output(axon, music_output, prompt, response.model_name) 211 | elif response.dendrite.status_code != 403: 212 | self.punish(axon, service="Text-To-Music", punish_message=response.dendrite.status_message) 213 | else: 214 | pass 215 | 216 | except Exception as e: 217 | bt.logging.error(f'An error occurred while handling speech output: {e}') 218 | 219 | def handle_music_output(self, axon, music_output, prompt, model_name): 220 | # Handle the music output received from the miners 221 | try: 222 | # Convert the list to a tensor 223 | speech_tensor = torch.Tensor(music_output) 224 | bt.logging.info("Converted music output to tensor successfully.") 225 | except Exception as e: 226 | bt.logging.error(f"Error converting music output to tensor: {e}") 227 | return 228 | 229 | try: 230 | # Normalize the speech data 231 | audio_data = speech_tensor / torch.max(torch.abs(speech_tensor)) 232 | bt.logging.info("Normalized the audio data.") 233 | except Exception as e: 234 | bt.logging.error(f"Error normalizing audio data: {e}") 235 | return 236 | 237 | try: 238 | # Convert to 32-bit PCM 239 | audio_data_int_ = (audio_data * 2147483647).type(torch.IntTensor) 240 | bt.logging.info("Converted audio data to 32-bit PCM.") 241 | 242 | # Add an extra dimension to make it a 2D tensor 243 | audio_data_int = audio_data_int_.unsqueeze(0) 244 | bt.logging.info("Added an extra dimension to audio data.") 245 | except Exception as e: 246 | bt.logging.error(f"Error converting audio data to 32-bit PCM: {e}") 247 | return 248 | 249 | try: 250 | # Get the .wav file from the path 251 | file_name = os.path.basename(self.audio_path) 252 | bt.logging.info(f"Saving audio file to: {file_name}") 253 | 254 | # Save the audio data as a .wav file 255 | output_path = os.path.join('/tmp/music/', file_name) 256 | sampling_rate = 32000 257 | torchaudio.save(output_path, src=audio_data_int, sample_rate=sampling_rate) 258 | bt.logging.info(f"Saved audio file to {output_path}") 259 | except Exception as e: 260 | bt.logging.error(f"Error saving audio file: {e}") 261 | return 262 | 263 | try: 264 | # Calculate the audio hash 265 | audio_hash = hashlib.sha256(audio_data.numpy().tobytes()).hexdigest() 266 | bt.logging.info("Calculated audio hash.") 267 | except Exception as e: 268 | bt.logging.error(f"Error calculating audio hash: {e}") 269 | return 270 | 271 | try: 272 | # Check if the music hash is a duplicate 273 | if check_duplicate_music(audio_hash): 274 | bt.logging.info(f"Duplicate music detected from miner: {axon.hotkey}. Issuing punishment.") 275 | self.punish(axon, service="Text-To-Music", punish_message="Duplicate music detected") 276 | else: 277 | timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") 278 | save_hash_to_file(audio_hash, axon.hotkey, timestamp) 279 | bt.logging.info(f"Music hash processed and saved successfully for miner: {axon.hotkey}") 280 | except Exception as e: 281 | bt.logging.error(f"Error checking or saving music hash: {e}") 282 | return 283 | 284 | try: 285 | # Log the audio to wandb 286 | uid_in_metagraph = self.metagraph.hotkeys.index(axon.hotkey) 287 | audio_data_np = np.array(audio_data_int_) 288 | wandb.log({ 289 | f"TTM prompt: {prompt[:100]} ....": wandb.Audio(audio_data_np, caption=f'For HotKey: {axon.hotkey[:10]} and uid {uid_in_metagraph}', sample_rate=sampling_rate) 290 | }) 291 | bt.logging.success(f"TTM Audio file uploaded to wandb successfully for Hotkey {axon.hotkey} and UID {uid_in_metagraph}") 292 | except Exception as e: 293 | bt.logging.error(f"Error uploading TTM audio file to wandb: {e}") 294 | 295 | try: 296 | # Get audio duration 297 | duration = self.get_duration(output_path) 298 | token = duration * 50.2 299 | bt.logging.info(f"The duration of the audio file is {duration} seconds.") 300 | except Exception as e: 301 | bt.logging.error(f"Error calculating audio duration: {e}") 302 | return 303 | 304 | try: 305 | refrence_dir = self.audio_path 306 | score, table1, table2 = self.score_output("/tmp/music/", refrence_dir, prompt) 307 | if duration < 15: 308 | score = self.score_adjustment(score, duration) 309 | bt.logging.info(f"Score updated based on short duration than required: {score}") 310 | else: 311 | bt.logging.info(f"Duration is greater than 15 seconds. No need to penalize the score.") 312 | except Exception as e: 313 | bt.logging.error(f"Error scoring the output: {e}") 314 | return 315 | 316 | try: 317 | current_datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S") 318 | tabulated_str = tabulate(table1, headers=[f"Raw score for hotkey:{axon.hotkey}", current_datetime], tablefmt="grid") 319 | print(tabulated_str) 320 | print("\n") 321 | tabulated_str2 = tabulate(table2, headers=[f"Normalized score for hotkey:{axon.hotkey}", current_datetime], tablefmt="grid") 322 | print(tabulated_str2) 323 | bt.logging.info(f"Aggregated Score for hotkey {axon.hotkey}: {score}") 324 | self.update_score(axon, score, service="Text-To-Music") 325 | except Exception as e: 326 | bt.logging.error(f"Error generating score tables or updating score: {e}") 327 | return 328 | 329 | return output_path 330 | 331 | 332 | def get_duration(self, wav_file_path): 333 | """Returns the duration of the audio file in seconds.""" 334 | with contextlib.closing(wave.open(wav_file_path, 'r')) as f: 335 | frames = f.getnframes() 336 | rate = f.getframerate() 337 | return frames / float(rate) 338 | 339 | def score_adjustment(self, score, duration): 340 | """Adjusts the score based on the duration of the generated audio.""" 341 | conditions = [ 342 | (lambda d: 14.5 <= d < 15, 0.9), 343 | (lambda d: 14 <= d < 14.5, 0.8), 344 | (lambda d: 13.5 <= d < 14, 0.7), 345 | (lambda d: 13 <= d < 13.5, 0.6), 346 | (lambda d: 12.5 <= d < 13, 0.0), 347 | ] 348 | for condition, multiplier in conditions: 349 | if condition(duration): 350 | return score * multiplier 351 | return score 352 | 353 | def score_output(self, output_path, refrence_dir , prompt): 354 | """Evaluates and returns the score for the generated music output.""" 355 | try: 356 | score_object = MusicQualityEvaluator() 357 | return score_object.evaluate_music_quality(output_path, refrence_dir, prompt) 358 | except Exception as e: 359 | bt.logging.error(f"Error scoring output: {e}") 360 | return 0.0 361 | 362 | 363 | def get_filtered_axons_from_combinations(self): 364 | if not self.combinations: 365 | self.get_filtered_axons() 366 | 367 | if self.combinations: 368 | current_combination = self.combinations.pop(0) 369 | bt.logging.info(f"Current Combination for TTM: {current_combination}") 370 | filtered_axons = [self.metagraph.axons[i] for i in current_combination] 371 | else: 372 | self.get_filtered_axons() 373 | current_combination = self.combinations.pop(0) 374 | bt.logging.info(f"Current Combination for TTM: {current_combination}") 375 | filtered_axons = [self.metagraph.axons[i] for i in current_combination] 376 | 377 | return filtered_axons 378 | 379 | 380 | def get_filtered_axons(self): 381 | # Get the uids of all miners in the network. 382 | uids = self.metagraph.uids.tolist() 383 | queryable_uids = (self.metagraph.total_stake >= 0) 384 | # Remove the weights of miners that are not queryable. 385 | queryable_uids = torch.Tensor(queryable_uids) * torch.Tensor([self.metagraph.neurons[uid].axon_info.ip != '0.0.0.0' for uid in uids]) 386 | 387 | active_miners = torch.sum(queryable_uids) 388 | dendrites_per_query = self.total_dendrites_per_query 389 | 390 | # if there are no active miners, set active_miners to 1 391 | if active_miners == 0: 392 | active_miners = 1 393 | # if there are less than dendrites_per_query * 3 active miners, set dendrites_per_query to active_miners / 3 394 | if active_miners < self.total_dendrites_per_query * 3: 395 | dendrites_per_query = int(active_miners / 3) 396 | else: 397 | dendrites_per_query = self.total_dendrites_per_query 398 | 399 | # less than 3 set to 3 400 | if dendrites_per_query < self.minimum_dendrites_per_query: 401 | dendrites_per_query = self.minimum_dendrites_per_query 402 | # zip uids and queryable_uids, filter only the uids that are queryable, unzip, and get the uids 403 | zipped_uids = list(zip(uids, queryable_uids)) 404 | filtered_zipped_uids = list(filter(lambda x: x[1], zipped_uids)) 405 | filtered_uids = [item[0] for item in filtered_zipped_uids] if filtered_zipped_uids else [] 406 | subset_length = min(dendrites_per_query, len(filtered_uids)) 407 | # Shuffle the order of members 408 | random.shuffle(filtered_uids) 409 | # Generate subsets of length 7 until all items are covered 410 | while filtered_uids: 411 | subset = filtered_uids[:subset_length] 412 | self.combinations.append(subset) 413 | filtered_uids = filtered_uids[subset_length:] 414 | return filtered_uids #self.combinations 415 | 416 | def update_weights(self, scores): 417 | """ 418 | Sets the validator weights to the metagraph hotkeys based on the scores it has received from the miners. 419 | The weights determine the trust and incentive level the validator assigns to miner nodes on the network. 420 | """ 421 | 422 | # Convert scores to a PyTorch tensor and check for NaN values 423 | weights = torch.tensor(scores) 424 | if torch.isnan(weights).any(): 425 | bt.logging.warning( 426 | "Scores contain NaN values. This may be due to a lack of responses from miners, or a bug in your reward functions." 427 | ) 428 | 429 | # Normalize scores to get raw weights 430 | raw_weights = torch.nn.functional.normalize(weights, p=1, dim=0) 431 | bt.logging.info("raw_weights", raw_weights) 432 | 433 | # Convert uids to a PyTorch tensor 434 | uids = torch.tensor(self.metagraph.uids) 435 | 436 | bt.logging.info("raw_weight_uids", uids) 437 | try: 438 | # Convert tensors to NumPy arrays for processing if required by the process_weights_for_netuid function 439 | uids_np = uids.numpy() if isinstance(uids, torch.Tensor) else uids 440 | raw_weights_np = raw_weights.numpy() if isinstance(raw_weights, torch.Tensor) else raw_weights 441 | 442 | # Process the raw weights and uids based on subnet limitations 443 | (processed_weight_uids, processed_weights) = bt.utils.weight_utils.process_weights_for_netuid( 444 | uids=uids_np, # Ensure this is a NumPy array 445 | weights=raw_weights_np, # Ensure this is a NumPy array 446 | netuid=self.config.netuid, 447 | subtensor=self.subtensor, 448 | metagraph=self.metagraph, 449 | ) 450 | bt.logging.info("processed_weights", processed_weights) 451 | bt.logging.info("processed_weight_uids", processed_weight_uids) 452 | except Exception as e: 453 | bt.logging.error(f"An error occurred while processing weights within update_weights: {e}") 454 | return 455 | 456 | # Convert processed weights and uids back to PyTorch tensors if needed for further processing 457 | processed_weight_uids = torch.tensor(processed_weight_uids) if isinstance(processed_weight_uids, np.ndarray) else processed_weight_uids 458 | processed_weights = torch.tensor(processed_weights) if isinstance(processed_weights, np.ndarray) else processed_weights 459 | 460 | # Convert weights and uids to uint16 format for emission 461 | uint_uids, uint_weights = bt.utils.weight_utils.convert_weights_and_uids_for_emit( 462 | uids=processed_weight_uids, weights=processed_weights 463 | ) 464 | bt.logging.info("uint_weights", uint_weights) 465 | bt.logging.info("uint_uids", uint_uids) 466 | 467 | # Set the weights on the Bittensor network 468 | try: 469 | result, msg = self.subtensor.set_weights( 470 | wallet=self.wallet, 471 | netuid=self.config.netuid, 472 | uids=uint_uids, 473 | weights=uint_weights, 474 | wait_for_finalization=False, 475 | wait_for_inclusion=False, 476 | version_key=self.version, 477 | ) 478 | 479 | if result: 480 | bt.logging.info(f"Weights set on the chain successfully! {result}") 481 | else: 482 | bt.logging.error(f"Failed to set weights: {msg}") 483 | except Exception as e: 484 | bt.logging.error(f"An error occurred while setting weights: {e}") -------------------------------------------------------------------------------- /ttm/ttm_score.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchaudio 4 | import numpy as np 5 | import bittensor as bt 6 | from sre_constants import error 7 | from huggingface_hub import hf_hub_download 8 | from audiocraft.metrics import PasstKLDivergenceMetric 9 | from audiocraft.metrics import CLAPTextConsistencyMetric 10 | from audioldm_eval.metrics.fad import FrechetAudioDistance 11 | 12 | 13 | 14 | class MetricEvaluator: 15 | @staticmethod 16 | def calculate_kld(generated_audio_dir, target_audio_dir): 17 | try: 18 | # Get the single audio file path in the directory 19 | generate = next((f for f in os.listdir(generated_audio_dir) if os.path.isfile(os.path.join(generated_audio_dir, f))), None) 20 | target = next((f for f in os.listdir(target_audio_dir) if os.path.isfile(os.path.join(target_audio_dir, f))), None) 21 | 22 | if generate is None or target is None: 23 | bt.logging.error("Generated or target audio file not found.") 24 | return None 25 | 26 | # Load your predicted and target audio files 27 | target_waveform, target_sr = torchaudio.load(os.path.join(target_audio_dir, target)) 28 | generated_waveform, generated_sr = torchaudio.load(os.path.join(generated_audio_dir, generate)) 29 | 30 | # Ensure sample rates match 31 | if target_sr != generated_sr: 32 | resampler = torchaudio.transforms.Resample(orig_freq=generated_sr, new_freq=target_sr) 33 | generated_waveform = resampler(generated_waveform) 34 | generated_sr = target_sr 35 | 36 | # Truncate or pad waveforms to match lengths 37 | min_length = min(target_waveform.shape[-1], generated_waveform.shape[-1]) 38 | target_waveform = target_waveform[..., :min_length] 39 | generated_waveform = generated_waveform[..., :min_length] 40 | 41 | # Ensure that the audio tensors are in the shape [batch_size, channels, length] 42 | target_waveform = target_waveform.unsqueeze(0) # Adding batch dimension 43 | generated_waveform = generated_waveform.unsqueeze(0) # Adding batch dimension 44 | 45 | # The sizes of the waveform 46 | sizes = torch.tensor([target_waveform.shape[-1]]) 47 | 48 | # The sample rates 49 | sample_rates = torch.tensor([target_sr]) # Use just one sample rate as they should match 50 | 51 | # Initialize the PasstKLDivergenceMetric 52 | kld_metric = PasstKLDivergenceMetric() 53 | 54 | # Move tensors to the appropriate device if needed 55 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 56 | target_waveform = target_waveform.to(device) 57 | generated_waveform = generated_waveform.to(device) 58 | sizes = sizes.to(device) 59 | sample_rates = sample_rates.to(device) 60 | kld_metric = kld_metric.to(device) 61 | 62 | # Update the metric 63 | kld_metric.update(preds=generated_waveform, targets=target_waveform, sizes=sizes, sample_rates=sample_rates) 64 | 65 | # Compute the PasstKLDivergenceMetric score 66 | kld = kld_metric.compute() 67 | return kld['kld_both'] 68 | 69 | except Exception as e: 70 | import traceback 71 | traceback_str = traceback.format_exc() 72 | bt.logging.error(f"Error during KLD calculation: {e}\n{traceback_str}") 73 | return None 74 | 75 | 76 | @staticmethod 77 | def calculate_fad(generated_audio_dir, target_audio_dir): 78 | # Initialize the Frechet Audio Distance calculator 79 | fad_calculator = FrechetAudioDistance() 80 | 81 | # Calculate the FAD score between the two directories 82 | fad_score = fad_calculator.score( 83 | background_dir=generated_audio_dir, # Generated audio directory 84 | eval_dir=target_audio_dir, # Target audio directory 85 | store_embds=False, # Set to True if you want to store embeddings for later reuse 86 | limit_num=1, # Limit the number of files to process, None means no limit 87 | recalculate=True # Set to True if you want to recalculate embeddings 88 | ) 89 | 90 | # Extract the FAD score from the dictionary 91 | fad_value = fad_score['frechet_audio_distance'] 92 | 93 | # Clamp the value to 0 if it's negative 94 | fad = max(0, fad_value) 95 | return fad 96 | 97 | @staticmethod 98 | def calculate_consistency(generated_audio_dir, text): 99 | try: 100 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 101 | pt_file = hf_hub_download(repo_id="lukewys/laion_clap", filename="music_audioset_epoch_15_esc_90.14.pt") 102 | clap_metric = CLAPTextConsistencyMetric(pt_file, model_arch='HTSAT-base').to(device) 103 | 104 | def convert_audio(audio, from_rate, to_rate, to_channels): 105 | resampler = torchaudio.transforms.Resample(orig_freq=from_rate, new_freq=to_rate) 106 | audio = resampler(audio) 107 | if to_channels == 1: 108 | audio = audio.mean(dim=0, keepdim=True) 109 | return audio 110 | 111 | # Get the single audio file path in the directory 112 | file_name = next((f for f in os.listdir(generated_audio_dir) if os.path.isfile(os.path.join(generated_audio_dir, f))), None) 113 | if file_name is None: 114 | raise FileNotFoundError("No audio file found in the directory.") 115 | 116 | file_path = os.path.join(generated_audio_dir, file_name) 117 | 118 | # Load and process the audio 119 | audio, sr = torchaudio.load(file_path) 120 | audio = convert_audio(audio, from_rate=sr, to_rate=sr, to_channels=1) 121 | 122 | # Calculate consistency score 123 | clap_metric.update(audio.unsqueeze(0), [text], torch.tensor([audio.shape[1]]), torch.tensor([sr])) 124 | consistency_score = clap_metric.compute() 125 | 126 | return consistency_score 127 | except Exception as e: 128 | bt.logging.error(f"Error during consistency calculation: {e}") 129 | return None 130 | 131 | class Normalizer: 132 | @staticmethod 133 | def normalize_kld(kld_score): 134 | if kld_score is not None: 135 | if 0 <= kld_score <= 1: 136 | normalized_kld = (1 - kld_score) # Higher score is better, so normalize as 1 - kld_score 137 | elif 1 < kld_score <= 2: 138 | normalized_kld = 0.5 * (2 - kld_score) # Scale down between 0.5 and 0 139 | else: 140 | normalized_kld = 0 # Anything > 2 is considered bad 141 | else: 142 | normalized_kld = 0 143 | return normalized_kld 144 | 145 | @staticmethod 146 | def normalize_fad(fad_score): 147 | if fad_score is not None: 148 | if 0 <= fad_score <= 5: 149 | normalized_fad = (5 - fad_score) / 5 # Normalize between 0 and 1 (higher is better) 150 | elif 5 < fad_score <= 10: 151 | normalized_fad = 0.5 * (10 - fad_score) / 5 # Scale down between 0.5 and 0 152 | else: 153 | normalized_fad = 0 # Anything > 10 is considered bad 154 | else: 155 | normalized_fad = 0 156 | return normalized_fad 157 | 158 | 159 | class Aggregator: 160 | @staticmethod 161 | def geometric_mean(scores): 162 | """Calculate the geometric mean of the scores, avoiding any non-positive values.""" 163 | scores = [max(score, 0.0001) for score in scores.values()] # Replace non-positive values to avoid math errors 164 | product = np.prod(scores) 165 | return product ** (1.0 / len(scores)) 166 | 167 | class MusicQualityEvaluator: 168 | def __init__(self): 169 | self.metric_evaluator = MetricEvaluator() 170 | self.normalizer = Normalizer() 171 | self.aggregator = Aggregator() 172 | 173 | def get_directory(self, path): 174 | return os.path.dirname(path) 175 | 176 | def evaluate_music_quality(self, generated_audio, target_audio, text=None): 177 | 178 | generated_audio_dir = self.get_directory(generated_audio) 179 | target_audio_dir = self.get_directory(target_audio) 180 | 181 | bt.logging.info(f"Generated audio directory: {generated_audio_dir}") 182 | bt.logging.info(f"Target audio directory: {target_audio_dir}") 183 | 184 | try: 185 | kld_score = self.metric_evaluator.calculate_kld(generated_audio_dir, target_audio_dir) 186 | except: 187 | bt.logging.error(f"Failed to calculate KLD") 188 | 189 | try: 190 | fad_score = self.metric_evaluator.calculate_fad(generated_audio_dir, target_audio_dir) 191 | except: 192 | bt.logging.error(f"Failed to calculate FAD") 193 | 194 | try: 195 | consistency_score = self.metric_evaluator.calculate_consistency(generated_audio_dir, text) 196 | except: 197 | bt.logging.error(f"Failed to calculate Consistency score") 198 | 199 | # Normalize scores and calculate aggregate score 200 | normalized_kld = self.normalizer.normalize_kld(kld_score) 201 | normalized_fad = self.normalizer.normalize_fad(fad_score) 202 | 203 | aggregate_quality = self.aggregator.geometric_mean({'KLD': normalized_kld, 'FAD': normalized_fad}) 204 | aggregate_score = self.aggregator.geometric_mean({'quality': aggregate_quality, 'normalized_consistency': consistency_score}) if consistency_score > 0.1 else 0 205 | # Print scores in a table 206 | table1 = [ 207 | ["Metric", "Raw Score"], 208 | ["KLD Score", kld_score], 209 | ["FAD Score", fad_score], 210 | ["Consistency Score", consistency_score] 211 | ] 212 | 213 | # Print table of normalized scores 214 | table2 = [ 215 | ["Metric", "Normalized Score"], 216 | ["Normalized KLD", normalized_kld], 217 | ["Normalized FAD", normalized_fad], 218 | ["Consistency Score", consistency_score] 219 | ] 220 | 221 | 222 | return aggregate_score, table1 , table2 --------------------------------------------------------------------------------