├── .DS_Store
├── LICENSE
├── README.md
├── docs
├── bittaudio.jpg
├── miner.md
└── validator.md
├── lib
├── __init__.py
└── hashing.py
├── min_compute.yml
├── neurons
├── __init__.py
├── miner.py
└── validator.py
├── requirements.txt
├── scripts
├── check_compatibility.sh
├── check_requirements_changes.sh
└── start_valid.py
├── setup.py
└── ttm
├── aimodel.py
├── protocol.py
├── ttm.py
└── ttm_score.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UncleTensor/BittAudio/d964362e50733ce185bcbacdd1c172aa26121f1b/.DS_Store
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Muhammad Farhan Aslam
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # BittAudio (SN 50) | Audio Generation Subnet on Bittensor
2 | 
3 | The main goal of the BittAudio is to establish a decentralized platform that incentivizes the creation, distribution and also monetization of AI audio content, such as:
4 | - Text-to-Music (TTM)
5 |
6 | Validators and miners work together to ensure high-quality outputs, fostering innovation and rewarding contributions in the audio domain.
7 | By introducing audio generation service such as Text-to-Music, this subnetwork expands the range of available service within the Bittensor ecosystem. This diversification enhances the utility and appeal of the Bittensor platform to a broader audience, including creators, influencers, developers, and end-users interested in audio content.
8 |
9 | ## Validators & Miners Interaction
10 | - Validators initiate requests filled with the required data and encrypt them with a symmetric key.
11 | - Requests are signed with the validator’s private key to certify authenticity.
12 | - Miners decrypt the requests, verify the signatures to ensure authenticity, process the requests, and then send back the results, encrypted and signed for security.
13 |
14 | **Validators** are responsible for initiating the generation process by providing prompts to the Miners on the network. These prompts serve as the input for TTM service. The Validators then evaluate the quality of the generated audio and reward the Miners based on the output quality.
15 | Please refer to the [Validator Documentation](docs/validator.md)
16 |
17 | **Miners** in the Audio Subnetwork are tasked with generating audio from the text prompts received from the Validators. Leveraging advanced TTM models, miners aim to produce high-fidelity music melodies. The quality of the generated audio is crucial, as it directly influences the miners' rewards.
18 | Please refer to the [Miner Documentation](docs/miner.md)
19 |
20 | ## Workflow
21 |
22 | 1. **Prompt Generation:** The Validators generates TTM prompts and distributes them to the Miners on the network.
23 |
24 | 2. **Audio Processing:** Miners receive the prompts and utilize TTM models to convert the text into audio (music).
25 |
26 | 3. **Quality Evaluation:** The Validator assesses the quality of the generated audio, considering factors such as: clarity, naturalness, and adherence to the prompt.
27 |
28 | 4. **Reward Distribution:** Based on the quality assessment, the Validator rewards Miners accordingly. Miners with consistently higher-quality outputs receive a larger share of rewards.
29 |
30 | ## Benefits
31 |
32 | - **Decentralized Text-to-Audio:** The subnetwork decentralizes the Text-to-Music process, distributing the workload among participating Miners.
33 |
34 | - **Quality Incentives:** The incentive mechanism encourages Miners to continually improve the quality of their generated audio.
35 |
36 | - **Bittensor Network Integration:** Leveraging the Bittensor network ensures secure and transparent interactions between Validators and Miners.
37 |
38 | Join BittAudio and contribute to the advancement of decentralized Text-to-Music technology within the Bittensor ecosystem.
39 |
40 | ## SOTA Benchmarking for Audio Evaluation
41 |
42 | To ensure the quality of audio generated by miners, we use a SOTA (State-of-the-Art) benchmarking process. Validators download sound files from the following link:
43 | https://huggingface.co/datasets/etechgrid/ttm-validation-dataset
44 |
45 | The output generated by miners is evaluated using three different metrics:
46 |
47 | 1. **Kullback-Leibler Divergence (KLD):** Measures the divergence between two probability distributions, allowing us to assess the distribution similarity between generated music and the original distribution of audio data.
48 |
49 | 2. **Frechet Audio Distance (FAD):** Calculates the difference between the statistical distribution of generated audio and real-world audio. This metric provides a robust evaluation of the overall quality, including the structure and timbre of generated music.
50 |
51 | 3. **CLAP Metric:** Evaluates text consistency by determining how well the generated audio adheres to the original text prompt. It measures the semantic alignment between the input text and the generated output.
52 |
53 | These metrics provide a comprehensive evaluation of the audio quality, and this new system is more robust compared to the previous metrics, which were a combination of CLAP, SNR (Signal-to-Noise Ratio), and HNR (Harmonic-to-Noise Ratio).
54 |
55 | ### Benchmark Milestone:
56 | | Model | Frechet Audio Distance | KLD | Text Consistency | Chroma Cosine Similarity |
57 | |--------------------------|------------------------|------|------------------|--------------------------|
58 | | facebook/musicgen-small | 4.88 | 1.42 | 0.27 | - |
59 | | facebook/musicgen-medium | 5.14 | 1.38 | 0.28 | - |
60 | | facebook/musicgen-large | 5.48 | 1.37 | 0.28 | - |
61 | | facebook/musicgen-melody | 4.93 | 1.41 | 0.27 | 0.44 |
62 |
63 | For more information on these models, you can visit:
64 | https://huggingface.co/facebook/musicgen-small
65 |
66 | ## Installation
67 | ```bash
68 | git clone https://github.com/UncleTensor/BittAudio.git
69 | cd BittAudio
70 | pip install -e .
71 | pip install -r requirements.txt
72 | wandb login
73 | ```
74 |
75 | ## Recommended GPU Configuration
76 |
77 | It is recommended to use NVIDIA GeForce RTX A6000 GPUs at minimum for both Validators and Miners.
78 |
79 |
80 | **Evaluation Mechanism:**
81 | The evaluation mechanism involves the Validators querying miners on the network with random prompts and receiving TTM responses. These responses are scored based on correctness, and the weights on the Bittensor network are updated accordingly. The scoring is conducted using a reward function from the lib module.
82 |
83 | **Miner/Validator Hardware Specs:**
84 | The hardware requirements for miners and validators vary depending on the complexity and resource demands of the selected TTM models. Typically, a machine equipped with a capable CPU and GPU, along with sufficient VRAM and RAM, is necessary. The amount of disk space required will depend on the size of the models and any additional data.
85 |
86 | **How to Run a Validator:**
87 | To operate a validator, you need to run the validator.py script with the required command-line arguments. This script initiates the setup of Bittensor objects, establishes a connection to the network, queries miners, scores their responses, and updates weights accordingly.
88 |
89 | **How to Run a Miner:**
90 | To operate a miner, run the miner.py script with the necessary configuration. This process involves initializing Bittensor objects, establishing a connection to the network, and processing incoming TTM requests.
91 |
92 | **TTM Models Supported:**
93 | The code incorporates various Text-to-Music models. The specific requirements for each model, including CPU, GPU VRAM, RAM, and disk space, are not explicitly stated in the provided code. For these type of requirements, it may be necessary to consult the documentation or delve into the implementation details of these models.
94 |
95 | In general, the resource demands of TTM models can vary significantly. Larger models often necessitate more powerful GPUs and additional system resources. It is advisable to consult the documentation or model repository for the specific requirements of each model. Additionally, if GPU acceleration is employed, having a compatible GPU with enough VRAM is typically advantageous for faster processing.
96 |
97 | ## License
98 | This repository is licensed under the MIT License.
99 |
100 | ```text
101 | MIT License
102 |
103 | Copyright (c) 2024 Opentensor
104 |
105 | Permission is hereby granted, free of charge, to any person obtaining a copy
106 | of this software and associated documentation files (the "Software"), to deal
107 | in the Software without restriction, including without limitation the rights
108 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
109 | copies of the Software, and to permit persons to whom the Software is
110 | furnished to do so, subject to the following conditions:
111 |
112 | The above copyright notice and this permission notice shall be included in all
113 | copies or substantial portions of the Software.
114 |
115 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
116 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
117 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
118 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
119 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
120 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
121 | SOFTWARE.
122 |
123 | ```
124 |
--------------------------------------------------------------------------------
/docs/bittaudio.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UncleTensor/BittAudio/d964362e50733ce185bcbacdd1c172aa26121f1b/docs/bittaudio.jpg
--------------------------------------------------------------------------------
/docs/miner.md:
--------------------------------------------------------------------------------
1 | # Audio Generation Subnetwork Miner Guide
2 | Welcome to the Miner's guide for the Audio Generation Subnetwork within the Bittensor network. This document provides instructions for setting up and running a Miner node in the network.
3 |
4 | ## Overview
5 | Miners in the Audio Subnetwork are responsible for generating audio from text prompts received from Validators. Utilizing advanced text-to-music models, miners aim to produce high-fidelity, natural-sounding music. The quality of the generated audio directly influences the rewards miners receive.
6 |
7 | ## Installation
8 | Follow these steps to install the necessary components:
9 |
10 | **Set Conda Enviornment**
11 | ```bash
12 | mkdir -p ~/miniconda3
13 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
14 | bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
15 | rm -rf ~/miniconda3/miniconda.sh
16 | ~/miniconda3/bin/conda init bash
17 | ~/miniconda3/bin/conda init zsh
18 | conda create -n {conda-env} python=3.10 -y
19 | conda activate {conda-env}
20 | ```
21 | **Install Repo**
22 | ```bash
23 | git clone https://github.com/UncleTensor/BittAudio.git
24 | cd BittAudio
25 | pip install -e .
26 | pip install -r requirements.txt
27 | ```
28 | **Install pm2**
29 | ```bash
30 | sudo apt install nodejs npm
31 | sudo npm install pm2 -g
32 | ```
33 |
34 | ### Recommended GPU Configuration
35 | - NVIDIA GeForce RTX A6000 GPUs are recommended for optimal performance.
36 |
37 | ### Running a Miner
38 | - To operate a miner, run the miner.py script with the necessary configuration.
39 |
40 | ### Miner Commands
41 | ```bash
42 | pm2 start neurons/miner.py -- \
43 | --netuid 50 \
44 | --wallet.name {wallet_name} \
45 | --wallet.hotkey {hotkey_name} \
46 | --logging.trace \
47 | --music_path {ttm-model} \
48 | --axon.port {machine_port}
49 | ```
50 |
51 | ### Bittensor Miner Script Arguments:
52 |
53 | | **Category** | **Argument** | **Default Value** | **Description** |
54 | |---------------------------------|--------------------------------------|----------------------------|-----------------------------------------------------------------------------------------------------------------------|
55 | | **Text To Music Model** | `--music_model` | 'facebook/musicgen-medium' ; 'facebook/musicgen-large' | The model to use for Text-To-Music |
56 | | **Music Finetuned Model** | `--music_path` | /path/to/model | The model to use for Text-To-Music |
57 | | **Network UID** | `--netuid` | Mainnet: 50 | The chain subnet UID. |
58 | | **Bittensor Subtensor Arguments** | `--subtensor.chain_endpoint` | - | Endpoint for Bittensor chain connection.|
59 | | | `--subtensor.network` | - | Bittensor network endpoint.|
60 | | **Bittensor Logging Arguments** | `--logging.debug` | - | Enable debugging logs.|
61 | | **Bittensor Wallet Arguments** | `--wallet.name` | - | Name of the wallet.|
62 | | | `--wallet.hotkey` | - | Hotkey path for the wallet.|
63 | | **Bittensor Axon Arguments** | `--axon.port` | - | Port number for the axon server.|
64 | | **PM2 process name** | `--pm2_name` | 'SN50Miner' | Name for the pm2 process for Auto Update. |
65 |
66 |
67 |
68 |
69 |
70 | ### License
71 | Refer to the main README for the MIT License details.
72 |
--------------------------------------------------------------------------------
/docs/validator.md:
--------------------------------------------------------------------------------
1 | # Audio Generation Subnetwork Validator Guide
2 |
3 | Welcome to the Validator's guide for the Audio Generation Subnetwork within the Bittensor network. This document provides instructions for setting up and running a Validator node in the network.
4 |
5 | ## Overview
6 | Validators initiate the audio generation process by providing prompts to the Miners and evaluate the quality of the generated audio. They play a crucial role in maintaining the quality standards of the network. The prompts will be generated with the help of the Corcel API, Product by Subnet 18, which provides a infinite range of prompts for Text-To-Music.
7 |
8 | ## Installation
9 | Follow these steps to install the necessary components:
10 |
11 | **Set Conda Enviornment**
12 | ```bash
13 | mkdir -p ~/miniconda3
14 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
15 | bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
16 | rm -rf ~/miniconda3/miniconda.sh
17 | ~/miniconda3/bin/conda init bash
18 | bash
19 | ~/miniconda3/bin/conda init zsh
20 | conda create -n {conda-env} python=3.10 -y
21 | conda activate {conda-env}
22 | ```
23 | **Install Repo**
24 | ```bash
25 | sudo apt update
26 | sudo apt install build-essential -y
27 | git clone https://github.com/UncleTensor/BittAudio.git
28 | cd BittAudio
29 | pip install -e.
30 | pip install audiocraft
31 | pip install laion_clap==1.1.4
32 | pip install git+https://github.com/haoheliu/audioldm_eval
33 | pip install git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt
34 | sudo mkdir -p /tmp/music
35 | wandb login
36 | ```
37 | **Install pm2**
38 | ```bash
39 | sudo apt install nodejs npm
40 | sudo npm install pm2 -g
41 | ```
42 |
43 | ## Validator Command for Auto Update
44 | ```bash
45 | pm2 start scripts/start_valid.py -- \
46 | --pm2_name {name} \
47 | --netuid 50 \
48 | --wallet.name {wallet_name} \
49 | --wallet.hotkey {hotkey_name} \
50 | --subtensor.network {finney}
51 | ```
52 |
53 | ### Bittensor Validator Script Arguments:
54 |
55 | | **Category** | **Argument** | **Default Value** | **Description** |
56 | |---------------------------------|--------------------------------------|----------------------------|-----------------------------------------------------------------------------------------------------------------------|
57 | | **Configuration Arguments** | `--alpha` | 0.9 | The weight moving average scoring. |
58 | | | `--netuid` | Mainnet: 50 | The chain subnet UID. |
59 | | **Bittensor Subtensor Arguments** | `--subtensor.chain_endpoint` | - | Endpoint for Bittensor chain connection. |
60 | | | `--subtensor.network` | - | Bittensor network endpoint. |
61 | | **Bittensor Logging Arguments** | `--logging.debug` | - | Enable debugging logs. |
62 | | **Bittensor Wallet Arguments** | `--wallet.name` | - | Name of the wallet. |
63 | | | `--wallet.hotkey` | - | Hotkey path for the wallet. |
64 | | **PM2 process name** | `--pm2_name` | 'SN50Miner' | Name for the pm2 process for Auto Update. |
65 |
66 | ## Miners logs in Validator
67 |
68 | If it is required to check miners logs in the validator, one can go ahead to ~/.pm2/logs directory and grep the miners scoring logs
69 | as follows:
70 |
71 | sudo grep -a -A 10 "Raw score for hotkey:5DXTGaAQm99AEAvhMRqWQ77b1aob4mAXwX" ~/.pm2/logs/validator-out.log
72 |
73 | sudo grep -a -A 10 "Normalized score for hotkey:5DXTGaAQm99AEAvhMRqWQ77b1aob4mAXwX" ~/.pm2/logs/validator-out.log
74 |
75 | ### License
76 | Refer to the main README for the MIT License details.
77 |
78 |
--------------------------------------------------------------------------------
/lib/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.1.3"
2 | version_split = __version__.split(".")
3 | __spec_version__ = (
4 | (1000 * int(version_split[0]))
5 | + (10 * int(version_split[1]))
6 | + (1 * int(version_split[2]))
7 | )
8 |
9 | MIN_STAKE = 10000
--------------------------------------------------------------------------------
/lib/hashing.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import json
4 | from datetime import datetime
5 | import bittensor as bt
6 |
7 | hashes_file = 'music_hashes.json'
8 |
9 | # In-memory cache for fast lookup
10 | cache = set()
11 |
12 | def calculate_audio_hash(audio_data: bytes) -> str:
13 | """Calculate a SHA256 hash of the given audio data."""
14 | return hashlib.sha256(audio_data).hexdigest()
15 |
16 | def load_hashes_to_cache():
17 | """Load existing hashes from JSON file into in-memory cache."""
18 | if os.path.exists(hashes_file):
19 | with open(hashes_file, 'r') as file:
20 | data = json.load(file)
21 | for entry in data:
22 | cache.add(entry['hash']) # Add hash to in-memory cache
23 |
24 | def save_hash_to_file(hash_value: str, timestamp: str, miner_id: str = None):
25 | """Save the new hash to the JSON file and in-memory cache."""
26 | cache.add(hash_value) # Add to cache for fast lookups
27 | if os.path.exists(hashes_file):
28 | with open(hashes_file, 'r+') as file:
29 | data = json.load(file)
30 | data.append({'hash': hash_value, 'miner_id': miner_id, 'timestamp': timestamp})
31 | file.seek(0)
32 | json.dump(data, file)
33 | else:
34 | # If the file doesn't exist, create it with the initial hash entry
35 | with open(hashes_file, 'w') as file:
36 | json.dump([{'hash': hash_value, 'miner_id': miner_id, 'timestamp': timestamp}], file)
37 |
38 |
39 | def check_duplicate_music(hash_value: str) -> bool:
40 | """Check if the given hash already exists in the in-memory cache."""
41 | return hash_value in cache
42 |
43 | def process_miner_music(miner_id: str, audio_data: bytes):
44 | """Process music sent by a miner and check for duplicates."""
45 | audio_hash = calculate_audio_hash(audio_data) # Calculate the audio hash
46 |
47 | if check_duplicate_music(audio_hash): # Check if it's a duplicate
48 | bt.logging.info(f"Duplicate music detected from miner: {miner_id}")
49 | return # Do nothing if it's a duplicate
50 | else:
51 | timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
52 | save_hash_to_file(audio_hash, miner_id, timestamp) # Save the hash to the file and cache
53 | bt.logging.info(f"Music processed and saved successfully for miner: {miner_id}")
54 | return audio_hash
55 |
--------------------------------------------------------------------------------
/min_compute.yml:
--------------------------------------------------------------------------------
1 | # Use this document to specify the minimum compute requirements.
2 | # This document will be used to generate a list of recommended hardware for your subnet.
3 |
4 | # This is intended to give a rough estimate of the minimum requirements
5 | # so that the user can make an informed decision about whether or not
6 | # they want to run a miner or validator on their machine.
7 |
8 | # NOTE: Specification for miners may be different from validators
9 |
10 | version: '1.0' # update this version key as needed, ideally should match your release version
11 |
12 | compute_spec:
13 |
14 | miner:
15 |
16 | cpu:
17 | min_cores: 8 # Minimum number of CPU cores
18 | min_speed: 3.0GHz # Minimum speed per core
19 | architecture: x86_64 # Architecture type (e.g., x86_64, arm64)
20 |
21 | gpu:
22 | required: true # Does the application require a GPU?
23 | min_vram: 24GB # Minimum GPU VRAM
24 | cuda_cores: 1024 # Minimum number of CUDA cores (if applicable)
25 | min_compute_capability: 6.0 # Minimum CUDA compute capability
26 | recommended_gpu: "NVIDIA A6000" # provide a recommended GPU to purchase/rent
27 |
28 | memory:
29 | min_ram: 32GB # Minimum RAM
30 | min_swap: 4GB # Minimum swap space
31 | ram_type: "DDR4" # RAM type (e.g., DDR4, DDR3, etc.)
32 |
33 | storage:
34 | min_space: 100GB # Minimum free storage space
35 | type: SSD # Preferred storage type (e.g., SSD, HDD)
36 | iops: 1000 # Minimum I/O operations per second (if applicable)
37 |
38 | os:
39 | name: Ubuntu # Name of the preferred operating system(s)
40 | version: "20.04" # Version of the preferred operating system(s)
41 |
42 | validator:
43 |
44 | cpu:
45 | min_cores: 8 # Minimum number of CPU cores
46 | min_speed: 3.0GHz # Minimum speed per core
47 | architecture: x86_64 # Architecture type (e.g., x86_64, arm64)
48 |
49 | gpu:
50 | required: true # Does the application require a GPU?
51 | min_vram: 24GB # Minimum GPU VRAM
52 | cuda_cores: 1024 # Minimum number of CUDA cores (if applicable)
53 | min_compute_capability: 6.0 # Minimum CUDA compute capability
54 | recommended_gpu: "NVIDIA A6000" # provide a recommended GPU to purchase/rent
55 |
56 | memory:
57 | min_ram: 32GB # Minimum RAM
58 | min_swap: 4GB # Minimum swap space
59 | ram_type: "DDR4" # RAM type (e.g., DDR4, DDR3, etc.)
60 |
61 | storage:
62 | min_space: 100GB # Minimum free storage space
63 | type: SSD # Preferred storage type (e.g., SSD, HDD)
64 | iops: 1000 # Minimum I/O operations per second (if applicable)
65 |
66 | os:
67 | name: Ubuntu # Name of the preferred operating system(s)
68 | version: ">=20.04" # Version of the preferred operating system(s)
69 |
70 | network_spec:
71 | bandwidth:
72 | download: ">=200Mbps" # Minimum download bandwidth
73 | upload: ">=100Mbps" # Minimum upload bandwidth
74 |
--------------------------------------------------------------------------------
/neurons/__init__.py:
--------------------------------------------------------------------------------
1 | from . import validator
2 | from . import miner
--------------------------------------------------------------------------------
/neurons/miner.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import lib
4 | import time
5 | import torch
6 | import typing
7 | import argparse
8 | import traceback
9 | import torchaudio
10 | import bittensor as bt
11 | import ttm.protocol as protocol
12 | # from ttm.protocol import MusicGeneration
13 | from scipy.io.wavfile import write as write_wav
14 | from transformers import AutoProcessor, MusicgenForConditionalGeneration
15 |
16 | # Set the project root path
17 | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
18 | audio_subnet_path = os.path.abspath(project_root)
19 |
20 | # Add the project root and 'AudioSubnet' directories to sys.path
21 | sys.path.insert(0, project_root)
22 | sys.path.insert(0, audio_subnet_path)
23 |
24 | class MusicGenerator:
25 | def __init__(self, model_path="facebook/musicgen-medium"):
26 | """Initializes the MusicGenerator with a specified model path."""
27 | self.model_name = model_path
28 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29 |
30 | # Load the processor and model
31 | self.processor = AutoProcessor.from_pretrained(self.model_name)
32 | self.model = MusicgenForConditionalGeneration.from_pretrained(self.model_name).to(self.device)
33 |
34 | def generate_music(self, prompt, token):
35 | """Generates music based on a given prompt and token count."""
36 | try:
37 | inputs = self.processor(text=[prompt], padding=True, return_tensors="pt").to(self.device)
38 | audio_values = self.model.generate(**inputs, max_new_tokens=token)
39 | return audio_values[0, 0].cpu().numpy()
40 | except Exception as e:
41 | print(f"Error occurred with {self.model_name}: {e}")
42 | return None
43 |
44 | # Configuration setup
45 | def get_config():
46 | parser = argparse.ArgumentParser()
47 |
48 | # Add model selection for Text-to-Music
49 | parser.add_argument("--music_model", default='facebook/musicgen-medium', help="The model to be used for Music Generation.")
50 | parser.add_argument("--music_path", default=None, help="Path to a custom finetuned model for Music Generation.")
51 |
52 | # Add Bittensor specific arguments
53 | parser.add_argument("--netuid", type=int, default=50, help="The chain subnet uid.")
54 | bt.subtensor.add_args(parser)
55 | bt.logging.add_args(parser)
56 | bt.wallet.add_args(parser)
57 | bt.axon.add_args(parser)
58 |
59 | config = bt.config(parser)
60 |
61 | # Set up logging paths
62 | config.full_path = os.path.expanduser(f"{config.logging.logging_dir}/{config.wallet.name}/{config.wallet.hotkey}/netuid{config.netuid}/miner")
63 |
64 | # Ensure the logging directory exists
65 | if not os.path.exists(config.full_path):
66 | os.makedirs(config.full_path, exist_ok=True)
67 |
68 | return config
69 |
70 | # Main function
71 | def main(config):
72 | bt.logging(config=config, logging_dir=config.full_path)
73 | bt.logging.info(f"Running TTM miner for subnet: {config.netuid} on network: {config.subtensor.chain_endpoint}")
74 |
75 | # Text-to-Music Model Setup
76 | try:
77 | if config.music_path:
78 | bt.logging.info(f"Using custom model for Text-To-Music from: {config.music_path}")
79 | ttm_models = MusicGenerator(model_path=config.music_path)
80 | elif config.music_model in ["facebook/musicgen-medium", "facebook/musicgen-large"]:
81 | bt.logging.info(f"Using Text-To-Music model: {config.music_model}")
82 | ttm_models = MusicGenerator(model_path=config.music_model)
83 | else:
84 | bt.logging.error(f"Invalid music model: {config.music_model}")
85 | exit(1)
86 | except Exception as e:
87 | bt.logging.error(f"Error initializing Text-To-Music model: {e}")
88 | exit(1)
89 |
90 | # Bittensor object setup
91 | wallet = bt.wallet(config=config)
92 | bt.logging.info(f"Wallet: {wallet}")
93 |
94 | subtensor = bt.subtensor(config=config)
95 | bt.logging.info(f"Subtensor: {subtensor}")
96 |
97 | metagraph = subtensor.metagraph(config.netuid)
98 | bt.logging.info(f"Metagraph: {metagraph}")
99 |
100 | if wallet.hotkey.ss58_address not in metagraph.hotkeys:
101 | bt.logging.error("Miner not registered. Run btcli register and try again.")
102 | exit()
103 |
104 | # Check the miner's subnet UID
105 | my_subnet_uid = metagraph.hotkeys.index(wallet.hotkey.ss58_address)
106 |
107 | ######################## Text to Music Processing ########################
108 |
109 | def music_blacklist_fn(synapse: protocol.MusicGeneration) -> typing.Tuple[bool, str]:
110 | if synapse.dendrite.hotkey not in metagraph.hotkeys:
111 | bt.logging.trace(f"Blacklisting unrecognized hotkey {synapse.dendrite.hotkey}")
112 | return True, "Unrecognized hotkey"
113 | elif synapse.dendrite.hotkey in metagraph.hotkeys and metagraph.S[metagraph.hotkeys.index(synapse.dendrite.hotkey)] < lib.MIN_STAKE:
114 | # Ignore requests from entities with low stake.
115 | bt.logging.trace(
116 | f"Blacklisting hotkey {synapse.dendrite.hotkey} with low stake"
117 | )
118 | return True, "Low stake"
119 | else:
120 | return False, "Accepted"
121 |
122 | # The priority function determines the request handling order.
123 | def music_priority_fn(synapse: protocol.MusicGeneration) -> float:
124 | caller_uid = metagraph.hotkeys.index(synapse.dendrite.hotkey)
125 | priority = float(metagraph.S[caller_uid])
126 | bt.logging.trace(f"Prioritizing {synapse.dendrite.hotkey} with stake: {priority}")
127 | return priority
128 |
129 | def convert_music_to_tensor(audio_file):
130 | """Convert the audio file to a tensor."""
131 | try:
132 | _, file_extension = os.path.splitext(audio_file)
133 | if file_extension.lower() in ['.wav', '.mp3']:
134 | audio, sample_rate = torchaudio.load(audio_file)
135 | return audio[0].tolist() # Convert to tensor/list
136 | else:
137 | bt.logging.error(f"Unsupported file format: {file_extension}")
138 | return None
139 | except Exception as e:
140 | bt.logging.error(f"Error converting file: {e}")
141 |
142 | def ProcessMusic(synapse: protocol.MusicGeneration) -> protocol.MusicGeneration:
143 | bt.logging.info(f"Generating music with model: {config.music_path if config.music_path else config.music_model}")
144 | print(f"synapse.text_input: {synapse.text_input}")
145 | print(f"synapse.duration: {synapse.duration}")
146 | music = ttm_models.generate_music(synapse.text_input, synapse.duration)
147 |
148 | if music is None:
149 | bt.logging.error("No music generated!")
150 | return None
151 | try:
152 | sampling_rate = 32000
153 | write_wav("random_sample.wav", rate=sampling_rate, data=music)
154 | bt.logging.success("Music generated and saved to random_sample.wav")
155 | music_tensor = convert_music_to_tensor("random_sample.wav")
156 | synapse.music_output = music_tensor
157 | return synapse
158 | except Exception as e:
159 | bt.logging.error(f"Error processing music output: {e}")
160 | return None
161 |
162 | ######################## Attach Axon and Serve ########################
163 |
164 | axon = bt.axon(wallet=wallet, config=config)
165 | bt.logging.info(f"Axon {axon}")
166 |
167 | # Attach forward function for TTM processing
168 | axon.attach(
169 | forward_fn=ProcessMusic,
170 | blacklist_fn=music_blacklist_fn,
171 | priority_fn=music_priority_fn,
172 | )
173 |
174 | # Serve the axon on the network
175 | bt.logging.info(f"Serving axon on network: {config.subtensor.chain_endpoint} with netuid: {config.netuid}")
176 | axon.serve(netuid=config.netuid, subtensor=subtensor)
177 |
178 | # Start the miner's axon
179 | bt.logging.info(f"Starting axon server on port: {config.axon.port}")
180 | axon.start()
181 |
182 | # Keep the miner running
183 | bt.logging.info("Starting main loop")
184 | step = 0
185 | while True:
186 | try:
187 | # Periodically update knowledge of the network graph
188 | if step % 500 == 0:
189 | metagraph = subtensor.metagraph(config.netuid)
190 | log = (
191 | f"Step:{step} | "
192 | f"Block:{metagraph.block.item()} | "
193 | f"Stake:{metagraph.S[my_subnet_uid]:.6f} | "
194 | f"Rank:{metagraph.R[my_subnet_uid]:.6f} | "
195 | f"Trust:{metagraph.T[my_subnet_uid]} | "
196 | f"Consensus:{metagraph.C[my_subnet_uid]:.6f} | "
197 | f"Incentive:{metagraph.I[my_subnet_uid]:.6f} | "
198 | f"Emission:{metagraph.E[my_subnet_uid]}"
199 | )
200 | bt.logging.info(log)
201 | step += 1
202 | time.sleep(1)
203 |
204 | # Stop the miner safely
205 | except KeyboardInterrupt:
206 | axon.stop()
207 | break
208 |
209 | # Log any unexpected errors
210 | except Exception as e:
211 | bt.logging.error(f"unexpected error",traceback.format_exc())
212 | continue
213 |
214 | # Entry point
215 | if __name__ == "__main__":
216 | config = get_config()
217 | main(config)
--------------------------------------------------------------------------------
/neurons/validator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import asyncio
4 | import datetime as dt
5 | import wandb
6 | import bittensor as bt
7 | import uvicorn
8 | from pyngrok import ngrok # Import ngrok from pyngrok
9 |
10 | # Set the project root path
11 | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
12 | # Set the 'AudioSubnet' directory path
13 | audio_subnet_path = os.path.abspath(project_root)
14 |
15 | # Add the project root and 'AudioSubnet' directories to sys.path
16 | sys.path.insert(0, project_root)
17 | sys.path.insert(0, audio_subnet_path)
18 |
19 | from ttm.ttm import MusicGenerationService
20 | from ttm.aimodel import AIModelService
21 |
22 |
23 | class AIModelController():
24 | def __init__(self):
25 | self.aimodel = AIModelService()
26 | self.music_generation_service = MusicGenerationService()
27 | self.current_service = self.music_generation_service
28 | self.last_run_start_time = dt.datetime.now()
29 |
30 | async def run_fastapi_with_ngrok(self, app):
31 | # Setup ngrok tunnel
32 | ngrok_tunnel = ngrok.connect(38287, bind_tls=True)
33 | print('Public URL:', ngrok_tunnel.public_url)
34 | # Create and start the uvicorn server as a background task
35 | config = uvicorn.Config(app=app, host="0.0.0.0", port=38287) # Ensure port matches ngrok's
36 | server = uvicorn.Server(config)
37 | # No need to await here, as we want this to run in the background
38 | task = asyncio.create_task(server.serve())
39 | return ngrok_tunnel, task # Returning task if you need to cancel it later
40 |
41 |
42 | async def run_services(self):
43 | while True:
44 | self.check_and_update_wandb_run()
45 | await self.music_generation_service.run_async()
46 |
47 | def check_and_update_wandb_run(self):
48 | # Calculate the time difference between now and the last run start time
49 | current_time = dt.datetime.now()
50 | time_diff = current_time - self.last_run_start_time
51 | # Check if 4 hours have passed since the last run start time
52 | if time_diff.total_seconds() >= 4 * 3600: # 4 hours * 3600 seconds/hour
53 | self.last_run_start_time = current_time # Update the last run start time to now
54 | if self.wandb_run:
55 | wandb.finish() # End the current run
56 | self.new_wandb_run() # Start a new run
57 |
58 | def new_wandb_run(self):
59 | now = dt.datetime.now()
60 | run_id = now.strftime("%Y-%m-%d_%H-%M-%S")
61 | name = f"Validator-{self.aimodel.uid}-{run_id}"
62 | commit = self.aimodel.get_git_commit_hash()
63 | self.wandb_run = wandb.init(
64 | name=name,
65 | project="AudioSubnet_Valid",
66 | entity="subnet16team",
67 | config={
68 | "uid": self.aimodel.uid,
69 | "hotkey": self.aimodel.wallet.hotkey.ss58_address,
70 | "run_name": run_id,
71 | "type": "Validator",
72 | "tao (stake)": self.aimodel.metagraph.neurons[self.aimodel.uid].stake.tao,
73 | "commit": commit,
74 | },
75 | tags=self.aimodel.sys_info,
76 | allow_val_change=True,
77 | anonymous="allow",
78 | )
79 | bt.logging.debug(f"Started a new wandb run: {name}")
80 |
81 | async def setup_and_run(controller):
82 | tasks = []
83 | secret_key = os.getenv("AUTH_SECRET_KEY")
84 | if os.path.exists(os.path.join(project_root, 'app')) and secret_key:
85 | app = create_app(secret_key)
86 | # Start FastAPI with ngrok without blocking
87 | ngrok_tunnel, server_task = await controller.run_fastapi_with_ngrok(app)
88 | tasks.append(server_task) # Keep track of the server task if you need to cancel it later
89 |
90 | # Start service-related tasks
91 | service_task = asyncio.create_task(controller.run_services())
92 | tasks.append(service_task)
93 |
94 | # Wait for all tasks to complete
95 | await asyncio.gather(*tasks)
96 |
97 | # Cleanup, in case you need to close ngrok or other resources
98 | ngrok_tunnel.close()
99 |
100 | async def main():
101 | controller = AIModelController()
102 | await setup_and_run(controller)
103 |
104 | if __name__ == "__main__":
105 | asyncio.run(main())
106 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | bittensor==7.4.0
2 | pydantic
3 | psutil
4 | GPUtil
5 | inflect
6 | huggingface_hub
7 | librosa
8 | torchaudio
9 | scipy
10 | transformers
11 | datasets==3.0.1
12 | wandb==0.18.5
13 | pyngrok==7.2.0
14 | tabulate==0.9.0
15 | resampy==0.4.3
16 | aiofiles==23.2.1
17 | altair==5.4.1
18 | annotated-types==0.7.0
19 | antlr4-python3-runtime==4.9.3
20 | anyio==4.6.0
21 | attrs==24.2.0
22 | audioread==3.0.1
23 | av==11.0.0
24 | blis==0.7.11
25 | catalogue==2.0.10
26 | certifi==2024.7.4
27 | cffi==1.17.1
28 | charset-normalizer==3.3.2
29 | click==8.1.7
30 | cloudpathlib==0.19.0
31 | cloudpickle==3.0.0
32 | colorlog==6.8.2
33 | confection==0.1.5
34 | contourpy==1.2.1
35 | cycler==0.12.1
36 | cymem==2.0.8
37 | decorator==5.1.1
38 | demucs==4.0.1
39 | docopt==0.6.2
40 | dora_search==0.1.12
41 | einops==0.8.0
42 | encodec==0.1.1
43 | exceptiongroup==1.2.2
44 | fastapi==0.110.1
45 | ffmpy==0.4.0
46 | filelock==3.16.1
47 | flashy==0.0.2
48 | fonttools==4.54.1
49 | fsspec
50 | gradio==4.43.0
51 | gradio_client==1.3.0
52 | h11==0.14.0
53 | httpcore==1.0.6
54 | httpx==0.27.2
55 | huggingface-hub==0.25.1
56 | hydra-colorlog==1.2.0
57 | hydra-core==1.3.2
58 | idna==3.10
59 | importlib_resources==6.4.5
60 | Jinja2==3.1.4
61 | joblib==1.4.2
62 | jsonschema==4.23.0
63 | jsonschema-specifications==2023.12.1
64 | julius==0.2.7
65 | kiwisolver==1.4.7
66 | lameenc==1.7.0
67 | langcodes==3.4.1
68 | language_data==1.2.0
69 | lazy_loader==0.4
70 | librosa==0.10.0
71 | lightning-utilities==0.11.7
72 | llvmlite==0.42.0
73 | marisa-trie==1.2.0
74 | markdown-it-py==3.0.0
75 | MarkupSafe==2.1.5
76 | matplotlib==3.8.3
77 | mdurl==0.1.2
78 | mpmath==1.3.0
79 | msgpack==1.1.0
80 | murmurhash==1.0.10
81 | narwhals==1.9.0
82 | networkx==3.2.1
83 | num2words==0.5.13
84 | numba==0.59.1
85 | numpy==1.26.4
86 | nvidia-cublas-cu12==12.1.3.1
87 | nvidia-cuda-cupti-cu12==12.1.105
88 | nvidia-cuda-nvrtc-cu12==12.1.105
89 | nvidia-cuda-runtime-cu12==12.1.105
90 | nvidia-cudnn-cu12==8.9.2.26
91 | nvidia-cufft-cu12==11.0.2.54
92 | nvidia-curand-cu12==10.3.2.106
93 | nvidia-cusolver-cu12==11.4.5.107
94 | nvidia-cusparse-cu12==12.1.0.106
95 | nvidia-nccl-cu12==2.18.1
96 | nvidia-nvjitlink-cu12==12.6.77
97 | nvidia-nvtx-cu12==12.1.105
98 | omegaconf==2.3.0
99 | openunmix==1.3.0
100 | orjson==3.10.7
101 | packaging==24.1
102 | pandas==2.2.1
103 | pesq==0.0.4
104 | pillow==10.4.0
105 | pip==23.0.1
106 | platformdirs==4.3.6
107 | pooch==1.8.2
108 | preshed==3.0.9
109 | protobuf==5.28.2
110 | pycparser==2.22
111 | pydantic==2.9.2
112 | pydantic_core==2.23.4
113 | pydub==0.25.1
114 | Pygments==2.18.0
115 | pyparsing==3.1.4
116 | pystoi==0.4.1
117 | python-dateutil==2.9.0.post0
118 | python-multipart==0.0.12
119 | pytz==2024.2
120 | PyYAML==6.0.2
121 | referencing==0.35.1
122 | regex==2024.9.11
123 | requests==2.32.3
124 | retrying==1.3.4
125 | rich==13.9.2
126 | rpds-py==0.20.0
127 | ruff==0.6.9
128 | safetensors==0.4.5
129 | scikit-learn==1.4.2
130 | scipy==1.13.0
131 | semantic-version==2.10.0
132 | sentencepiece==0.2.0
133 | setuptools==70.0.0
134 | shellingham==1.5.4
135 | six==1.16.0
136 | smart-open==7.0.5
137 | sniffio==1.3.1
138 | soundfile==0.12.1
139 | soxr==0.5.0.post1
140 | spacy==3.7.5
141 | spacy-legacy==3.0.12
142 | spacy-loggers==1.0.5
143 | srsly==2.4.8
144 | starlette==0.37.2
145 | submitit==1.5.2
146 | sympy==1.13.3
147 | thinc==8.2.5
148 | threadpoolctl==3.5.0
149 | tokenizers==0.15.2
150 | tomlkit==0.12.0
151 | torch==2.1.0
152 | torchaudio==2.1.0
153 | torchdata==0.7.0
154 | torchmetrics==1.3.2
155 | torchtext==0.16.0
156 | torchvision==0.16.0
157 | tqdm==4.66.5
158 | transformers==4.38.2
159 | treetable==0.2.5
160 | triton
161 | typer==0.12.5
162 | typing_extensions==4.12.2
163 | tzdata==2024.2
164 | urllib3==2.2.3
165 | uvicorn==0.31.0
166 | wasabi==1.1.3
167 | weasel==0.4.1
168 | websockets==11.0.3
169 | wrapt==1.16.0
170 | xformers==0.0.22.post7
171 | zipp==3.20.2
--------------------------------------------------------------------------------
/scripts/check_compatibility.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ -z "$1" ]; then
4 | echo "Please provide a Python version as an argument."
5 | exit 1
6 | fi
7 |
8 | python_version="$1"
9 | all_passed=true
10 |
11 | GREEN='\033[0;32m'
12 | YELLOW='\033[0;33m'
13 | RED='\033[0;31m'
14 | NC='\033[0m' # No Color
15 |
16 | check_compatibility() {
17 | all_supported=0
18 |
19 | while read -r requirement; do
20 | # Skip lines starting with git+
21 | if [[ "$requirement" == git+* ]]; then
22 | continue
23 | fi
24 |
25 | package_name=$(echo "$requirement" | awk -F'[!=<>]' '{print $1}' | awk -F'[' '{print $1}') # Strip off brackets
26 | echo -n "Checking $package_name... "
27 |
28 | url="https://pypi.org/pypi/$package_name/json"
29 | response=$(curl -s $url)
30 | status_code=$(curl -s -o /dev/null -w "%{http_code}" $url)
31 |
32 | if [ "$status_code" != "200" ]; then
33 | echo -e "${RED}Information not available for $package_name. Failure.${NC}"
34 | all_supported=1
35 | continue
36 | fi
37 |
38 | classifiers=$(echo "$response" | jq -r '.info.classifiers[]')
39 | requires_python=$(echo "$response" | jq -r '.info.requires_python')
40 |
41 | base_version="Programming Language :: Python :: ${python_version%%.*}"
42 | specific_version="Programming Language :: Python :: $python_version"
43 |
44 | if echo "$classifiers" | grep -q "$specific_version" || echo "$classifiers" | grep -q "$base_version"; then
45 | echo -e "${GREEN}Supported${NC}"
46 | elif [ "$requires_python" != "null" ]; then
47 | if echo "$requires_python" | grep -Eq "==$python_version|>=$python_version|<=$python_version"; then
48 | echo -e "${GREEN}Supported${NC}"
49 | else
50 | echo -e "${RED}Not compatible with Python $python_version due to constraint $requires_python.${NC}"
51 | all_supported=1
52 | fi
53 | else
54 | echo -e "${YELLOW}Warning: Specific version not listed, assuming compatibility${NC}"
55 | fi
56 | done < requirements.txt
57 |
58 | return $all_supported
59 | }
60 |
61 | echo "Checking compatibility for Python $python_version..."
62 | check_compatibility
63 | if [ $? -eq 0 ]; then
64 | echo -e "${GREEN}All requirements are compatible with Python $python_version.${NC}"
65 | else
66 | echo -e "${RED}All requirements are NOT compatible with Python $python_version.${NC}"
67 | all_passed=false
68 | fi
69 |
70 | echo ""
71 | if $all_passed; then
72 | echo -e "${GREEN}All tests passed.${NC}"
73 | else
74 | echo -e "${RED}All tests did not pass.${NC}"
75 | exit 1
76 | fi
77 |
--------------------------------------------------------------------------------
/scripts/check_requirements_changes.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Check if requirements files have changed in the last commit
4 | if git diff --name-only HEAD~1 | grep -E 'requirements.txt|requirements.txt'; then
5 | echo "Requirements files have changed. Running compatibility checks..."
6 | echo 'export REQUIREMENTS_CHANGED="true"' >> $BASH_ENV
7 | else
8 | echo "Requirements files have not changed. Skipping compatibility checks..."
9 | echo 'export REQUIREMENTS_CHANGED="false"' >> $BASH_ENV
10 | fi
11 |
--------------------------------------------------------------------------------
/scripts/start_valid.py:
--------------------------------------------------------------------------------
1 | """
2 | This script runs a validator process and automatically updates it when a new version is released.
3 | Command-line arguments will be forwarded to validator (`neurons/validator.py`), so you can pass
4 | them like this:
5 | python3 scripts/start_validator.py --wallet.name=my-wallet
6 | Auto-updates are enabled by default and will make sure that the latest version is always running
7 | by pulling the latest version from git and upgrading python packages. This is done periodically.
8 | Local changes may prevent the update, but they will be preserved.
9 |
10 | The script will use the same virtual environment as the one used to run it. If you want to run
11 | validator within virtual environment, run this auto-update script from the virtual environment.
12 |
13 | Pm2 is required for this script. This script will start a pm2 process using the name provided by
14 | the --pm2_name argument.
15 | """
16 | import argparse
17 | import logging
18 | import subprocess
19 | import sys
20 | import time
21 | from datetime import timedelta
22 | from shlex import split
23 | from pathlib import Path
24 | from typing import List
25 |
26 | log = logging.getLogger(__name__)
27 | UPDATES_CHECK_TIME = timedelta(minutes=5)
28 | ROOT_DIR = Path(__file__).parent.parent
29 |
30 |
31 |
32 | def get_version() -> str:
33 | """Extract the version as current git commit hash"""
34 | result = subprocess.run(
35 | split("git rev-parse HEAD"),
36 | check=True,
37 | capture_output=True,
38 | cwd=ROOT_DIR,
39 | )
40 | commit = result.stdout.decode().strip()
41 | assert len(commit) == 40, f"Invalid commit hash: {commit}"
42 | return commit[:8]
43 |
44 |
45 | def start_validator_process(pm2_name: str, args: List[str]) -> subprocess.Popen:
46 | """
47 | Spawn a new python process running neurons.validator.
48 | `sys.executable` ensures the same python interpreter is used as the one
49 | used to run this auto-updater.
50 | """
51 | assert sys.executable, "Failed to get python executable"
52 |
53 | log.info("Starting validator process with pm2, name: %s", pm2_name)
54 | process = subprocess.Popen(
55 | (
56 | "pm2",
57 | "start",
58 | sys.executable,
59 | "--name",
60 | pm2_name,
61 | "--",
62 | "-m",
63 | "neurons.validator",
64 | *args, # Added to include additional arguments
65 | ),
66 | cwd=ROOT_DIR,
67 | )
68 | process.pm2_name = pm2_name
69 |
70 | return process
71 |
72 |
73 |
74 | def stop_validator_process(process: subprocess.Popen) -> None:
75 | """Stop the validator process"""
76 | subprocess.run(
77 | ("pm2", "delete", process.pm2_name), cwd=ROOT_DIR, check=True
78 | )
79 |
80 |
81 | def pull_latest_version() -> None:
82 | """
83 | Pull the latest version from git.
84 | This uses `git pull --rebase`, so if any changes were made to the local repository,
85 | this will try to apply them on top of origin's changes. This is intentional, as we
86 | don't want to overwrite any local changes. However, if there are any conflicts,
87 | this will abort the rebase and return to the original state.
88 | The conflicts are expected to happen rarely since validator is expected
89 | to be used as-is.
90 | """
91 | try:
92 | subprocess.run(
93 | split("git pull --rebase --autostash"), check=True, cwd=ROOT_DIR
94 | )
95 | except subprocess.CalledProcessError as exc:
96 | log.error("Failed to pull, reverting: %s", exc)
97 | subprocess.run(split("git rebase --abort"), check=True, cwd=ROOT_DIR)
98 |
99 |
100 | def upgrade_packages() -> None:
101 | """
102 | Upgrade python packages by running `pip install --upgrade -r requirements.txt`.
103 | Notice: this won't work if some package in `requirements.txt` is downgraded.
104 | Ignored as this is unlikely to happen.
105 | """
106 |
107 | log.info("Upgrading packages")
108 | try:
109 | subprocess.run(
110 | split(f"{sys.executable} -m pip install -e ."),
111 | check=True,
112 | cwd=ROOT_DIR,
113 | )
114 | except subprocess.CalledProcessError as exc:
115 | log.error("Failed to upgrade packages, proceeding anyway. %s", exc)
116 |
117 |
118 | def main(pm2_name: str, args: List[str]) -> None:
119 | """
120 | Run the validator process and automatically update it when a new version is released.
121 | This will check for updates every `UPDATES_CHECK_TIME` and update the validator
122 | if a new version is available. Update is performed as simple `git pull --rebase`.
123 | """
124 |
125 | validator = start_validator_process(pm2_name, args)
126 | current_version = latest_version = get_version()
127 | log.info("Current version: %s", current_version)
128 |
129 | try:
130 | while True:
131 | pull_latest_version()
132 | latest_version = get_version()
133 | log.info("Latest version: %s", latest_version)
134 |
135 | if latest_version != current_version:
136 | log.info(
137 | "Upgraded to latest version: %s -> %s",
138 | current_version,
139 | latest_version,
140 | )
141 | upgrade_packages()
142 |
143 | stop_validator_process(validator)
144 | validator = start_validator_process(pm2_name, args)
145 | current_version = latest_version
146 |
147 | time.sleep(UPDATES_CHECK_TIME.total_seconds())
148 |
149 | finally:
150 | stop_validator_process(validator)
151 |
152 |
153 | if __name__ == "__main__":
154 | logging.basicConfig(
155 | level=logging.INFO,
156 | format="%(asctime)s %(levelname)s %(message)s",
157 | handlers=[logging.StreamHandler(sys.stdout)],
158 | )
159 |
160 | parser = argparse.ArgumentParser(
161 | description="Automatically update and restart the validator process when a new version is released.",
162 | epilog="Example usage: python start_validator.py --pm2_name 'validator' --wallet_name 'wallet1' --wallet_hotkey 'key123'",
163 | )
164 |
165 | parser.add_argument(
166 | "--pm2_name", default="validator", help="Name of the PM2 process."
167 | )
168 |
169 | flags, extra_args = parser.parse_known_args()
170 |
171 | main(flags.pm2_name, extra_args)
172 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 | # Copyright © 2023 Yuma Rao
3 | # Copyright © 2023 Omega Labs, Inc.
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation
7 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
8 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
11 | # the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
14 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
15 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
16 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
17 | # DEALINGS IN THE SOFTWARE.
18 |
19 | import re
20 | import os
21 | import codecs
22 | from os import path
23 | from io import open
24 | from setuptools import setup, find_packages
25 |
26 |
27 | def read_requirements(path):
28 | with open(path, "r") as f:
29 | requirements = f.read().splitlines()
30 | return requirements
31 |
32 |
33 | requirements = read_requirements("requirements.txt")
34 | here = path.abspath(path.dirname(__file__))
35 |
36 | with open(path.join(here, "README.md"), encoding="utf-8") as f:
37 | long_description = f.read()
38 |
39 | # loading version from setup.py
40 | with codecs.open(
41 | os.path.join(here, "lib/__init__.py"), encoding="utf-8"
42 | ) as init_file:
43 | version_match = re.search(
44 | r"^__version__ = ['\"]([^'\"]*)['\"]", init_file.read(), re.M
45 | )
46 | version_string = version_match.group(1)
47 |
48 | setup(
49 | name="ttm_bittensor_subnet",
50 | version=version_string,
51 | description="ttm_bittensor_subnet",
52 | long_description=long_description,
53 | long_description_content_type="text/markdown",
54 | url="https://github.com/UncleTensor/BittAudio.git",
55 | author="ttm",
56 | packages=find_packages(),
57 | include_package_data=True,
58 | author_email="",
59 | license="MIT",
60 | python_requires=">=3.10",
61 | install_requires=requirements,
62 | classifiers=[
63 | "Development Status :: 3 - Alpha",
64 | "Intended Audience :: Developers",
65 | "Topic :: Software Development :: Build Tools",
66 | "License :: OSI Approved :: MIT License",
67 | "Programming Language :: Python :: 3 :: Only",
68 | "Programming Language :: Python :: 3.8",
69 | "Programming Language :: Python :: 3.9",
70 | "Programming Language :: Python :: 3.10",
71 | "Topic :: Scientific/Engineering",
72 | "Topic :: Scientific/Engineering :: Mathematics",
73 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
74 | "Topic :: Software Development",
75 | "Topic :: Software Development :: Libraries",
76 | "Topic :: Software Development :: Libraries :: Python Modules",
77 | ],
78 | )
--------------------------------------------------------------------------------
/ttm/aimodel.py:
--------------------------------------------------------------------------------
1 | import bittensor as bt
2 | import pandas as pd
3 | import subprocess
4 | import platform
5 | import argparse
6 | import inflect
7 | import psutil
8 | import GPUtil
9 | import sys
10 | import os
11 | import re
12 | from lib import __spec_version__ as spec_version
13 |
14 |
15 | class AIModelService:
16 | _scores = None
17 | _base_initialized = False # Class-level flag for one-time initialization
18 | version: int = spec_version # Adjust version as necessary
19 |
20 | def __init__(self):
21 | self.config = self.get_config()
22 | self.sys_info = self.get_system_info()
23 | self.setup_paths()
24 | self.setup_logging()
25 | self.wallet = bt.wallet(config=self.config)
26 | self.subtensor = bt.subtensor(config=self.config)
27 | self.dendrite = bt.dendrite(wallet=self.wallet)
28 | self.metagraph = self.subtensor.metagraph(self.config.netuid)
29 | self.p = inflect.engine()
30 |
31 | if not AIModelService._base_initialized:
32 | bt.logging.info(f"Wallet: {self.wallet}")
33 | bt.logging.info(f"Subtensor: {self.subtensor}")
34 | bt.logging.info(f"Dendrite: {self.dendrite}")
35 | bt.logging.info(f"Metagraph: {self.metagraph}")
36 | AIModelService._base_initialized = True
37 |
38 | if AIModelService._scores is None:
39 | AIModelService._scores = self.metagraph.E.copy()
40 | self.scores = AIModelService._scores
41 | self.uid = self.metagraph.hotkeys.index(self.wallet.hotkey.ss58_address)
42 |
43 | def get_config(self):
44 | parser = argparse.ArgumentParser()
45 |
46 | parser.add_argument("--alpha", default=0.1, type=float, help="The weight moving average scoring.")
47 | parser.add_argument("--custom", default="my_custom_value", help="Adds a custom value to the parser.")
48 | parser.add_argument("--subtensor.network", type=str, help="The logging directory.")
49 | parser.add_argument("--netuid", default=50, type=int, help="The chain subnet uid.")
50 | parser.add_argument("--wallet.name", type=str, help="The wallet name.")
51 | parser.add_argument("--wallet.hotkey", type=str, help="The wallet hotkey.")
52 |
53 | # Add Bittensor specific arguments
54 | bt.subtensor.add_args(parser)
55 | bt.logging.add_args(parser)
56 | bt.wallet.add_args(parser)
57 |
58 | # Parse and return the config
59 | config = bt.config(parser)
60 | return config
61 |
62 | def priority_uids(self, metagraph):
63 | hotkeys = metagraph.hotkeys # List of hotkeys
64 | coldkeys = metagraph.coldkeys # List of coldkeys
65 | UIDs = range(len(hotkeys)) # Assuming UID is the index of neurons
66 | stakes = metagraph.S.numpy() # Total stake
67 | emissions = metagraph.E.numpy() # Emission
68 |
69 | # Create a DataFrame from the metagraph data
70 | df = pd.DataFrame({
71 | "UID": UIDs,
72 | "HOTKEY": hotkeys,
73 | "COLDKEY": coldkeys,
74 | "STAKE": stakes,
75 | "EMISSION": emissions,
76 | "AXON": metagraph.axons,
77 | })
78 |
79 | # Filter and sort the DataFrame
80 | df = df[df['STAKE'] < 500]
81 | df = df.sort_values(by=["EMISSION"], ascending=False)
82 | uid = df.iloc[0]['UID']
83 | axon_info = df.iloc[0]['AXON']
84 |
85 | result = [(uid, axon_info)]
86 | return result
87 |
88 | def get_system_info(self):
89 | system_info = {
90 | "OS Version": platform.platform(),
91 | "CPU Count": os.cpu_count(),
92 | "RAM": f"{psutil.virtual_memory().total / (1024**3):.2f} GB",
93 | }
94 |
95 | gpus = GPUtil.getGPUs()
96 | if gpus:
97 | system_info["GPU"] = gpus[0].name
98 |
99 | # Convert dictionary to list of strings for logging purposes
100 | tags = [f"{key}: {value}" for key, value in system_info.items()]
101 | return tags
102 |
103 | def setup_paths(self):
104 | # Set the project root path
105 | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
106 |
107 | # Add project root to sys.path
108 | sys.path.insert(0, project_root)
109 |
110 | def convert_numeric_values(self, input_prompt):
111 | # Regular expression to identify date patterns
112 | date_pattern = r'(\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b)|(\b\d{2,4}[/-]\d{1,2}[/-]\d{1,2}\b)'
113 |
114 | # Regular expression to find ordinal numbers
115 | ordinal_pattern = r'\b\d{1,2}(st|nd|rd|th)\b'
116 |
117 | # Regular expression to find numeric values with and without commas, excluding those part of date patterns and ordinals
118 | numeric_pattern = r'\b(? List:
42 | """
43 | Processes and returns the music_output into a format ready for audio rendering or further analysis.
44 | """
45 | return self
46 |
--------------------------------------------------------------------------------
/ttm/ttm.py:
--------------------------------------------------------------------------------
1 | from lib.hashing import load_hashes_to_cache, check_duplicate_music, save_hash_to_file
2 | from ttm.ttm_score import MusicQualityEvaluator
3 | from ttm.protocol import MusicGeneration
4 | from ttm.aimodel import AIModelService
5 | from datasets import load_dataset
6 | from datetime import datetime
7 | from tabulate import tabulate
8 | import bittensor as bt
9 | import soundfile as sf
10 | import numpy as np
11 | import torchaudio
12 | import contextlib
13 | import traceback
14 | import asyncio
15 | import hashlib
16 | import random
17 | import torch
18 | import wandb
19 | import wave
20 | import lib
21 | import sys
22 | import os
23 | import re
24 |
25 |
26 | # Set the project root path
27 | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
28 | audio_subnet_path = os.path.abspath(project_root)
29 | sys.path.insert(0, project_root)
30 | sys.path.insert(0, audio_subnet_path)
31 | refrence_dir = "audio_files" # Directory to save the audio files
32 | class MusicGenerationService(AIModelService):
33 | def __init__(self):
34 | super().__init__()
35 | self.load_prompts()
36 | self.total_dendrites_per_query = 10
37 | self.minimum_dendrites_per_query = 3
38 | self.current_block = self.subtensor.block
39 | self.last_updated_block = self.current_block - (self.current_block % 100)
40 | self.last_reset_weights_block = self.current_block
41 | self.filtered_axon = []
42 | self.combinations = []
43 | self.duration = None
44 | self.lock = asyncio.Lock()
45 | self.audio_path = None
46 | # Load hashes from file to cache at startup
47 | load_hashes_to_cache()
48 |
49 | def load_prompts(self):
50 | # Load the dataset (you can change this to any other dataset name)
51 | dataset = load_dataset("etechgrid/ttm-validation-dataset", split="train") # Adjust the split if needed (train, test, etc.)
52 | random_index = random.randint(0, len(dataset) - 1)
53 | self.random_sample = dataset[random_index]
54 | # Checking if the prompt exists in the dataset
55 | if 'Prompts' in self.random_sample:
56 | prompt = self.random_sample['Prompts']
57 | bt.logging.info(f"Returning the prompt: {prompt}")
58 | else:
59 | print("'Prompt' not found in the sample.")
60 | return None # Return None if no prompt found
61 |
62 | # Check if audio data exists and save it
63 | if 'File_Path' in self.random_sample and isinstance(self.random_sample['File_Path'], dict):
64 | file_path = self.random_sample['File_Path']
65 | if 'array' in file_path and 'sampling_rate' in file_path:
66 | audio_array = file_path['array']
67 | sample_rate = file_path['sampling_rate']
68 |
69 | # Save the audio to a file
70 | os.makedirs(refrence_dir, exist_ok=True) # Create output directory if it doesn't exist
71 | audio_path = os.path.join(refrence_dir, "random_sample.wav")
72 |
73 | try:
74 | # Save the audio data using soundfile
75 | sf.write(audio_path, audio_array, sample_rate)
76 | bt.logging.info(f"Audio saved successfully at: {audio_path}")
77 | self.audio_path = audio_path
78 |
79 | # Read the audio file into a numerical array
80 | audio_data, sample_rate = sf.read(self.audio_path)
81 |
82 | # Convert the numerical array to a tensor
83 | speech_tensor = torch.Tensor(audio_data)
84 |
85 | # Normalize the speech data
86 | audio_data = speech_tensor / torch.max(torch.abs(speech_tensor))
87 | audio_hash = hashlib.sha256(audio_data.numpy().tobytes()).hexdigest()
88 | timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
89 | # Check if the music hash is a duplicate
90 | if check_duplicate_music(audio_hash):
91 | bt.logging.info(f"Duplicate music detected from Validator. skipping hash.")
92 | else:
93 | try:
94 | save_hash_to_file(audio_hash, timestamp)
95 | bt.logging.info(f"Music hash processed and saved successfully for Validator")
96 | except Exception as e:
97 | bt.logging.error(f"Error saving audio hash: {e}")
98 |
99 | except Exception as e:
100 | bt.logging.error(f"Error saving audio file: {e}")
101 | else:
102 | print("Invalid audio data in 'File_Path'. Expected 'array' and 'sampling_rate'.")
103 | return None
104 | else:
105 | print("'File_Path' not found or invalid format in the sample.")
106 | return None
107 |
108 | return prompt # Return the prompt after saving the audio file
109 |
110 |
111 | async def run_async(self):
112 | step = 0
113 | while True:
114 | try:
115 | await self.main_loop_logic(step)
116 | step += 1
117 | if step % 10 == 0:
118 | self.metagraph.sync(subtensor=self.subtensor)
119 | bt.logging.info(f"🔄 Syncing metagraph with subtensor.")
120 | except KeyboardInterrupt:
121 | print("Keyboard interrupt detected. Exiting MusicGenerationService.")
122 | break
123 | except Exception as e:
124 | print(f"An error occurred in MusicGenerationService: {e}")
125 | traceback.print_exc()
126 |
127 | async def main_loop_logic(self, step):
128 | g_prompt = None
129 | try:
130 | # Load prompt from the dataset using the load_prompts function
131 | bt.logging.info(f"Using prompt from HuggingFace Dataset for Text-To-Music at Step: {step}")
132 | g_prompt = self.load_prompts()
133 |
134 | if isinstance(g_prompt, str):
135 | g_prompt = self.convert_numeric_values(g_prompt)
136 |
137 | # Ensure prompt length does not exceed 256 characters
138 | while isinstance(g_prompt, str) and len(g_prompt) > 256:
139 | bt.logging.error(f'The length of current Prompt is greater than 256. Skipping current prompt.')
140 | g_prompt = self.load_prompts() # Reload another prompt
141 | g_prompt = self.convert_numeric_values(g_prompt)
142 |
143 | # Get filtered axons and query the network
144 | filtered_axons = self.get_filtered_axons_from_combinations()
145 | responses = self.query_network(filtered_axons, g_prompt)
146 | try:
147 | self.process_responses(filtered_axons, responses, g_prompt)
148 | except Exception as e:
149 | bt.logging.error(f"getting an error in processing response: {e}")
150 |
151 | if self.last_reset_weights_block + 50 < self.current_block:
152 | bt.logging.info(f"Resetting weights for validators and nodes without IPs")
153 | self.last_reset_weights_block = self.current_block
154 | # set all nodes without ips set to 0
155 | self.scores = torch.Tensor(self.scores) # Convert NumPy array to PyTorch tensor
156 | self.scores = self.scores * torch.Tensor([self.metagraph.neurons[uid].axon_info.ip != '0.0.0.0' for uid in self.metagraph.uids])
157 |
158 | except Exception as e:
159 | bt.logging.error(f"An error occurred in main loop logic: {e}")
160 |
161 | def query_network(self, filtered_axons, prompt, duration=15):
162 | # Network querying logic
163 | if duration == 15:
164 | self.duration = 755
165 | self.time_out = 100
166 | elif duration == 30:
167 | self.duration = 1510
168 | self.time_out = 200
169 |
170 | """Queries the network with filtered axons and prompt."""
171 | responses = self.dendrite.query(
172 | filtered_axons,
173 | MusicGeneration(text_input=prompt, duration=self.duration),
174 | deserialize=True,
175 | timeout=200,
176 | )
177 | return responses
178 |
179 | def update_block(self):
180 | self.current_block = self.subtensor.block
181 | if self.current_block - self.last_updated_block > 120:
182 | bt.logging.info(f"Updating weights. Last update was at block: {self.last_updated_block}")
183 | bt.logging.info(f"Current block is for weight update is: {self.current_block}")
184 | self.update_weights(self.scores)
185 | self.last_updated_block = self.current_block
186 | else:
187 | bt.logging.info(f"Updating weights. Last update was at block: {self.last_updated_block}")
188 | bt.logging.info(f"Current block is: {self.current_block}")
189 | bt.logging.info(f"Next update will be at block: {self.last_updated_block + 120}")
190 | bt.logging.info(f"Skipping weight update. Last update was at block {self.last_updated_block}")
191 |
192 | def process_responses(self, filtered_axons, responses, prompt):
193 | """Processes responses received from the network."""
194 | for axon, response in zip(filtered_axons, responses):
195 | if response is not None and isinstance(response, MusicGeneration):
196 | self.process_response(axon, response, prompt)
197 |
198 | bt.logging.info(f"Scores after update in TTM: {self.scores}")
199 | self.update_block()
200 |
201 | def process_response(self, axon, response, prompt, api=False):
202 | try:
203 | music_output = response.music_output
204 | if response is not None and isinstance(response, MusicGeneration) and response.music_output is not None and response.dendrite.status_code == 200:
205 | bt.logging.success(f"Received music output from {axon.hotkey}")
206 | if api:
207 | file = self.handle_music_output(axon, music_output, prompt, response.model_name)
208 | return file
209 | else:
210 | self.handle_music_output(axon, music_output, prompt, response.model_name)
211 | elif response.dendrite.status_code != 403:
212 | self.punish(axon, service="Text-To-Music", punish_message=response.dendrite.status_message)
213 | else:
214 | pass
215 |
216 | except Exception as e:
217 | bt.logging.error(f'An error occurred while handling speech output: {e}')
218 |
219 | def handle_music_output(self, axon, music_output, prompt, model_name):
220 | # Handle the music output received from the miners
221 | try:
222 | # Convert the list to a tensor
223 | speech_tensor = torch.Tensor(music_output)
224 | bt.logging.info("Converted music output to tensor successfully.")
225 | except Exception as e:
226 | bt.logging.error(f"Error converting music output to tensor: {e}")
227 | return
228 |
229 | try:
230 | # Normalize the speech data
231 | audio_data = speech_tensor / torch.max(torch.abs(speech_tensor))
232 | bt.logging.info("Normalized the audio data.")
233 | except Exception as e:
234 | bt.logging.error(f"Error normalizing audio data: {e}")
235 | return
236 |
237 | try:
238 | # Convert to 32-bit PCM
239 | audio_data_int_ = (audio_data * 2147483647).type(torch.IntTensor)
240 | bt.logging.info("Converted audio data to 32-bit PCM.")
241 |
242 | # Add an extra dimension to make it a 2D tensor
243 | audio_data_int = audio_data_int_.unsqueeze(0)
244 | bt.logging.info("Added an extra dimension to audio data.")
245 | except Exception as e:
246 | bt.logging.error(f"Error converting audio data to 32-bit PCM: {e}")
247 | return
248 |
249 | try:
250 | # Get the .wav file from the path
251 | file_name = os.path.basename(self.audio_path)
252 | bt.logging.info(f"Saving audio file to: {file_name}")
253 |
254 | # Save the audio data as a .wav file
255 | output_path = os.path.join('/tmp/music/', file_name)
256 | sampling_rate = 32000
257 | torchaudio.save(output_path, src=audio_data_int, sample_rate=sampling_rate)
258 | bt.logging.info(f"Saved audio file to {output_path}")
259 | except Exception as e:
260 | bt.logging.error(f"Error saving audio file: {e}")
261 | return
262 |
263 | try:
264 | # Calculate the audio hash
265 | audio_hash = hashlib.sha256(audio_data.numpy().tobytes()).hexdigest()
266 | bt.logging.info("Calculated audio hash.")
267 | except Exception as e:
268 | bt.logging.error(f"Error calculating audio hash: {e}")
269 | return
270 |
271 | try:
272 | # Check if the music hash is a duplicate
273 | if check_duplicate_music(audio_hash):
274 | bt.logging.info(f"Duplicate music detected from miner: {axon.hotkey}. Issuing punishment.")
275 | self.punish(axon, service="Text-To-Music", punish_message="Duplicate music detected")
276 | else:
277 | timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
278 | save_hash_to_file(audio_hash, axon.hotkey, timestamp)
279 | bt.logging.info(f"Music hash processed and saved successfully for miner: {axon.hotkey}")
280 | except Exception as e:
281 | bt.logging.error(f"Error checking or saving music hash: {e}")
282 | return
283 |
284 | try:
285 | # Log the audio to wandb
286 | uid_in_metagraph = self.metagraph.hotkeys.index(axon.hotkey)
287 | audio_data_np = np.array(audio_data_int_)
288 | wandb.log({
289 | f"TTM prompt: {prompt[:100]} ....": wandb.Audio(audio_data_np, caption=f'For HotKey: {axon.hotkey[:10]} and uid {uid_in_metagraph}', sample_rate=sampling_rate)
290 | })
291 | bt.logging.success(f"TTM Audio file uploaded to wandb successfully for Hotkey {axon.hotkey} and UID {uid_in_metagraph}")
292 | except Exception as e:
293 | bt.logging.error(f"Error uploading TTM audio file to wandb: {e}")
294 |
295 | try:
296 | # Get audio duration
297 | duration = self.get_duration(output_path)
298 | token = duration * 50.2
299 | bt.logging.info(f"The duration of the audio file is {duration} seconds.")
300 | except Exception as e:
301 | bt.logging.error(f"Error calculating audio duration: {e}")
302 | return
303 |
304 | try:
305 | refrence_dir = self.audio_path
306 | score, table1, table2 = self.score_output("/tmp/music/", refrence_dir, prompt)
307 | if duration < 15:
308 | score = self.score_adjustment(score, duration)
309 | bt.logging.info(f"Score updated based on short duration than required: {score}")
310 | else:
311 | bt.logging.info(f"Duration is greater than 15 seconds. No need to penalize the score.")
312 | except Exception as e:
313 | bt.logging.error(f"Error scoring the output: {e}")
314 | return
315 |
316 | try:
317 | current_datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
318 | tabulated_str = tabulate(table1, headers=[f"Raw score for hotkey:{axon.hotkey}", current_datetime], tablefmt="grid")
319 | print(tabulated_str)
320 | print("\n")
321 | tabulated_str2 = tabulate(table2, headers=[f"Normalized score for hotkey:{axon.hotkey}", current_datetime], tablefmt="grid")
322 | print(tabulated_str2)
323 | bt.logging.info(f"Aggregated Score for hotkey {axon.hotkey}: {score}")
324 | self.update_score(axon, score, service="Text-To-Music")
325 | except Exception as e:
326 | bt.logging.error(f"Error generating score tables or updating score: {e}")
327 | return
328 |
329 | return output_path
330 |
331 |
332 | def get_duration(self, wav_file_path):
333 | """Returns the duration of the audio file in seconds."""
334 | with contextlib.closing(wave.open(wav_file_path, 'r')) as f:
335 | frames = f.getnframes()
336 | rate = f.getframerate()
337 | return frames / float(rate)
338 |
339 | def score_adjustment(self, score, duration):
340 | """Adjusts the score based on the duration of the generated audio."""
341 | conditions = [
342 | (lambda d: 14.5 <= d < 15, 0.9),
343 | (lambda d: 14 <= d < 14.5, 0.8),
344 | (lambda d: 13.5 <= d < 14, 0.7),
345 | (lambda d: 13 <= d < 13.5, 0.6),
346 | (lambda d: 12.5 <= d < 13, 0.0),
347 | ]
348 | for condition, multiplier in conditions:
349 | if condition(duration):
350 | return score * multiplier
351 | return score
352 |
353 | def score_output(self, output_path, refrence_dir , prompt):
354 | """Evaluates and returns the score for the generated music output."""
355 | try:
356 | score_object = MusicQualityEvaluator()
357 | return score_object.evaluate_music_quality(output_path, refrence_dir, prompt)
358 | except Exception as e:
359 | bt.logging.error(f"Error scoring output: {e}")
360 | return 0.0
361 |
362 |
363 | def get_filtered_axons_from_combinations(self):
364 | if not self.combinations:
365 | self.get_filtered_axons()
366 |
367 | if self.combinations:
368 | current_combination = self.combinations.pop(0)
369 | bt.logging.info(f"Current Combination for TTM: {current_combination}")
370 | filtered_axons = [self.metagraph.axons[i] for i in current_combination]
371 | else:
372 | self.get_filtered_axons()
373 | current_combination = self.combinations.pop(0)
374 | bt.logging.info(f"Current Combination for TTM: {current_combination}")
375 | filtered_axons = [self.metagraph.axons[i] for i in current_combination]
376 |
377 | return filtered_axons
378 |
379 |
380 | def get_filtered_axons(self):
381 | # Get the uids of all miners in the network.
382 | uids = self.metagraph.uids.tolist()
383 | queryable_uids = (self.metagraph.total_stake >= 0)
384 | # Remove the weights of miners that are not queryable.
385 | queryable_uids = torch.Tensor(queryable_uids) * torch.Tensor([self.metagraph.neurons[uid].axon_info.ip != '0.0.0.0' for uid in uids])
386 |
387 | active_miners = torch.sum(queryable_uids)
388 | dendrites_per_query = self.total_dendrites_per_query
389 |
390 | # if there are no active miners, set active_miners to 1
391 | if active_miners == 0:
392 | active_miners = 1
393 | # if there are less than dendrites_per_query * 3 active miners, set dendrites_per_query to active_miners / 3
394 | if active_miners < self.total_dendrites_per_query * 3:
395 | dendrites_per_query = int(active_miners / 3)
396 | else:
397 | dendrites_per_query = self.total_dendrites_per_query
398 |
399 | # less than 3 set to 3
400 | if dendrites_per_query < self.minimum_dendrites_per_query:
401 | dendrites_per_query = self.minimum_dendrites_per_query
402 | # zip uids and queryable_uids, filter only the uids that are queryable, unzip, and get the uids
403 | zipped_uids = list(zip(uids, queryable_uids))
404 | filtered_zipped_uids = list(filter(lambda x: x[1], zipped_uids))
405 | filtered_uids = [item[0] for item in filtered_zipped_uids] if filtered_zipped_uids else []
406 | subset_length = min(dendrites_per_query, len(filtered_uids))
407 | # Shuffle the order of members
408 | random.shuffle(filtered_uids)
409 | # Generate subsets of length 7 until all items are covered
410 | while filtered_uids:
411 | subset = filtered_uids[:subset_length]
412 | self.combinations.append(subset)
413 | filtered_uids = filtered_uids[subset_length:]
414 | return filtered_uids #self.combinations
415 |
416 | def update_weights(self, scores):
417 | """
418 | Sets the validator weights to the metagraph hotkeys based on the scores it has received from the miners.
419 | The weights determine the trust and incentive level the validator assigns to miner nodes on the network.
420 | """
421 |
422 | # Convert scores to a PyTorch tensor and check for NaN values
423 | weights = torch.tensor(scores)
424 | if torch.isnan(weights).any():
425 | bt.logging.warning(
426 | "Scores contain NaN values. This may be due to a lack of responses from miners, or a bug in your reward functions."
427 | )
428 |
429 | # Normalize scores to get raw weights
430 | raw_weights = torch.nn.functional.normalize(weights, p=1, dim=0)
431 | bt.logging.info("raw_weights", raw_weights)
432 |
433 | # Convert uids to a PyTorch tensor
434 | uids = torch.tensor(self.metagraph.uids)
435 |
436 | bt.logging.info("raw_weight_uids", uids)
437 | try:
438 | # Convert tensors to NumPy arrays for processing if required by the process_weights_for_netuid function
439 | uids_np = uids.numpy() if isinstance(uids, torch.Tensor) else uids
440 | raw_weights_np = raw_weights.numpy() if isinstance(raw_weights, torch.Tensor) else raw_weights
441 |
442 | # Process the raw weights and uids based on subnet limitations
443 | (processed_weight_uids, processed_weights) = bt.utils.weight_utils.process_weights_for_netuid(
444 | uids=uids_np, # Ensure this is a NumPy array
445 | weights=raw_weights_np, # Ensure this is a NumPy array
446 | netuid=self.config.netuid,
447 | subtensor=self.subtensor,
448 | metagraph=self.metagraph,
449 | )
450 | bt.logging.info("processed_weights", processed_weights)
451 | bt.logging.info("processed_weight_uids", processed_weight_uids)
452 | except Exception as e:
453 | bt.logging.error(f"An error occurred while processing weights within update_weights: {e}")
454 | return
455 |
456 | # Convert processed weights and uids back to PyTorch tensors if needed for further processing
457 | processed_weight_uids = torch.tensor(processed_weight_uids) if isinstance(processed_weight_uids, np.ndarray) else processed_weight_uids
458 | processed_weights = torch.tensor(processed_weights) if isinstance(processed_weights, np.ndarray) else processed_weights
459 |
460 | # Convert weights and uids to uint16 format for emission
461 | uint_uids, uint_weights = bt.utils.weight_utils.convert_weights_and_uids_for_emit(
462 | uids=processed_weight_uids, weights=processed_weights
463 | )
464 | bt.logging.info("uint_weights", uint_weights)
465 | bt.logging.info("uint_uids", uint_uids)
466 |
467 | # Set the weights on the Bittensor network
468 | try:
469 | result, msg = self.subtensor.set_weights(
470 | wallet=self.wallet,
471 | netuid=self.config.netuid,
472 | uids=uint_uids,
473 | weights=uint_weights,
474 | wait_for_finalization=False,
475 | wait_for_inclusion=False,
476 | version_key=self.version,
477 | )
478 |
479 | if result:
480 | bt.logging.info(f"Weights set on the chain successfully! {result}")
481 | else:
482 | bt.logging.error(f"Failed to set weights: {msg}")
483 | except Exception as e:
484 | bt.logging.error(f"An error occurred while setting weights: {e}")
--------------------------------------------------------------------------------
/ttm/ttm_score.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchaudio
4 | import numpy as np
5 | import bittensor as bt
6 | from sre_constants import error
7 | from huggingface_hub import hf_hub_download
8 | from audiocraft.metrics import PasstKLDivergenceMetric
9 | from audiocraft.metrics import CLAPTextConsistencyMetric
10 | from audioldm_eval.metrics.fad import FrechetAudioDistance
11 |
12 |
13 |
14 | class MetricEvaluator:
15 | @staticmethod
16 | def calculate_kld(generated_audio_dir, target_audio_dir):
17 | try:
18 | # Get the single audio file path in the directory
19 | generate = next((f for f in os.listdir(generated_audio_dir) if os.path.isfile(os.path.join(generated_audio_dir, f))), None)
20 | target = next((f for f in os.listdir(target_audio_dir) if os.path.isfile(os.path.join(target_audio_dir, f))), None)
21 |
22 | if generate is None or target is None:
23 | bt.logging.error("Generated or target audio file not found.")
24 | return None
25 |
26 | # Load your predicted and target audio files
27 | target_waveform, target_sr = torchaudio.load(os.path.join(target_audio_dir, target))
28 | generated_waveform, generated_sr = torchaudio.load(os.path.join(generated_audio_dir, generate))
29 |
30 | # Ensure sample rates match
31 | if target_sr != generated_sr:
32 | resampler = torchaudio.transforms.Resample(orig_freq=generated_sr, new_freq=target_sr)
33 | generated_waveform = resampler(generated_waveform)
34 | generated_sr = target_sr
35 |
36 | # Truncate or pad waveforms to match lengths
37 | min_length = min(target_waveform.shape[-1], generated_waveform.shape[-1])
38 | target_waveform = target_waveform[..., :min_length]
39 | generated_waveform = generated_waveform[..., :min_length]
40 |
41 | # Ensure that the audio tensors are in the shape [batch_size, channels, length]
42 | target_waveform = target_waveform.unsqueeze(0) # Adding batch dimension
43 | generated_waveform = generated_waveform.unsqueeze(0) # Adding batch dimension
44 |
45 | # The sizes of the waveform
46 | sizes = torch.tensor([target_waveform.shape[-1]])
47 |
48 | # The sample rates
49 | sample_rates = torch.tensor([target_sr]) # Use just one sample rate as they should match
50 |
51 | # Initialize the PasstKLDivergenceMetric
52 | kld_metric = PasstKLDivergenceMetric()
53 |
54 | # Move tensors to the appropriate device if needed
55 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
56 | target_waveform = target_waveform.to(device)
57 | generated_waveform = generated_waveform.to(device)
58 | sizes = sizes.to(device)
59 | sample_rates = sample_rates.to(device)
60 | kld_metric = kld_metric.to(device)
61 |
62 | # Update the metric
63 | kld_metric.update(preds=generated_waveform, targets=target_waveform, sizes=sizes, sample_rates=sample_rates)
64 |
65 | # Compute the PasstKLDivergenceMetric score
66 | kld = kld_metric.compute()
67 | return kld['kld_both']
68 |
69 | except Exception as e:
70 | import traceback
71 | traceback_str = traceback.format_exc()
72 | bt.logging.error(f"Error during KLD calculation: {e}\n{traceback_str}")
73 | return None
74 |
75 |
76 | @staticmethod
77 | def calculate_fad(generated_audio_dir, target_audio_dir):
78 | # Initialize the Frechet Audio Distance calculator
79 | fad_calculator = FrechetAudioDistance()
80 |
81 | # Calculate the FAD score between the two directories
82 | fad_score = fad_calculator.score(
83 | background_dir=generated_audio_dir, # Generated audio directory
84 | eval_dir=target_audio_dir, # Target audio directory
85 | store_embds=False, # Set to True if you want to store embeddings for later reuse
86 | limit_num=1, # Limit the number of files to process, None means no limit
87 | recalculate=True # Set to True if you want to recalculate embeddings
88 | )
89 |
90 | # Extract the FAD score from the dictionary
91 | fad_value = fad_score['frechet_audio_distance']
92 |
93 | # Clamp the value to 0 if it's negative
94 | fad = max(0, fad_value)
95 | return fad
96 |
97 | @staticmethod
98 | def calculate_consistency(generated_audio_dir, text):
99 | try:
100 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
101 | pt_file = hf_hub_download(repo_id="lukewys/laion_clap", filename="music_audioset_epoch_15_esc_90.14.pt")
102 | clap_metric = CLAPTextConsistencyMetric(pt_file, model_arch='HTSAT-base').to(device)
103 |
104 | def convert_audio(audio, from_rate, to_rate, to_channels):
105 | resampler = torchaudio.transforms.Resample(orig_freq=from_rate, new_freq=to_rate)
106 | audio = resampler(audio)
107 | if to_channels == 1:
108 | audio = audio.mean(dim=0, keepdim=True)
109 | return audio
110 |
111 | # Get the single audio file path in the directory
112 | file_name = next((f for f in os.listdir(generated_audio_dir) if os.path.isfile(os.path.join(generated_audio_dir, f))), None)
113 | if file_name is None:
114 | raise FileNotFoundError("No audio file found in the directory.")
115 |
116 | file_path = os.path.join(generated_audio_dir, file_name)
117 |
118 | # Load and process the audio
119 | audio, sr = torchaudio.load(file_path)
120 | audio = convert_audio(audio, from_rate=sr, to_rate=sr, to_channels=1)
121 |
122 | # Calculate consistency score
123 | clap_metric.update(audio.unsqueeze(0), [text], torch.tensor([audio.shape[1]]), torch.tensor([sr]))
124 | consistency_score = clap_metric.compute()
125 |
126 | return consistency_score
127 | except Exception as e:
128 | bt.logging.error(f"Error during consistency calculation: {e}")
129 | return None
130 |
131 | class Normalizer:
132 | @staticmethod
133 | def normalize_kld(kld_score):
134 | if kld_score is not None:
135 | if 0 <= kld_score <= 1:
136 | normalized_kld = (1 - kld_score) # Higher score is better, so normalize as 1 - kld_score
137 | elif 1 < kld_score <= 2:
138 | normalized_kld = 0.5 * (2 - kld_score) # Scale down between 0.5 and 0
139 | else:
140 | normalized_kld = 0 # Anything > 2 is considered bad
141 | else:
142 | normalized_kld = 0
143 | return normalized_kld
144 |
145 | @staticmethod
146 | def normalize_fad(fad_score):
147 | if fad_score is not None:
148 | if 0 <= fad_score <= 5:
149 | normalized_fad = (5 - fad_score) / 5 # Normalize between 0 and 1 (higher is better)
150 | elif 5 < fad_score <= 10:
151 | normalized_fad = 0.5 * (10 - fad_score) / 5 # Scale down between 0.5 and 0
152 | else:
153 | normalized_fad = 0 # Anything > 10 is considered bad
154 | else:
155 | normalized_fad = 0
156 | return normalized_fad
157 |
158 |
159 | class Aggregator:
160 | @staticmethod
161 | def geometric_mean(scores):
162 | """Calculate the geometric mean of the scores, avoiding any non-positive values."""
163 | scores = [max(score, 0.0001) for score in scores.values()] # Replace non-positive values to avoid math errors
164 | product = np.prod(scores)
165 | return product ** (1.0 / len(scores))
166 |
167 | class MusicQualityEvaluator:
168 | def __init__(self):
169 | self.metric_evaluator = MetricEvaluator()
170 | self.normalizer = Normalizer()
171 | self.aggregator = Aggregator()
172 |
173 | def get_directory(self, path):
174 | return os.path.dirname(path)
175 |
176 | def evaluate_music_quality(self, generated_audio, target_audio, text=None):
177 |
178 | generated_audio_dir = self.get_directory(generated_audio)
179 | target_audio_dir = self.get_directory(target_audio)
180 |
181 | bt.logging.info(f"Generated audio directory: {generated_audio_dir}")
182 | bt.logging.info(f"Target audio directory: {target_audio_dir}")
183 |
184 | try:
185 | kld_score = self.metric_evaluator.calculate_kld(generated_audio_dir, target_audio_dir)
186 | except:
187 | bt.logging.error(f"Failed to calculate KLD")
188 |
189 | try:
190 | fad_score = self.metric_evaluator.calculate_fad(generated_audio_dir, target_audio_dir)
191 | except:
192 | bt.logging.error(f"Failed to calculate FAD")
193 |
194 | try:
195 | consistency_score = self.metric_evaluator.calculate_consistency(generated_audio_dir, text)
196 | except:
197 | bt.logging.error(f"Failed to calculate Consistency score")
198 |
199 | # Normalize scores and calculate aggregate score
200 | normalized_kld = self.normalizer.normalize_kld(kld_score)
201 | normalized_fad = self.normalizer.normalize_fad(fad_score)
202 |
203 | aggregate_quality = self.aggregator.geometric_mean({'KLD': normalized_kld, 'FAD': normalized_fad})
204 | aggregate_score = self.aggregator.geometric_mean({'quality': aggregate_quality, 'normalized_consistency': consistency_score}) if consistency_score > 0.1 else 0
205 | # Print scores in a table
206 | table1 = [
207 | ["Metric", "Raw Score"],
208 | ["KLD Score", kld_score],
209 | ["FAD Score", fad_score],
210 | ["Consistency Score", consistency_score]
211 | ]
212 |
213 | # Print table of normalized scores
214 | table2 = [
215 | ["Metric", "Normalized Score"],
216 | ["Normalized KLD", normalized_kld],
217 | ["Normalized FAD", normalized_fad],
218 | ["Consistency Score", consistency_score]
219 | ]
220 |
221 |
222 | return aggregate_score, table1 , table2
--------------------------------------------------------------------------------