├── .gitignore ├── load_balancer ├── Dockerfile ├── autoscaler.py ├── balance_server.py ├── cli.py ├── commands.py ├── fly.toml ├── install.sh └── requirements.txt └── worker ├── Dockerfile ├── fly.toml └── install.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .idea -------------------------------------------------------------------------------- /load_balancer/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.9 2 | 3 | WORKDIR /app 4 | COPY . . 5 | RUN apt update && \ 6 | apt install -y jq curl && \ 7 | curl -L https://fly.io/install.sh | sh && \ 8 | pip install --no-cache-dir --upgrade -r ./requirements.txt 9 | ENV PATH="/root/.fly/bin:${PATH}" 10 | CMD ["uvicorn", "balance_server:app", "--host", "0.0.0.0", "--port", "5000"] -------------------------------------------------------------------------------- /load_balancer/autoscaler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import redis 3 | import time 4 | import logging 5 | import threading 6 | from datetime import datetime, timezone 7 | from commands import Commands 8 | GPU_NUM_THRESHOLD = int(os.getenv('GPU_NUM_THRESHOLD', '2')) 9 | GPU_TIME_THRESHOLD = int(os.getenv('GPU_TIME_THRESHOLD', '60')) 10 | FLY_TOKEN = os.getenv('FLY_TOKEN') 11 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO) 12 | 13 | 14 | class Monitor(threading.Thread): 15 | def __init__(self, redis_con): 16 | super().__init__() 17 | self.commands = Commands() 18 | self.redis_con = redis_con 19 | self.redis_con.set('stop_marked_gpu', '') 20 | # self.mongo_collection = mongo_con['bark']['log'] 21 | self.num_gpu_log = [] 22 | self.num_gpu_machines = 0 23 | 24 | def add_log(self, num_gpu, action=None): 25 | current_time = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f") 26 | logging.info(f'Log Time: {current_time}, Num GPUs: {num_gpu}, Action: {action}') 27 | # self.mongo_collection.insert_one( 28 | # { 29 | # "log_time": current_time, 30 | # "num_gpus": num_gpu, 31 | # "action": action 32 | # } 33 | # ) 34 | 35 | def scale_up_if_needed(self): 36 | num_requests = int(self.redis_con.get('active_requests').decode('utf-8')) 37 | num_gpus = 3 * self.num_gpu_machines - num_requests 38 | if num_gpus < GPU_NUM_THRESHOLD: 39 | # scale up 40 | self.commands.start(a='optimizedbark', count=1) 41 | logging.info(f"Starting machine") 42 | self.num_gpu_machines += 1 43 | self.add_log(num_gpus, action="started") 44 | 45 | def run(self) -> None: 46 | i = 0 47 | while True: 48 | if i == 0: 49 | gpu_machines = self.commands.get_machines_by_state(a='optimizedbark', state='started') 50 | gpu_machines = gpu_machines.strip().split('\n') 51 | temp = 0 52 | for machine_id in gpu_machines: 53 | if machine_id != '': 54 | temp += 1 55 | self.num_gpu_machines = temp 56 | if i >= 600: # every 10 minutes 57 | i = 0 58 | num_requests = int(self.redis_con.get('active_requests').decode('utf-8')) 59 | num_gpus = 3 * self.num_gpu_machines - num_requests 60 | self.num_gpu_log.append(num_gpus >= (GPU_NUM_THRESHOLD + 3)) 61 | if len(self.num_gpu_log) > GPU_TIME_THRESHOLD: 62 | self.num_gpu_log.pop(0) 63 | action = None 64 | if all(self.num_gpu_log): 65 | gpu_list = [] 66 | _, keys = self.redis_con.scan(match='migs_*') 67 | for key in keys: 68 | val = int(self.redis_con.get(key.decode('utf-8')).decode('utf-8')) 69 | gpu_list.append((val, key.decode('utf-8'))) 70 | gpu_list.sort(reverse=True) 71 | max_val, max_key = gpu_list[0] 72 | if max_val == 3: 73 | self.commands.stop_machine(a='optimizedbark', machine_id=max_key[5:]) 74 | logging.info(f"Stopping machine {max_key[5:]}") 75 | self.num_gpu_machines -= 1 76 | self.redis_con.set('stop_marked_gpu', '') 77 | self.redis_con.set(max_key, 0) 78 | action = f"stopped {max_key[5:]}" 79 | else: 80 | self.redis_con.set('stop_marked_gpu', max_key[5:]) 81 | self.add_log(num_gpus, action=action) 82 | if num_gpus < GPU_NUM_THRESHOLD: 83 | # scale up 84 | self.commands.start(a='optimizedbark', count=1) 85 | logging.info(f"Starting machine") 86 | self.num_gpu_machines += 1 87 | self.add_log(num_gpus, action="started") 88 | # self.redis_con.lpush('available_gpus', json.dumps({'time': time.time(), 'num_gpus': num_gpus})) 89 | time.sleep(1) 90 | i += 1 91 | 92 | 93 | if __name__ == '__main__': 94 | redis_url = os.environ.get("redis_url", 95 | 'redis://default:eb7199cbf0f54bf5bb084f7f1d594692@fly-bark-queries.upstash.io:6379') 96 | r = redis.Redis.from_url(redis_url) 97 | monitor = Monitor(r) 98 | monitor.start() 99 | monitor.join() 100 | -------------------------------------------------------------------------------- /load_balancer/balance_server.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Request, HTTPException 2 | from fastapi.responses import StreamingResponse 3 | import redis 4 | import uuid 5 | import json 6 | import time 7 | import base64 8 | import uvicorn 9 | import os 10 | from datetime import datetime, timezone 11 | from google.cloud import storage 12 | from autoscaler import Monitor 13 | # from pymongo import MongoClient 14 | 15 | os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "voice-npz.json" 16 | app = FastAPI() 17 | redis_url = os.environ.get( 18 | "redis_url", 19 | 'redis://default:eb7199cbf0f54bf5bb084f7f1d594692@fly-bark-queries.upstash.io:6379' 20 | ) 21 | # mongo_uri = os.environ.get( 22 | # "mongo_uri", 23 | # "mongodb+srv://ginger:P%40ssw0rd131181@bark-log.1fit2mh.mongodb.net/?retryWrites=true&w=majority&appName=bark-log" 24 | # ) 25 | # client = MongoClient(mongo_uri) 26 | r_sub = redis.Redis.from_url(redis_url) 27 | # r_sub = redis.Redis( 28 | # host='localhost', # Changed to localhost 29 | # port=6379, 30 | # password='' # Likely no password if you're just testing locally 31 | # ) 32 | r_pub = redis.Redis.from_url(redis_url) 33 | r_monitor = redis.Redis.from_url(redis_url) 34 | monitor = Monitor(r_monitor) 35 | monitor.start() 36 | r_pub.setnx('active_requests', 0) 37 | 38 | # r_pub = redis.Redis( 39 | # host='localhost', # Changed to localhost 40 | # port=6379, 41 | # password='' # Likely no password if you're just testing locally 42 | # ) 43 | 44 | def get_prediction_stream(request_id): 45 | """Yields prediction data in real-time.""" 46 | # Subscribe to Redis channel for real-time predictions 47 | pubsub = r_sub.pubsub() 48 | pubsub.subscribe(request_id) 49 | 50 | for message in pubsub.listen(): 51 | # Check for message type to avoid initial subscription confirmation message 52 | if message['type'] == 'message': 53 | data = message['data'] 54 | # Assuming the 'complete' signal is a message with '{"complete": true}' 55 | if b'complete' in data: 56 | break 57 | encoded_result = json.loads(data)['prediction'] 58 | decoded_result = base64.b64decode(encoded_result) 59 | yield decoded_result 60 | 61 | 62 | def check_voice(voice): 63 | client = storage.Client() 64 | 65 | # Get the bucket 66 | bucket = client.get_bucket('tts-voices-npz') 67 | blob = bucket.blob(voice + ".npz") 68 | return blob.exists() 69 | 70 | 71 | @app.post("/{call_id}/synthesize", response_class=StreamingResponse) 72 | async def predict(call_id: str, request: Request): 73 | request_id = str(uuid.uuid4()) 74 | data = await request.json() 75 | text = data.pop("text") 76 | voice = data.pop("voice").replace('.npz', '') 77 | rate = data.pop("rate") if "rate" in data.keys() else 1.0 78 | if not check_voice(voice): 79 | def stream_results(): 80 | yield f"NO VOICE {voice}" 81 | 82 | return StreamingResponse(stream_results(), status_code=400) 83 | r_pub.lpush( 84 | "ml_requests", 85 | json.dumps( 86 | { 87 | "request_id": request_id, 88 | "text": text, 89 | "voice": voice, 90 | "rate": rate, 91 | "request_time": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f") 92 | } 93 | ) 94 | ) 95 | r_pub.incr('active_requests') 96 | 97 | def event_stream(): 98 | return get_prediction_stream(request_id) 99 | 100 | return StreamingResponse(event_stream(), media_type="application/octet-stream") 101 | 102 | 103 | if __name__ == "__main__": 104 | uvicorn.run( 105 | app, 106 | host="0.0.0.0", 107 | port=5000, 108 | log_level="debug", 109 | ) 110 | -------------------------------------------------------------------------------- /load_balancer/cli.py: -------------------------------------------------------------------------------- 1 | import fire 2 | from commands import Commands 3 | 4 | 5 | def run(): 6 | commands = Commands() 7 | try: 8 | fire.Fire(commands) 9 | except Exception as e: 10 | print(e) 11 | -------------------------------------------------------------------------------- /load_balancer/commands.py: -------------------------------------------------------------------------------- 1 | import os 2 | import invoke 3 | 4 | c = invoke.Context() 5 | 6 | 7 | class Commands: 8 | def __init__(self): 9 | self.token = os.getenv("FLY_TOKEN") 10 | if not self.token: 11 | raise Exception("FLY_TOKEN is required") 12 | 13 | def count(self, *, a=None, state=None): 14 | """shows the number of machines""" 15 | app_name = a 16 | if not app_name: 17 | return print("App name is required") 18 | 19 | result = self.get_machines_by_state(a=app_name, state=state) 20 | result = c.run( 21 | f'echo "{result}" | awk "NF" | wc -l', 22 | echo=False, 23 | hide="both", 24 | warn=True, 25 | ) 26 | count = int(result.stdout) 27 | return count 28 | 29 | def stop(self, *, a=None, count=None): 30 | """stop `count` number of started machines""" 31 | app_name = a 32 | if not app_name: 33 | return print("App name is required -a") 34 | 35 | if not count: 36 | return print("Count of machines to stop is required --count") 37 | 38 | started_machine_ids = self.get_machines_by_state(a=app_name, state="started") 39 | result = c.run( 40 | f"echo \"{started_machine_ids}\" | awk 'NF' | head -n {count} | xargs -P 0 -L 1 -I {{id}} flyctl m stop {{id}} -a {app_name} -t {self.token}" 41 | ) 42 | 43 | def start(self, *, a=None, count=None): 44 | """start `count` number of stopped machines""" 45 | app_name = a 46 | if not app_name: 47 | return print("App name is required -a") 48 | 49 | stopped_machine_ids = self.get_machines_by_state(a=app_name, state="stopped") 50 | result = c.run( 51 | f"echo \"{stopped_machine_ids}\" | awk 'NF' | head -n {count} | xargs -P 0 -L 1 -I {{id}} flyctl m start {{id}} -a {app_name} -t {self.token}" 52 | ) 53 | 54 | def add(self, *, a=None, count=None): 55 | """clone `count` number of machines""" 56 | app_name = a 57 | if not app_name: 58 | return print("App name is required -a") 59 | 60 | if not count: 61 | return print("Count of machines to add is required --count") 62 | 63 | result = c.run( 64 | f"flyctl m list -q -a {app_name} -t {self.token} | awk 'NF' | head -n {count}" 65 | ) 66 | images = result.stdout.strip() 67 | result = c.run( 68 | f"yes '{images}' | head -n {count} | xargs -P 0 -L 1 -I {{id}} flyctl m clone {{id}} -a {app_name} -t {self.token}" 69 | ) 70 | 71 | def remove(self, *, a=None, count=None): 72 | """destroys `count` number of stopped machines""" 73 | app_name = a 74 | if not app_name: 75 | return print("App name is required -a") 76 | 77 | if not count: 78 | return print("Count of machines to destroy is required --count") 79 | 80 | stopped_machine_ids = self.get_machines_by_state(a=app_name, state="stopped") 81 | result = c.run( 82 | f"echo \"{stopped_machine_ids}\" | awk 'NF' | head -n {count} | xargs -P 0 -L 1 -I {{id}} flyctl m destroy {{id}} -a {app_name} -t {self.token}" 83 | ) 84 | 85 | def stop_machine(self, *, a=None, machine_id=None): 86 | app_name = a 87 | if not app_name: 88 | return print("App name is required -a") 89 | 90 | if not machine_id: 91 | return print("machine_id to stop is required --count") 92 | 93 | c.run(f"flyctl m stop {machine_id} -a {app_name} -t {self.token}") 94 | 95 | def get_machines_by_state(self, *, a=None, state=None): 96 | """get machines from app by state""" 97 | app_name = a 98 | if not app_name: 99 | return print("App name is required -a") 100 | 101 | if not state: 102 | result = c.run( 103 | f"flyctl m list -a {app_name} -t {self.token} --json | jq -r '.[] | .id' | awk 'NF'", 104 | echo=False, 105 | warn=True, 106 | hide="out", 107 | ) 108 | if result.stderr: 109 | raise Exception(result.stderr) 110 | return result.stdout 111 | 112 | opts = dict( 113 | echo=False, 114 | warn=True, 115 | hide="out", 116 | ) 117 | # Uncomment to debug 118 | # if state == "started": 119 | # opts = dict() 120 | 121 | result = c.run( 122 | f"flyctl m list -a {app_name} -t {self.token} --json | jq -r '.[] | select(.state == \"{state}\") | .id' | awk 'NF'", 123 | **opts, 124 | ) 125 | 126 | # Uncomment to debug thrashing 127 | # if state == "started": 128 | # print(result) 129 | 130 | if result.stderr: 131 | raise Exception(result.stderr) 132 | return result.stdout 133 | -------------------------------------------------------------------------------- /load_balancer/fly.toml: -------------------------------------------------------------------------------- 1 | # fly.toml app configuration file generated for tts-app on 2024-03-21T17:18:26Z 2 | # 3 | # See https://fly.io/docs/reference/configuration/ for information about how to use this file. 4 | # 5 | 6 | app = 'tts-app' 7 | primary_region = 'ord' 8 | 9 | [env] 10 | FLY_TOKEN = "fo1_6nTnPORTHvk6FA_Qas5JyJhfRr_FRYNcHEHh0VJXtSQ" 11 | GPU_NUM_THRESHOLD = 2 12 | GPU_TIME_THRESHOLD = 60 13 | redis_url = "redis://default:eb7199cbf0f54bf5bb084f7f1d594692@fly-bark-queries.upstash.io:6379" 14 | # mongo_uri = "mongodb+srv://ginger:P%40ssw0rd131181@bark-log.1fit2mh.mongodb.net/?retryWrites=true&w=majority&appName=bark-log" 15 | 16 | [http_service] 17 | internal_port = 5000 18 | force_https = true 19 | auto_stop_machines = false 20 | auto_start_machines = true 21 | min_machines_running = 0 22 | processes = ['app'] 23 | 24 | [[vm]] 25 | memory = '4gb' 26 | cpu_kind = 'shared' 27 | cpus = 2 28 | -------------------------------------------------------------------------------- /load_balancer/install.sh: -------------------------------------------------------------------------------- 1 | fly launch -o air-297 -------------------------------------------------------------------------------- /load_balancer/requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi 2 | redis 3 | uvicorn 4 | google-cloud-storage 5 | invoke 6 | pymongo -------------------------------------------------------------------------------- /worker/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG BASE_IMAGE=nvcr.io/nvidia/tensorrt 2 | ARG BASE_TAG=24.02-py3 3 | 4 | FROM ${BASE_IMAGE}:${BASE_TAG} as trt_bark 5 | LABEL authors="ginger" 6 | 7 | WORKDIR /app 8 | 9 | COPY TRT_Bark . 10 | 11 | # Combine creation of models directory with subsequent commands 12 | RUN mkdir models && \ 13 | cat /etc/resolv.conf && \ 14 | pip install nvidia-pyindex && \ 15 | pip install -r requirements.txt --no-cache && \ 16 | pip install git+https://github.com/suno-ai/bark --no-cache 17 | RUN python3 download_bark.py && \ 18 | python3 bark_large.py && \ 19 | python3 bark_coarse.py && \ 20 | python3 bark_fine.py && \ 21 | rm -r models/bark_large/ONNX && \ 22 | rm -r models/bark_coarse/ONNX 23 | 24 | FROM ${BASE_IMAGE}:${BASE_TAG} as fastapi_bark 25 | LABEL authors="ginger" 26 | 27 | WORKDIR /app 28 | 29 | COPY OptimizedBark . 30 | 31 | COPY --from=trt_bark /app/models models 32 | RUN pip install nvidia-pyindex && \ 33 | pip install -r requirements.txt --no-cache && \ 34 | python3 -c "import nltk;nltk.download('punkt')" && \ 35 | python3 -c "from vocos import Vocos;Vocos.from_pretrained('charactr/vocos-encodec-24khz')" 36 | EXPOSE 5000 37 | 38 | CMD ["python3", "main.py"] 39 | -------------------------------------------------------------------------------- /worker/fly.toml: -------------------------------------------------------------------------------- 1 | # fly.toml app configuration file generated for optimizedbark on 2024-03-21T15:52:56Z 2 | # 3 | # See https://fly.io/docs/reference/configuration/ for information about how to use this file. 4 | # 5 | 6 | app = 'optimizedbark' 7 | organization = 'air-297' 8 | primary_region = 'ord' 9 | 10 | [build] 11 | image = 'registry.fly.io/optimizedbark:deployment-01HSGVEAG5VFAM5X6CRC243K82' 12 | 13 | [env] 14 | DEFAULT_VOICE = "hey_james_reliable_1_small_coarse_fix" 15 | redis_url = "redis://default:eb7199cbf0f54bf5bb084f7f1d594692@fly-bark-queries.upstash.io:6379" 16 | mongo_uri = "mongodb+srv://ginger:P%40ssw0rd131181@bark-log.1fit2mh.mongodb.net/?retryWrites=true&w=majority&appName=bark-log" 17 | 18 | [http_service] 19 | auto_stop_machines = false 20 | auto_start_machines = true 21 | min_machines_running = 1 22 | processes = ['app'] 23 | 24 | [[vm]] 25 | size = 'a100-40gb' 26 | memory = '20gb' 27 | cpu_kind = 'performance' 28 | gpu_kind = 'a100-pcie-40gb' 29 | cpus = 4 30 | gpus = 1 31 | -------------------------------------------------------------------------------- /worker/install.sh: -------------------------------------------------------------------------------- 1 | # download OptimizedBark 2 | wget https://github.com/gayanMatch/OptimizedBark/releases/download/release.0321/OptimizedBark_2024.0321.zip 3 | unzip OptimizedBark_2024.0321.zip 4 | 5 | # download TRT_Bark 6 | wget https://github.com/gayanMatch/TRT_Bark/releases/download/release.0321/TRT_Bark_2024.03.21.zip 7 | unzip TRT_Bark_2024.0321.zip 8 | 9 | DOCKER_BUILDKIT=0 docker build -t tts-worker:release.0321 . 10 | flyctl launch --org air-297 --vm-gpu-kind a100-pcie-40gb --local-only -i tts-worker:release.0321 11 | 12 | --------------------------------------------------------------------------------