├── .gitignore ├── .vscode └── settings.json ├── CITATION.cff ├── README.md ├── ai ├── __init__.py ├── constants.py ├── convert.py ├── converter.py ├── dataset.py ├── enums.py ├── eval.py ├── models.py ├── play.py ├── train.py └── utils.py ├── assets ├── Oka_Custom.osk ├── danser-settings.json ├── good-play-autopilot.png ├── good-play-relax.png └── skins │ ├── AngeLMegumin (RemakeNeuroTest) [AI] (Unknown).osk │ ├── AngeLMegumin (RemakeNeuroTest)(With Cursor) [MegumiWithCursor] (Unknown).osk │ ├── Moonshine 2.0 [Eclipse].osk │ ├── Oka Custom [Oka Custom (No Cursor)] .osk │ └── Oka Custom [Oka_Custom] .osk ├── experiments ├── rl_test.py ├── rt.py ├── s_test.py ├── socket_test.py ├── test_1.py ├── test_a.py ├── test_b.py ├── test_c.py ├── test_d.py ├── test_e.py ├── test_f.py ├── torch_t.py └── vel_test.py ├── main.py ├── poetry.lock ├── pyproject.toml ├── requirements.txt ├── rl ├── __init__.py ├── agent.py ├── dqn.py ├── env.py └── memory.py ├── test.py └── windows.py /.gitignore: -------------------------------------------------------------------------------- 1 | /data/*/* 2 | /venv 3 | */__pycache__/* 4 | models/* 5 | *.pyc 6 | 7 | node_modules/* 8 | *.log 9 | node_modules 10 | pending-capture/* 11 | 12 | *.mkv 13 | *.lock -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "editor.detectIndentation": false, 3 | "python.analysis.typeCheckingMode": "off", 4 | "[python]": { 5 | "editor.defaultFormatter": "ms-python.black-formatter" 6 | }, 7 | "python.formatting.provider": "black", 8 | "editor.formatOnSave": true 9 | } -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "" 3 | authors: 4 | - family-names: "Ebelo" 5 | given-names: "Oyintare" 6 | orcid: "https://orcid.org/0009-0001-0044-5654" 7 | title: "osu-ai" 8 | version: 1.0.0 9 | doi: 10.5281/zenodo.10208110 10 | date-released: 2023-11-26 11 | url: "https://github.com/TareHimself/osu-ai" 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Osu Neural Network Created Using Pytorch 2 | 3 | - DISCLAIMER : I am not responsible for any consequences that stem from the illicit use of the contents of this repository. 4 | 5 | ## Info 6 | 7 | - Data extracted from replays produced by a slightly modified version of [danser-go](https://github.com/Wieku/danser-go) 8 | 9 | - Aiming works but clicking does not. 10 | 11 | - All skins used are located in [assets/skins](assets/skins) 12 | 13 | - Showcase videos [Clicking (DOES NOT WORK CURRENTLY)](https://www.youtube.com/watch?v=ZgHyN98iR1M&t=5s) and [Aiming](https://www.youtube.com/watch?v=YEoSrtow8Qw). 14 | 15 | ## Quick Start 16 | - The following assume that your monitor/screen size is 1920x1080 (plan to make it dynamic later) 17 | - Clone this repository along with [this](https://github.com/TareHimself/danser-go) modified version of danser 18 | - Build danser using the instructions in the repository 19 | - Copy [danser-settings.json](assets/danser-settings.json) to `"cloned danser repo"/settings/danser-settings.json` 20 | - Launch danser using the built binary 21 | - Once danser is done importing in the config dropdown switch it to the settings we copied earlier 22 | - Switch mode to "Watch a Replay" and select the replay you want to train on 23 | - Switch the dropdown on the bottom from "Watch" to "Record" 24 | - Before recording you can configure the settings to use a different skin. 25 | - Once ready click "danse!" and wait for danser to generate the recording. 26 | - Once done it will open a folder with a json file and an mkv file, these are needed to generate the dataset. 27 | - Going back to the osu-ai folder first setup [Anaconda](https://www.anaconda.com/download) 28 | - Run the following in the terminal to create an enviroment 29 | ```bash 30 | conda create --name osu-ai python=3.9.12 31 | conda activate osu-ai 32 | ``` 33 | - Install [Poetry](https://python-poetry.org/) 34 | - Run the following in the enviroment we created 35 | 36 | ```bash 37 | poetry install 38 | # For cuda support run "poe force-cuda" 39 | # For win32Mouse support run "poe use-win32" 40 | ``` 41 | - now we can run main 42 | ```bash 43 | python main.py 44 | ``` 45 | - You should see this menu 46 | ```bash 47 | What would you like to do ? 48 | [0] Train or finetune a model 49 | [1] Convert a video and json into a dataset 50 | [2] Test a model 51 | [3] Quit 52 | ``` 53 | - Select "1" for Convert, name it whatever you want. For the video and json type in the path to the respective files we generated earlier i.e. `a/b/c.mkv` `a/b/c.json`. 54 | - For the number of threads I usually use 5 and we will leave the offset at 0. 55 | - For `Max images to keep in memory when writing` I usually leave it at 0 unless the video is really long 56 | - Now wait for the dataset to be generated 57 | - Now we can train. Select "0" to train then "0" for aim. Name it whatever and select the dataset we just made. I usually set max epochs to a very large number since `ctrl+c` will stop training early. 58 | - And that's it. 59 | 60 | ## Example Autopilot play below on [this map](https://osu.ppy.sh/beatmapsets/765778#osu/1627148). The model was trained using [this map](https://osu.ppy.sh/beatmapsets/1721048#osu/3560542). 61 | 62 | ![goodplay](assets/good-play-autopilot.png) 63 | 64 | ## Example Relax play below on [this map](https://osu.ppy.sh/beatmapsets/1357624#osu/2809623). The model was trained using [this map](https://osu.ppy.sh/beatmapsets/1511778#osu/3287118). 65 | 66 | ![goodplay](assets/good-play-relax.png) 67 | -------------------------------------------------------------------------------- /ai/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /ai/constants.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from os import getcwd, path, makedirs 3 | 4 | ASSETS_DIR = path.normpath(path.join(getcwd(), 'assets')) 5 | MODELS_DIR = path.normpath(path.join( 6 | getcwd(), 'models')) 7 | RAW_DATA_DIR = path.normpath(path.join( 8 | getcwd(), 'data', 'raw')) 9 | PROCESSED_DATA_DIR = path.normpath(path.join( 10 | getcwd(), 'data', 'processed')) 11 | 12 | CAPTURE_HEIGHT_PERCENT = 1.0 13 | 14 | 15 | PLAY_AREA_RATIO = 4 / 3 16 | 17 | FINAL_IMAGE_WIDTH = 80 # int(1920 * 0.1) 18 | # FINAL_IMAGE_WIDTH = 80 19 | FINAL_IMAGE_HEIGHT = int(FINAL_IMAGE_WIDTH / PLAY_AREA_RATIO) 20 | 21 | FINAL_PLAY_AREA_SIZE = (FINAL_IMAGE_WIDTH, FINAL_IMAGE_HEIGHT) 22 | 23 | PYTORCH_DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 24 | 25 | CURRENT_STACK_NUM = 10 26 | 27 | FRAME_DELAY = 0.01 28 | 29 | MAX_THREADS_FOR_RESIZING = 20 30 | 31 | if not path.exists(RAW_DATA_DIR): 32 | makedirs(RAW_DATA_DIR) 33 | 34 | if not path.exists(PROCESSED_DATA_DIR): 35 | makedirs(PROCESSED_DATA_DIR) 36 | -------------------------------------------------------------------------------- /ai/convert.py: -------------------------------------------------------------------------------- 1 | from ai.utils import get_validated_input 2 | from ai.constants import RAW_DATA_DIR 3 | from os import path 4 | from ai.converter import ReplayConverter 5 | import traceback 6 | 7 | 8 | def start_convert(): 9 | project_name = get_validated_input( 10 | 'What Would You Like To Name This Project ?:', conversion_fn=lambda a: a.lower().strip()) 11 | rendered_path = get_validated_input( 12 | 'Path to the rendered replay video:', validate_fn=lambda a: path.exists(a.strip()), 13 | conversion_fn=lambda a: a.strip(), on_validation_error=lambda a: print("Invalid path!")) 14 | replay_json = get_validated_input( 15 | 'Path to the rendered replay json:', validate_fn=lambda a: path.exists(a.strip()), 16 | conversion_fn=lambda a: a.strip(), on_validation_error=lambda a: print("Invalid path!")) 17 | 18 | num_threads = get_validated_input("Number of threads to use when processing the video (more isn't always faster):", 19 | lambda a: a.strip().isnumeric() and 0 < int(a.strip()), lambda a: int(a.strip()), 20 | on_validation_error=lambda a: print("It must be an integer greater than zero")) 21 | 22 | offset_ms = get_validated_input("Offset in ms to apply to the dataset (e.g. -100):", 23 | lambda a: a.strip().lstrip('-+').isdigit(), lambda a: int(a.strip()), 24 | on_validation_error=lambda a: print("It must be a positive or negative integer")) 25 | 26 | max_memory = get_validated_input("Max images to keep in memory when writing. Default is 0 (as much as possible):", 27 | lambda a: (a.strip().lstrip('-+').isdigit() and int(a.strip()) >= 0) if len(a.strip()) > 0 else True, lambda a: int(a.strip()) if len(a.strip()) > 0 else 0, 28 | on_validation_error=lambda a: print("It must be a positive integer or left empty")) 29 | 30 | try: 31 | ReplayConverter(project_name, rendered_path, replay_json, 32 | RAW_DATA_DIR, num_writers=num_threads, frame_offset_ms=offset_ms, max_in_memory=max_memory) 33 | except: 34 | traceback.print_exc() 35 | -------------------------------------------------------------------------------- /ai/converter.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from typing import Union 3 | import cv2 4 | import json 5 | from tqdm import tqdm 6 | from queue import Queue 7 | from threading import Thread, Event 8 | import shutil 9 | import os 10 | from ai.utils import Cv2VideoContext, EventsSampler, playfield_coords_to_screen, derive_capture_params 11 | from ai.constants import CAPTURE_HEIGHT_PERCENT 12 | 13 | 14 | class ReplayConverter: 15 | """ 16 | NOTE: if the video was shot at 60fps the minimum frame offset would be 17 | """ 18 | 19 | def __init__(self, project_name: str, danser_video: str, replay_json: str, 20 | save_dir: str = "", num_writers=5, max_in_memory=0, frame_interval_ms=10, frame_offset_ms=0, 21 | video_fps=100, 22 | replay_keys_json: Union[str, None] = None, 23 | debug=False) -> None: 24 | self.project_name = project_name 25 | self.save_dir = save_dir 26 | self.danser_video = danser_video 27 | self.replay_json = replay_json 28 | self.replay_keys_json = replay_keys_json 29 | self.num_writers = num_writers 30 | self.max_in_memory = max_in_memory 31 | self.frame_interval_ms = frame_interval_ms 32 | self.frame_offset_ms = frame_offset_ms 33 | self.video_fps = video_fps 34 | self.debug = debug 35 | self.build_dataset() 36 | 37 | def build_dataset(self): 38 | 39 | with Cv2VideoContext(self.danser_video) as ctx: 40 | screen_height = int(ctx.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 41 | screen_width = int(ctx.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 42 | 43 | 44 | [capture_w,capture_h,capture_dx,capture_dy] = derive_capture_params(screen_width,screen_height) 45 | 46 | with open(self.replay_json, 'r') as f: 47 | replay_data = json.load(f) 48 | 49 | replay_keys_data = None 50 | 51 | if self.replay_keys_json is not None: 52 | with open(self.replay_keys_json, 'r') as f: 53 | replay_keys_data = json.load(f) 54 | 55 | start_time = replay_data["objects"][0][ 56 | "start"] # We assume start time will be the same for both since it should be the same map 57 | breaks = [] 58 | 59 | total_event_time_mouse = 0 60 | 61 | time_offset = start_time + self.frame_offset_ms 62 | 63 | events_keys = [] 64 | events_mouse = [] 65 | 66 | for event in replay_data["events"]: 67 | total_event_time_mouse += event['diff'] 68 | timestamp = total_event_time_mouse - time_offset 69 | [x,y,dx,dy] = playfield_coords_to_screen(event["x"],event["y"],screen_width,screen_height,True) 70 | events_mouse.append({ 71 | "x": round(x + dx), 72 | "y": round(y + dy), 73 | "time": timestamp, 74 | }) 75 | 76 | if replay_keys_data is None: 77 | events_keys.append({ 78 | "keys": [event['k1'], event['k2']], 79 | "time": timestamp, 80 | }) 81 | 82 | if replay_keys_data is not None: 83 | total_event_time_keys = 0 84 | for event in replay_keys_data["events"]: 85 | total_event_time_keys += event['diff'] 86 | timestamp = total_event_time_mouse - time_offset 87 | events_keys.append({ 88 | "keys": [event['k1'], event['k2']], 89 | "time": timestamp, 90 | }) 91 | 92 | for b in replay_data['breaks']: 93 | breaks.append({ 94 | "start": b["start"] - time_offset, 95 | "end": b["end"] - time_offset 96 | }) 97 | 98 | stop_time = max(int(events_mouse[len(events_mouse) - 1]['time']), 99 | int(events_keys[len(events_keys) - 1]['time'])) 100 | 101 | iter_target = range(0, stop_time, self.frame_interval_ms) 102 | 103 | remove_breaks = False 104 | 105 | if remove_breaks: 106 | new_iter_target = [] 107 | for i in iter_target: 108 | should_add = True 109 | for item in breaks: 110 | if item['start'] <= i <= item["end"]: 111 | should_add = False 112 | break 113 | 114 | if should_add: 115 | new_iter_target.append(i) 116 | iter_target = new_iter_target 117 | 118 | save_dir = os.path.join(self.save_dir, self.project_name) 119 | 120 | if os.path.exists(save_dir): 121 | shutil.rmtree(save_dir) 122 | 123 | os.mkdir(save_dir) 124 | 125 | loading_bar = tqdm(desc="Generating Dataset", total=len(iter_target)) 126 | 127 | write_buff = Queue(maxsize=self.max_in_memory) 128 | 129 | stop_conversion = Event() 130 | 131 | def write_one_frame_func(data): 132 | sample, frame = data 133 | if frame is not None: 134 | cur_time, x, y, keys_bool = sample 135 | image_file_name = f"{self.project_name}-{round(cur_time)},{1 if keys_bool[0] else 0},{1 if keys_bool[1] else 0},{x},{y}.png" 136 | 137 | image_path = os.path.join(save_dir, image_file_name) 138 | 139 | cv2.imwrite(image_path, frame) 140 | 141 | loading_bar.update() 142 | 143 | def frame_writer_func(): 144 | to_write = write_buff.get() 145 | while to_write is not None and not stop_conversion.is_set(): 146 | write_one_frame_func(to_write) 147 | to_write = write_buff.get() 148 | 149 | def frame_reader_func(target: list): 150 | 151 | mouse_sampler = EventsSampler(events_mouse.copy()) 152 | keys_sampler = EventsSampler(events_keys.copy()) 153 | 154 | local_iter_target: collections.deque = collections.deque(target) 155 | 156 | frame_delta_ms = (1 / self.video_fps) * 1000 157 | 158 | with Cv2VideoContext(self.danser_video) as video_capture_context: 159 | 160 | video_start_time = 0 161 | 162 | target_time = local_iter_target.popleft() 163 | 164 | if self.debug: 165 | print("Set time to", target_time) 166 | 167 | total_frames_skipped = 0 168 | 169 | while (len(local_iter_target) > 0 or target_time is not None) and not stop_conversion.is_set(): 170 | 171 | total_time_delta = target_time - video_start_time 172 | 173 | target_frame = total_time_delta / frame_delta_ms if total_time_delta != 0 else 0 174 | 175 | frames_to_skip_start = (target_frame - total_frames_skipped) 176 | 177 | if self.debug: 178 | print("INFO", total_time_delta, target_time, target_frame, 179 | frame_delta_ms, total_frames_skipped, frames_to_skip_start) 180 | 181 | if frames_to_skip_start == 0 or frames_to_skip_start >= 1: 182 | 183 | frames_to_skip = frames_to_skip_start 184 | 185 | if frames_to_skip > 1: # no point in skipping frames if we cant read what we have left 186 | while frames_to_skip - 1 >= 0: 187 | video_capture_context.cap.read() 188 | frames_to_skip -= 1 189 | if self.debug: 190 | print("SKIPPED") 191 | 192 | read_success, frame = video_capture_context.cap.read() 193 | 194 | if read_success: 195 | 196 | current_time = ((total_frames_skipped + abs( 197 | frames_to_skip_start - frames_to_skip)) * frame_delta_ms) + video_start_time 198 | 199 | cur_time_mouse, x, y = mouse_sampler.sample_mouse(current_time) 200 | cur_time_keys, keys_bool = keys_sampler.sample_keys(current_time) 201 | 202 | if self.debug: 203 | print("Key State", keys_bool) 204 | 205 | # [x,y,dx,dy] = playfield_coords_to_screen(x,y) 206 | # debug_x, debug_y = round( 207 | # x + dx), round( 208 | # y + dy), 209 | 210 | debug_frame = frame[int(capture_dy):int( 211 | capture_dy + capture_h), int(capture_dx):int( 212 | capture_dx + capture_w)].copy() 213 | 214 | cv2.imshow("Window", 215 | cv2.circle(debug_frame, (int(x) - 5, int(y) - 5), 10, 216 | (255, 255, 255), 217 | 3)) 218 | cv2.waitKey(0) 219 | 220 | write_buff.put( 221 | ((current_time, round(x), round(y), keys_bool), frame[int(capture_dy):int(capture_dy + capture_h), int(capture_dx):int(capture_dx + capture_w)])) 222 | 223 | frames_to_skip -= 1 224 | else: 225 | loading_bar.update() 226 | 227 | total_frames_skipped += abs(frames_to_skip_start - frames_to_skip) 228 | if self.debug: 229 | print("FINAL FRAMES SKIPPED", total_frames_skipped, 230 | frames_to_skip_start, frames_to_skip) 231 | else: 232 | loading_bar.update() 233 | 234 | if len(local_iter_target) == 0: 235 | target_time = None 236 | else: 237 | target_time = local_iter_target.popleft() 238 | 239 | frame_reader_thread = Thread(target=frame_reader_func, group=None, daemon=True, args=[iter_target]) 240 | 241 | frame_writers = [Thread(target=frame_writer_func, group=None, daemon=True) for x in 242 | range(self.num_writers * 2)] # not sure if 1:2 ratio is right 243 | 244 | frame_reader_thread.start() 245 | 246 | for writer in frame_writers: 247 | writer.start() 248 | 249 | while frame_reader_thread.is_alive(): 250 | try: 251 | frame_reader_thread.join(timeout=2) 252 | except KeyboardInterrupt: 253 | stop_conversion.set() 254 | for _ in frame_writers: 255 | write_buff.put(None) 256 | break 257 | 258 | if not stop_conversion.is_set(): 259 | for _ in frame_writers: 260 | write_buff.put(None) 261 | 262 | for writer in frame_writers: 263 | while writer.is_alive() and not stop_conversion.is_set(): 264 | try: 265 | writer.join(timeout=2) 266 | except KeyboardInterrupt: 267 | stop_conversion.set() 268 | break 269 | -------------------------------------------------------------------------------- /ai/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from os import path 4 | import cv2 5 | import numpy as np 6 | import traceback 7 | from tempfile import TemporaryDirectory 8 | import torchvision.transforms as transforms 9 | from queue import Queue 10 | from threading import Thread 11 | from concurrent.futures import ThreadPoolExecutor 12 | from tqdm import tqdm 13 | from ai.constants import CURRENT_STACK_NUM, FINAL_PLAY_AREA_SIZE, PROCESSED_DATA_DIR, \ 14 | RAW_DATA_DIR, MAX_THREADS_FOR_RESIZING 15 | from collections import deque 16 | from torch.utils.data import Dataset 17 | from ai.enums import EModelType 18 | 19 | image_to_pytorch_image = transforms.ToTensor() 20 | 21 | INVALID_KEY_STATE = "An Invalid State" 22 | 23 | KEY_STATES = { 24 | "00": 0, 25 | "01": 1, 26 | "10": 2, 27 | } 28 | 29 | 30 | class OsuDataset(Dataset): 31 | """ 32 | 33 | """ 34 | FILE_REG_EXPR = r"-([0-9]+),[0-1],[0-1],[-0-9.]+,[-0-9.]+.png" 35 | 36 | def __init__(self, datasets: list[str], label_type: EModelType = EModelType.Actions, force_rebuild=False) -> None: 37 | self.datasets = datasets 38 | self.labels = [] 39 | self.images = [] 40 | self.label_type = label_type 41 | self.data_to_process = Queue() 42 | self.force_rebuild = force_rebuild 43 | self.make_training_data() 44 | 45 | @staticmethod 46 | def extract_info(frame, state, dims): 47 | width, height = dims 48 | # print(dims) 49 | _, k1, k2, x, y = state.split(',') 50 | 51 | # greyscale 52 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 53 | 54 | # normalize 55 | frame = frame / 255 56 | 57 | x = max(0, float(x.strip())) 58 | 59 | x = x if x == 0 else x / width 60 | 61 | y = max(0, float(y.strip())) 62 | 63 | y = y if y == 0 else y / height 64 | 65 | return frame, KEY_STATES.get(f"{k1}{k2}".strip(), 0), np.array([x, y]) 66 | 67 | @staticmethod 68 | def stack_frames(previous_frames: deque, frame): 69 | prev_frames = list(previous_frames) 70 | prev_count = len(prev_frames) 71 | needed_count = CURRENT_STACK_NUM - prev_count 72 | 73 | if needed_count > 1: 74 | previous_frames.append(frame) 75 | return None 76 | else: 77 | final_frames = prev_frames[prev_count - 78 | (CURRENT_STACK_NUM - 1):prev_count] + [frame] 79 | previous_frames.append(frame) 80 | 81 | return np.stack(final_frames) 82 | 83 | def background_loader(self, dataset_dir: str, files_to_load: list[str]): 84 | try: 85 | for item in files_to_load: 86 | image_file = cv2.imread( 87 | path.join(dataset_dir, item), cv2.IMREAD_COLOR) 88 | self.data_to_process.put((image_file, item[:-4])) 89 | 90 | self.data_to_process.put(None) 91 | except: 92 | traceback.print_exc() 93 | 94 | @staticmethod 95 | def resize_dataset(temp_directory: str, dataset: str, source_path: str): 96 | files = [] 97 | 98 | files_to_load = os.listdir(source_path) 99 | loading_bar = tqdm( 100 | desc=f"Resizing Dataset [{dataset}]", total=len(files_to_load)) 101 | 102 | data_dims = None 103 | 104 | def resize_image(filename): 105 | nonlocal temp_directory 106 | nonlocal data_dims 107 | 108 | try: 109 | current_item_source_path = path.join(source_path, filename) 110 | current_item_dest_path = path.join(temp_directory, filename) 111 | 112 | frame = cv2.imread(current_item_source_path, cv2.IMREAD_COLOR) 113 | 114 | if data_dims is None: 115 | data_dims = frame.shape[:2][::-1] 116 | 117 | frame = cv2.resize( 118 | frame, FINAL_PLAY_AREA_SIZE, interpolation=cv2.INTER_LINEAR) 119 | 120 | cv2.imwrite(current_item_dest_path, frame) 121 | files.append(filename) 122 | loading_bar.update() 123 | 124 | except Exception: 125 | traceback.print_exc() 126 | 127 | try: 128 | with ThreadPoolExecutor(MAX_THREADS_FOR_RESIZING) as executor: 129 | 130 | for file in files_to_load: 131 | executor.submit(resize_image, file) 132 | 133 | executor.shutdown() 134 | 135 | except KeyboardInterrupt: 136 | pass 137 | 138 | return files, data_dims 139 | 140 | def get_or_create_dataset(self, temp_directory: str, dataset: str) -> tuple[list[np.ndarray], list[np.ndarray], list[np.ndarray]]: 141 | 142 | try: 143 | 144 | processed_data_path = path.join( 145 | PROCESSED_DATA_DIR, f"{CURRENT_STACK_NUM}-{FINAL_PLAY_AREA_SIZE[0]}-{dataset}.npy") 146 | raw_data_path = path.join( 147 | RAW_DATA_DIR, f'{dataset}') 148 | 149 | if not self.force_rebuild and path.exists(processed_data_path): 150 | loaded_data = np.load(processed_data_path, allow_pickle=True) 151 | return list(loaded_data[:, 0]), list(loaded_data[:, 1]), list(loaded_data[:, 2]) 152 | 153 | files, data_dims = OsuDataset.resize_dataset(temp_directory, 154 | dataset, raw_data_path) 155 | 156 | files.sort(key=lambda x: int( 157 | re.search(OsuDataset.FILE_REG_EXPR, x).groups()[0])) 158 | 159 | frame_queue = deque(maxlen=CURRENT_STACK_NUM - 1) 160 | 161 | processed = [] 162 | 163 | Thread(target=self.background_loader, 164 | daemon=True, group=None, args=[temp_directory, files]).start() 165 | 166 | loader = tqdm(total=len(files), 167 | desc=f"Processing Dataset [{dataset}]") 168 | 169 | data = self.data_to_process.get() 170 | 171 | while data is not None: 172 | frame, state = data 173 | 174 | frame, key_state, mouse_state = OsuDataset.extract_info( 175 | frame, state, data_dims) 176 | 177 | stacked = OsuDataset.stack_frames(frame_queue, frame) 178 | if stacked is None: 179 | loader.update() 180 | data = self.data_to_process.get() 181 | continue 182 | # transp = stacked.transpose(1, 2, 0) 183 | # cv2.imshow("Debug", transp) 184 | # cv2.waitKey(0) 185 | 186 | processed.append( 187 | np.array([stacked, key_state, mouse_state], dtype=object)) 188 | 189 | loader.update() 190 | data = self.data_to_process.get() 191 | loader.close() 192 | 193 | processed = np.stack(processed) 194 | 195 | print(f"Saving Dataset [{dataset}]") 196 | np.save(processed_data_path, processed) 197 | 198 | return list(processed[:, 0]), list(processed[:, 1]), list(processed[:, 2]) 199 | except Exception: 200 | 201 | self.data_to_process.put(None) 202 | 203 | traceback.print_exc() 204 | 205 | return [], [], [] 206 | 207 | def make_training_data(self): 208 | try: 209 | self.labels = [] 210 | self.images = [] 211 | total_images = [] 212 | total_mouse_coordinates = [] 213 | total_keys = [] 214 | 215 | with TemporaryDirectory() as temp_dir: 216 | for dataset in self.datasets: 217 | images, keys, coordinates = self.get_or_create_dataset( 218 | temp_dir, dataset) 219 | total_images.extend(images) 220 | total_mouse_coordinates.extend(coordinates) 221 | total_keys.extend(keys) 222 | 223 | print("LABEL TYPE",self.label_type,self.label_type == EModelType.Aim) 224 | if self.label_type == EModelType.Actions: 225 | 226 | self.images = total_images 227 | self.labels = total_keys 228 | 229 | unique_labels = list(set(self.labels)) 230 | counts = {} 231 | for label in unique_labels: 232 | counts[label] = 0 233 | 234 | for label in self.labels: 235 | counts[label] += 1 236 | 237 | print("Initial Data Balance", counts) 238 | target_amount = max(counts.values()) 239 | for label in counts.keys(): 240 | if counts[label] < target_amount: 241 | label_examples = [self.images[x] 242 | for x in range(len(self.labels)) if self.labels[x] == label] 243 | len_examples = len(label_examples) 244 | for i in range(target_amount - counts[label]): 245 | self.labels.append(label) 246 | self.images.append( 247 | label_examples[i % len_examples]) 248 | 249 | for label in unique_labels: 250 | counts[label] = 0 251 | 252 | for label in self.labels: 253 | counts[label] += 1 254 | 255 | print("Final Dataset Balance", counts) 256 | elif self.label_type == EModelType.Aim: 257 | 258 | self.images = total_images 259 | self.labels = total_mouse_coordinates 260 | print("Final Dataset Size", len(self.labels)) 261 | 262 | elif self.label_type == EModelType.Combined: 263 | def convert_label(a): 264 | return np.array([a[0][0], a[0][1], 1 if a[1] == 2 else 0, 1 if a[1] == 1 else 0]) 265 | 266 | self.labels = list(map(convert_label, zip(total_mouse_coordinates, total_keys))) 267 | self.images = total_images 268 | 269 | unique_labels = list(set(total_keys)) 270 | counts = {} 271 | for label in unique_labels: 272 | counts[label] = 0 273 | 274 | for label in self.labels: 275 | counts[KEY_STATES[f"{int(label[2])}{int(label[3])}"]] += 1 276 | 277 | print("Initial Data Balance", counts) 278 | target_amount = max(counts.values()) 279 | for label in counts.keys(): 280 | if counts[label] < target_amount: 281 | label_examples = [(self.images[x], x) 282 | for x in range(len(total_keys)) if total_keys[x] == label] 283 | len_examples = len(label_examples) 284 | for i in range(target_amount - counts[label]): 285 | target_example, target_index = label_examples[i % len_examples] 286 | self.labels.append(self.labels[target_index]) 287 | self.images.append(target_example) 288 | 289 | for label in unique_labels: 290 | counts[label] = 0 291 | 292 | for label in self.labels: 293 | counts[KEY_STATES[f"{int(label[2])}{int(label[3])}"]] += 1 294 | 295 | print("Final Dataset Balance", counts) 296 | 297 | except Exception: 298 | traceback.print_exc() 299 | 300 | def __getitem__(self, idx): 301 | return self.images[idx], self.labels[idx] 302 | 303 | def __len__(self): 304 | return len(self.labels) 305 | 306 | # np.save('test_data.npy', extract_data_from_image( 307 | # "D:\Github\osu-ai\data\\raw\meaning-of-love-4.62\\755.png")) 308 | -------------------------------------------------------------------------------- /ai/enums.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class EModelType(Enum): 5 | Unknown = 'Unknown' 6 | Actions = 'Actions' 7 | Aim = 'Aim' 8 | Combined = 'Combined' 9 | 10 | 11 | class EPlayAreaIndices(Enum): 12 | Width = 0 13 | Height = 1 14 | OffsetX = 2 15 | OffsetY = 3 16 | -------------------------------------------------------------------------------- /ai/eval.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import time 3 | import cv2 4 | import numpy as np 5 | import torch 6 | import keyboard 7 | from threading import Thread 8 | from torch import Tensor 9 | from ai.models import ActionsNet, AimNet, OsuAiModel, CombinedNet 10 | from ai.constants import FINAL_PLAY_AREA_SIZE, FRAME_DELAY, PYTORCH_DEVICE, MODELS_DIR 11 | from ai.utils import FixedRuntime, derive_capture_params 12 | from collections import deque 13 | from mss import mss 14 | from ai.enums import EPlayAreaIndices 15 | import mouse 16 | # 'osu!' # 17 | DEFAULT_OSU_WINDOW = 'osu!' # "osu! (development)" 18 | USE_WIN_32_MOUSE = False 19 | try: 20 | import win32api 21 | USE_WIN_32_MOUSE = True 22 | except: 23 | USE_WIN_32_MOUSE = False 24 | 25 | class EvalThread(Thread): 26 | 27 | def __init__(self, model_id: str, game_window_name: str = DEFAULT_OSU_WINDOW, eval_key: str = '\\'): 28 | super().__init__(group=None, daemon=True) 29 | self.game_window_name = game_window_name 30 | self.model_id = model_id 31 | self.capture_params = derive_capture_params() 32 | self.eval_key = eval_key 33 | self.eval = True 34 | self.start() 35 | 36 | 37 | def get_model(self): 38 | model = torch.jit.load(os.path.join(MODELS_DIR, self.model_id, 'model.pt')) 39 | model.load_state_dict(torch.load(os.path.join(MODELS_DIR, self.model_id, 'weights.pt'))) 40 | model.to(PYTORCH_DEVICE) 41 | model.eval() 42 | return model 43 | 44 | def on_output(self, output: Tensor): 45 | pass 46 | 47 | def on_eval_ready(self): 48 | print("Unknown Model Ready") 49 | 50 | def kill(self): 51 | self.eval = False 52 | 53 | @torch.no_grad() 54 | def run(self): 55 | eval_model = self.get_model() 56 | with torch.inference_mode(): 57 | frame_buffer = deque(maxlen=eval_model.channels) 58 | eval_this_frame = False 59 | 60 | def toggle_eval(): 61 | nonlocal eval_this_frame 62 | eval_this_frame = not eval_this_frame 63 | 64 | keyboard.add_hotkey(self.eval_key, callback=toggle_eval) 65 | 66 | self.on_eval_ready() 67 | 68 | with mss() as sct: 69 | monitor = {"top": self.capture_params[EPlayAreaIndices.OffsetY.value], 70 | "left": self.capture_params[EPlayAreaIndices.OffsetX.value], 71 | "width": self.capture_params[EPlayAreaIndices.Width.value], 72 | "height": self.capture_params[EPlayAreaIndices.Height.value]} 73 | 74 | while self.eval: 75 | with FixedRuntime(target_time=FRAME_DELAY): # limit capture to every "FRAME_DELAY" seconds 76 | if eval_this_frame: 77 | frame = np.array(sct.grab(monitor)) 78 | frame = cv2.resize(cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY), FINAL_PLAY_AREA_SIZE) 79 | 80 | needed = eval_model.channels - len(frame_buffer) 81 | 82 | if needed > 0: 83 | for i in range(needed): 84 | frame_buffer.append(frame) 85 | else: 86 | frame_buffer.append(frame) 87 | 88 | stacked = np.stack(frame_buffer) 89 | 90 | frame_buffer.append(frame) 91 | # cv2.imshow("Debug", stacked[0:3].transpose(1, 2, 0)) 92 | # cv2.waitKey(1) 93 | 94 | converted_frame = torch.from_numpy(stacked / 255).type( 95 | torch.FloatTensor).to(PYTORCH_DEVICE) 96 | 97 | inputs = converted_frame.reshape( 98 | (1, converted_frame.shape[0], converted_frame.shape[1], converted_frame.shape[2])) 99 | 100 | out: torch.Tensor = eval_model(inputs) 101 | 102 | self.on_output(out.detach()) 103 | 104 | keyboard.remove_hotkey(toggle_eval) 105 | 106 | 107 | class ActionsThread(EvalThread): 108 | KEYS_STATE_TO_STRING = { 109 | 0: "Idle ", 110 | 1: "Button 1", 111 | 2: "Button 2" 112 | } 113 | 114 | def __init__(self, model_id: str, game_window_name: str = DEFAULT_OSU_WINDOW, eval_key: str = '\\'): 115 | super().__init__(model_id, game_window_name, eval_key) 116 | 117 | 118 | def on_eval_ready(self): 119 | print(f"Actions Model Ready,Press '{self.eval_key}' To Toggle") 120 | 121 | def on_output(self, output: Tensor): 122 | _, predicated = torch.max(output, dim=1) 123 | probs = torch.softmax(output, dim=1) 124 | prob = probs[0][predicated.item()] 125 | if prob.item() > 0: # 0.7: 126 | state = predicated.item() 127 | if state == 0: 128 | keyboard.release('x') 129 | keyboard.release('z') 130 | elif state == 1: 131 | keyboard.release('z') 132 | keyboard.press('x') 133 | elif state == 2: 134 | keyboard.release('x') 135 | keyboard.press('z') 136 | 137 | 138 | class AimThread(EvalThread): 139 | def __init__(self, model_id: str, game_window_name: str = DEFAULT_OSU_WINDOW, eval_key: str = '\\'): 140 | super().__init__(model_id, game_window_name, eval_key) 141 | 142 | # def get_model(self): 143 | # # model = torch.jit.load(os.path.join(MODELS_DIR, self.model_id, 'model.pt')) 144 | # # model.load_state_dict(torch.load(os.path.join(MODELS_DIR, self.model_id, 'weights.pt'))) 145 | # model = AimNet.load(self.model_id) 146 | # model.to(PYTORCH_DEVICE) 147 | # model.eval() 148 | # return model 149 | 150 | def on_eval_ready(self): 151 | print(f"Aim Model Ready,Press '{self.eval_key}' To Toggle") 152 | 153 | def on_output(self, output: Tensor): 154 | mouse_x_percent, mouse_y_percent = output[0] 155 | position = (int((mouse_x_percent * self.capture_params[EPlayAreaIndices.Width.value]) + self.capture_params[ 156 | EPlayAreaIndices.OffsetX.value]), int( 157 | (mouse_y_percent * self.capture_params[EPlayAreaIndices.Height.value]) + self.capture_params[ 158 | EPlayAreaIndices.OffsetY.value])) 159 | # pyautogui.moveTo(position[0], position[1]) 160 | if USE_WIN_32_MOUSE: 161 | import win32api 162 | win32api.SetCursorPos(position) 163 | else: 164 | mouse.move(position[0],position[1]) 165 | 166 | 167 | class CombinedThread(EvalThread): 168 | def __init__(self, model_id: str, game_window_name: str = DEFAULT_OSU_WINDOW, eval_key: str = '\\'): 169 | super().__init__(model_id, game_window_name, eval_key) 170 | 171 | def on_eval_ready(self): 172 | print(f"Combined Model Ready,Press '{self.eval_key}' To Toggle") 173 | 174 | def on_output(self, output: Tensor): 175 | mouse_x_percent, mouse_y_percent, k1_prob, k2_prob = output[0] 176 | position = (int((mouse_x_percent * self.capture_params[EPlayAreaIndices.Width.value]) + self.capture_params[ 177 | EPlayAreaIndices.OffsetX.value]), int( 178 | (mouse_y_percent * self.capture_params[EPlayAreaIndices.Height.value]) + self.capture_params[ 179 | EPlayAreaIndices.OffsetY.value])) 180 | 181 | if USE_WIN_32_MOUSE: 182 | import win32api 183 | win32api.SetCursorPos(position) 184 | else: 185 | mouse.move(position[0],position[1]) 186 | 187 | if k1_prob >= 0.5: 188 | keyboard.press('z') 189 | else: 190 | keyboard.release('z') 191 | 192 | if k2_prob >= 0.5: 193 | keyboard.press('x') 194 | else: 195 | keyboard.release('x') 196 | -------------------------------------------------------------------------------- /ai/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import uuid 4 | import os 5 | import json 6 | import timm 7 | from typing import Callable 8 | from datetime import datetime 9 | from ai.constants import CURRENT_STACK_NUM, FINAL_PLAY_AREA_SIZE 10 | from ai.utils import refresh_model_list 11 | from ai.enums import EModelType 12 | 13 | 14 | def get_timm_model(build_final_layer: Callable[[int], nn.Module] = lambda a: nn.Linear(a, 2), channels=3, 15 | model_name="resnet18", pretrained=False): 16 | model = timm.create_model(model_name=model_name, pretrained=pretrained, in_chans=channels, num_classes=3) 17 | # model = timm.create_model("resnet18",pretrained=True,in_chans=3,num_classes=3) 18 | classifier = model.default_cfg['classifier'] 19 | 20 | in_features = getattr(model, classifier).in_features # Get the number of input features for the final layer 21 | 22 | setattr(model, classifier, build_final_layer(in_features)) # Replace the final layer 23 | 24 | return model 25 | 26 | 27 | class OsuAiModel(torch.nn.Module): 28 | def __init__(self, channels=CURRENT_STACK_NUM, model_type: EModelType = EModelType.Unknown) -> None: 29 | super().__init__() 30 | self.channels = channels 31 | self.model_type = model_type 32 | 33 | def save(self, project_name: str, datasets: list[str], epochs: int, learning_rate: int, path: str = './models', 34 | weights=None): 35 | 36 | model_id = str(uuid.uuid4()) 37 | 38 | weights_to_save = weights if weights is not None else self.state_dict() 39 | 40 | save_dir = os.path.join(path, model_id) 41 | 42 | os.mkdir(save_dir) 43 | 44 | weights_dir = os.path.join(save_dir, 'weights.pt') 45 | 46 | model_dir = os.path.join(save_dir, 'model.pt') 47 | 48 | torch.save(weights_to_save, weights_dir) 49 | 50 | model_scripted = torch.jit.script(self) # Export to TorchScript 51 | 52 | model_scripted.save(model_dir) # Save 53 | 54 | config = { 55 | "name": project_name, 56 | "channels": self.channels, 57 | "date": str(datetime.utcnow()), 58 | "datasets": datasets, 59 | "type": self.model_type.name, 60 | "epochs": epochs, 61 | "lr": learning_rate 62 | } 63 | 64 | with open(os.path.join(save_dir, 'info.json'), 'w') as f: 65 | json.dump(config, f, indent=2) 66 | 67 | refresh_model_list() 68 | 69 | @staticmethod 70 | def load(model_id: str, model_gen=lambda *a, **b: OsuAiModel(*a, **b)): 71 | weights_path = os.path.join('./models', model_id, 'weights.pt') 72 | config_path = os.path.join('./models', model_id, 'info.json') 73 | weights = torch.load(weights_path) 74 | with open(config_path, 'r') as f: 75 | config_json = json.load(f) 76 | model = model_gen( 77 | channels=config_json['channels'], model_type=EModelType(config_json['type'])) 78 | model.load_state_dict(weights) 79 | print(model.model_type) 80 | return model 81 | 82 | 83 | class AimNet(OsuAiModel): 84 | """ 85 | Works 86 | 87 | Args: 88 | torch (_type_): _description_ 89 | """ 90 | 91 | def __init__(self, channels=CURRENT_STACK_NUM, model_type: EModelType = EModelType.Aim): 92 | super().__init__(channels, model_type) 93 | 94 | self.conv = get_timm_model(build_final_layer=lambda features: nn.Sequential( 95 | nn.Linear(features, 512), 96 | nn.ReLU(), 97 | nn.Dropout(0.4), 98 | nn.Linear(512, 256), 99 | nn.ReLU(), 100 | nn.Dropout(0.4), 101 | nn.Linear(256, 2), 102 | ), channels=channels) 103 | 104 | def forward(self, images): 105 | return self.conv(images) 106 | 107 | @staticmethod 108 | def load(model_id: str): 109 | return OsuAiModel.load(model_id, lambda *a, **b: AimNet(*a, **b)) 110 | 111 | 112 | class ActionsNet(OsuAiModel): 113 | """ 114 | Works so far 115 | 116 | Args: 117 | torch (_type_): _description_ 118 | """ 119 | 120 | def __init__(self, channels=CURRENT_STACK_NUM, model_type: EModelType = EModelType.Actions): 121 | super().__init__(channels, model_type) 122 | 123 | self.conv = get_timm_model(build_final_layer=lambda features: nn.Sequential( 124 | nn.Linear(features, 512), 125 | nn.ReLU(), 126 | nn.Dropout(0.4), 127 | nn.Linear(512, 3), 128 | ), channels=channels) 129 | 130 | def forward(self, images): 131 | return self.conv(images) 132 | 133 | @staticmethod 134 | def load(model_id: str): 135 | return OsuAiModel.load(model_id, lambda *a, **b: ActionsNet(*a, **b)) 136 | 137 | 138 | class CombinedNet(OsuAiModel): 139 | """ 140 | Works 141 | 142 | Args: 143 | torch (_type_): _description_ 144 | """ 145 | 146 | def __init__(self, channels=CURRENT_STACK_NUM, model_type: EModelType = EModelType.Combined): 147 | super().__init__(channels, model_type) 148 | 149 | self.conv = get_timm_model(build_final_layer=lambda features: nn.Sequential( 150 | nn.Linear(features, 512), 151 | nn.ReLU(), 152 | nn.Dropout(0.4), 153 | nn.Linear(512, 256), 154 | nn.ReLU(), 155 | nn.Dropout(0.4), 156 | nn.Linear(256, 4) 157 | ), channels=channels) 158 | 159 | def forward(self, images): 160 | return self.conv(images) 161 | 162 | @staticmethod 163 | def load(model_id: str): 164 | return OsuAiModel.load(model_id, lambda *a, **b: CombinedNet(*a, **b)) -------------------------------------------------------------------------------- /ai/play.py: -------------------------------------------------------------------------------- 1 | from ai.utils import FixedRuntime, get_models, get_validated_input, EModelType 2 | from ai.eval import ActionsThread, AimThread, CombinedThread 3 | import traceback 4 | 5 | 6 | def start_play(): 7 | try: 8 | action_models = get_models(EModelType.Actions) 9 | 10 | aim_models = get_models(EModelType.Aim) 11 | 12 | combined_models = get_models(EModelType.Combined) 13 | 14 | user_choice = get_validated_input(f"""What type of model would you like to test? 15 | [0] Aim Model | {len(aim_models)} Available 16 | [1] Actions Model | {len(action_models)} Available 17 | [2] Combined Model | {len(combined_models)} Available 18 | """, lambda a: a.strip().isnumeric() and (0 <= int(a.strip()) <= 2), lambda a: int(a.strip())) 19 | 20 | active_model = None 21 | if user_choice == 0: 22 | prompt = "What aim model would you like to use?\n" 23 | for i in range(len(aim_models)): 24 | prompt += f" [{i}] {aim_models[i]}\n" 25 | 26 | model_index = get_validated_input(prompt, lambda a: a.strip().isnumeric() and ( 27 | 0 <= int(a.strip()) < len(aim_models)), lambda a: int(a.strip())) 28 | 29 | active_model = AimThread(model_id=aim_models[model_index]['id']) 30 | 31 | elif user_choice == 1: 32 | prompt = "What actions model would you like to use?\n" 33 | for i in range(len(action_models)): 34 | prompt += f" [{i}] {action_models[i]}\n" 35 | 36 | model_index = get_validated_input(prompt, lambda a: a.strip().isnumeric() and ( 37 | 0 <= int(a.strip()) < len(action_models)), lambda a: int(a.strip())) 38 | 39 | active_model = ActionsThread( 40 | model_id=action_models[model_index]['id']) 41 | else: 42 | prompt = "What combined model would you like to use?\n" 43 | for i in range(len(combined_models)): 44 | prompt += f" [{i}] {combined_models[i]}\n" 45 | 46 | model_index = get_validated_input(prompt, lambda a: a.strip().isnumeric() and ( 47 | 0 <= int(a.strip()) < len(combined_models)), lambda a: int(a.strip())) 48 | 49 | active_model = CombinedThread( 50 | model_id=combined_models[model_index]['id']) 51 | 52 | try: 53 | while True: 54 | with FixedRuntime(2): 55 | pass 56 | 57 | except KeyboardInterrupt as e: 58 | if active_model is not None: 59 | active_model.kill() 60 | except Exception as e: 61 | traceback.print_exc() 62 | -------------------------------------------------------------------------------- /ai/train.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from os import getcwd, path 3 | import torch 4 | import torch.nn as nn 5 | from tqdm import tqdm 6 | from ai.dataset import OsuDataset 7 | import torchvision.transforms as transforms 8 | from ai.models import ActionsNet, AimNet, CombinedNet 9 | from torch.utils.data import DataLoader 10 | from ai.utils import get_datasets, get_validated_input, get_models 11 | from ai.constants import PYTORCH_DEVICE 12 | from ai.enums import EModelType 13 | 14 | transform = transforms.ToTensor() 15 | 16 | PYTORCH_DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 17 | 18 | SAVE_PATH = path.normpath(path.join(getcwd(), 'models')) 19 | 20 | 21 | @torch.jit.script 22 | def get_acc(predicted: torch.Tensor, truth: torch.Tensor, thresh: int = 60, is_combined: bool = False): 23 | predicted = predicted.detach().clone() 24 | truth = truth.detach().clone() 25 | 26 | predicted[:, 0] *= 1920 27 | predicted[:, 1] *= 1080 28 | truth[:, 0] *= 1920 29 | truth[:, 1] *= 1080 30 | 31 | diff = (predicted[:, :-2] - truth[:, :-2]) if is_combined else predicted - truth 32 | 33 | dist = torch.sqrt((diff ** 2).sum(dim=1)) 34 | 35 | dist[dist < thresh] = 1 36 | 37 | dist[dist >= thresh] = 0 38 | 39 | if not is_combined: 40 | return dist.mean().item() 41 | 42 | predicted_keys = predicted[:, 2:] 43 | truth_keys = truth[:, 2:] 44 | 45 | predicted_keys[predicted_keys >= 0.5] = 1 46 | truth_keys[truth_keys >= 0.5] = 1 47 | predicted_keys[predicted_keys < 0.5] = 0 48 | truth_keys[truth_keys < 0.5] = 0 49 | 50 | return (dist.mean().item() + torch.all(predicted_keys == truth_keys, dim=1).float().mean().item()) / 2 51 | 52 | 53 | def train_action_net(datasets: list[str], force_rebuild=False, checkpoint_model_id=None, 54 | batch_size=64, 55 | epochs=1, learning_rate=0.0001, project_name=""): 56 | if len(project_name.strip()) == 0: 57 | project_name = f"Project with {len(datasets)} Datasets" 58 | 59 | train_set = OsuDataset( 60 | datasets=datasets, label_type=EModelType.Actions) 61 | 62 | osu_data_loader = DataLoader( 63 | train_set, 64 | batch_size=batch_size, 65 | shuffle=True 66 | ) 67 | 68 | model = None 69 | 70 | if checkpoint_model_id: 71 | model = ActionsNet.load(checkpoint_model_id).type( 72 | torch.FloatTensor).to(PYTORCH_DEVICE) 73 | else: 74 | model = ActionsNet().type(torch.FloatTensor).to(PYTORCH_DEVICE) 75 | 76 | criterion = nn.CrossEntropyLoss() 77 | optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) 78 | 79 | best_state = copy.deepcopy(model.state_dict()) 80 | best_loss = 99999999999 81 | best_epoch = 0 82 | patience = 20 83 | patience_count = 0 84 | 85 | try: 86 | for epoch in range(epochs): 87 | loading_bar = tqdm(total=len(osu_data_loader)) 88 | total_accu, total_count = 0, 0 89 | running_loss = 0 90 | for idx, data in enumerate(osu_data_loader): 91 | images, results = data 92 | images = images.type(torch.FloatTensor).to(PYTORCH_DEVICE) 93 | results = results.type(torch.LongTensor).to(PYTORCH_DEVICE) 94 | 95 | optimizer.zero_grad() 96 | 97 | outputs = model(images) 98 | 99 | loss = criterion(outputs, results) 100 | 101 | loss.backward() 102 | optimizer.step() 103 | total_accu += (outputs.argmax(1) == results).sum().item() 104 | total_count += results.size(0) 105 | running_loss += loss.item() * images.size(0) 106 | loading_bar.set_description_str( 107 | f'Training Actions | {project_name} | epoch {epoch + 1}/{epochs} | Accuracy {((total_accu / total_count) * 100):.4f} | loss {(running_loss / len(osu_data_loader.dataset)):.8f} | ') 108 | loading_bar.update() 109 | loading_bar.set_description_str( 110 | f'Training Actions | {project_name} | epoch {epoch + 1}/{epochs} | Accuracy {((total_accu / total_count) * 100):.4f} | loss {(running_loss / len(osu_data_loader.dataset)):.8f} | ') 111 | loading_bar.close() 112 | epoch_loss = running_loss / len(osu_data_loader.dataset) 113 | epoch_accu = (total_accu / total_count) * 100 114 | 115 | if epoch_loss < best_loss: 116 | best_loss = epoch_loss 117 | best_state = copy.deepcopy(model.state_dict()) 118 | best_epoch = epoch 119 | patience_count = 0 120 | else: 121 | patience_count += 1 122 | 123 | if patience_count == patience: 124 | model.save(project_name, datasets, best_epoch, 125 | learning_rate, weights=best_state) 126 | break 127 | 128 | # if running_loss / len(osu_data_loader.dataset) < best_loss: 129 | # best_loss = running_loss / len(osu_data_loader.dataset) 130 | # best_state = copy.deepcopy(model.state_dict()) 131 | # best_epoch = epoch 132 | 133 | except KeyboardInterrupt: 134 | if get_validated_input("Would you like to save the last epoch?\n", lambda a: True, 135 | lambda a: a.strip().lower()).startswith("y"): 136 | model.save(project_name, datasets, best_epoch, 137 | learning_rate, weights=best_state) 138 | return 139 | 140 | 141 | def train_aim_net(datasets: list[str], force_rebuild=False, checkpoint_model_id=None, batch_size=64, 142 | epochs=1, learning_rate=0.0001, project_name=""): 143 | if len(project_name.strip()) == 0: 144 | project_name = f"Project with {len(datasets)} Datasets" 145 | 146 | train_set = OsuDataset( 147 | datasets=datasets, label_type=EModelType.Aim) 148 | 149 | osu_data_loader = DataLoader( 150 | train_set, 151 | batch_size=batch_size, 152 | shuffle=True 153 | ) 154 | 155 | # print(train_set[0][0].shape, train_set[1000:1050][1]) 156 | 157 | model = None 158 | 159 | if checkpoint_model_id: 160 | model = AimNet.load(checkpoint_model_id).to(PYTORCH_DEVICE) 161 | else: 162 | model = AimNet().to(PYTORCH_DEVICE) 163 | 164 | criterion = nn.MSELoss() 165 | optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) 166 | 167 | best_state = copy.deepcopy(model.state_dict()) 168 | best_loss = 99999999999 169 | best_epoch = 0 170 | patience = 100 171 | patience_count = 0 172 | try: 173 | for epoch in range(epochs): 174 | loading_bar = tqdm(total=len(osu_data_loader)) 175 | total_accu, total_count = 0, 0 176 | running_loss = 0 177 | for idx, data in enumerate(osu_data_loader): 178 | images, expected = data 179 | images: torch.Tensor = images.type( 180 | torch.FloatTensor).to(PYTORCH_DEVICE) 181 | expected: torch.Tensor = expected.type( 182 | torch.FloatTensor).to(PYTORCH_DEVICE) 183 | 184 | outputs: torch.Tensor = model(images) 185 | 186 | loss = criterion(outputs, expected) 187 | 188 | optimizer.zero_grad() 189 | 190 | loss.backward() 191 | optimizer.step() 192 | 193 | total_accu += get_acc(outputs, expected) 194 | total_count += 1 195 | running_loss += loss.item() * images.size(0) 196 | loading_bar.set_description_str( 197 | f'Training Aim | {project_name} | epoch {epoch + 1}/{epochs} | Accuracy {((total_accu / total_count) * 100):.4f} | loss {(running_loss / len(osu_data_loader.dataset)):.10f} | ') 198 | loading_bar.update() 199 | epoch_loss = running_loss / len(osu_data_loader.dataset) 200 | epoch_accu = (total_accu / total_count) * 100 201 | loading_bar.set_description_str( 202 | f'Training Aim | {project_name} | epoch {epoch + 1}/{epochs} | Accuracy {(epoch_accu):.4f} | loss {(epoch_loss):.10f} | ') 203 | loading_bar.close() 204 | if epoch_loss < best_loss: 205 | best_loss = epoch_loss 206 | best_state = copy.deepcopy(model.state_dict()) 207 | best_epoch = epoch 208 | patience_count = 0 209 | else: 210 | patience_count += 1 211 | 212 | if patience_count == patience: 213 | model.save(project_name, datasets, best_epoch, 214 | learning_rate, weights=best_state) 215 | return 216 | except KeyboardInterrupt: 217 | if get_validated_input("Would you like to save the best epoch?\n", lambda a: True, 218 | lambda a: a.strip().lower()).startswith("y"): 219 | model.save(project_name, datasets, best_epoch, 220 | learning_rate, weights=best_state) 221 | 222 | return 223 | 224 | model.save(project_name, datasets, best_epoch, learning_rate, weights=best_state) 225 | 226 | 227 | def train_combined_net(datasets: list[str], force_rebuild=False, checkpoint_model_id=None, batch_size=64, 228 | epochs=1, learning_rate=0.0001, project_name=""): 229 | if len(project_name.strip()) == 0: 230 | project_name = f"Project with {len(datasets)} Datasets" 231 | 232 | train_set = OsuDataset( 233 | datasets=datasets, label_type=EModelType.Combined) 234 | 235 | osu_data_loader = DataLoader( 236 | train_set, 237 | batch_size=batch_size, 238 | shuffle=True 239 | ) 240 | 241 | # print(train_set[0][0].shape, train_set[1000:1050][1]) 242 | 243 | model = None 244 | 245 | if checkpoint_model_id: 246 | model = CombinedNet.load(checkpoint_model_id).to(PYTORCH_DEVICE) 247 | else: 248 | model = CombinedNet().to(PYTORCH_DEVICE) 249 | 250 | criterion = nn.MSELoss() 251 | optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) 252 | 253 | best_state = copy.deepcopy(model.state_dict()) 254 | best_loss = 99999999999 255 | best_epoch = 0 256 | patience = 100 257 | patience_count = 0 258 | try: 259 | for epoch in range(epochs): 260 | loading_bar = tqdm(total=len(osu_data_loader)) 261 | total_accu, total_count = 0, 0 262 | running_loss = 0 263 | for idx, data in enumerate(osu_data_loader): 264 | images, expected = data 265 | images: torch.Tensor = images.type( 266 | torch.FloatTensor).to(PYTORCH_DEVICE) 267 | expected: torch.Tensor = expected.type( 268 | torch.FloatTensor).to(PYTORCH_DEVICE) 269 | 270 | outputs: torch.Tensor = model(images) 271 | 272 | loss = criterion(outputs, expected) 273 | 274 | optimizer.zero_grad() 275 | 276 | loss.backward() 277 | optimizer.step() 278 | 279 | total_accu += get_acc(outputs, expected, is_combined=True) 280 | total_count += 1 281 | running_loss += loss.item() * images.size(0) 282 | loading_bar.set_description_str( 283 | f'Training Aim | {project_name} | epoch {epoch + 1}/{epochs} | Accuracy {((total_accu / total_count) * 100):.4f} | loss {(running_loss / len(osu_data_loader.dataset)):.10f} | ') 284 | loading_bar.update() 285 | epoch_loss = running_loss / len(osu_data_loader.dataset) 286 | epoch_accu = (total_accu / total_count) * 100 287 | loading_bar.set_description_str( 288 | f'Training Aim | {project_name} | epoch {epoch + 1}/{epochs} | Accuracy {(epoch_accu):.4f} | loss {(epoch_loss):.10f} | ') 289 | loading_bar.close() 290 | if epoch_loss < best_loss: 291 | best_loss = epoch_loss 292 | best_state = copy.deepcopy(model.state_dict()) 293 | best_epoch = epoch 294 | patience_count = 0 295 | else: 296 | patience_count += 1 297 | 298 | if patience_count == patience: 299 | model.save(project_name, datasets, best_epoch, 300 | learning_rate, weights=best_state) 301 | return 302 | except KeyboardInterrupt: 303 | if get_validated_input("Would you like to save the best epoch?\n", lambda a: True, 304 | lambda a: a.strip().lower()).startswith("y"): 305 | model.save(project_name, datasets, best_epoch, 306 | learning_rate, weights=best_state) 307 | 308 | return 309 | 310 | model.save(project_name, datasets, best_epoch, learning_rate, weights=best_state) 311 | 312 | 313 | def get_train_data(data_type: EModelType, datasets: list[int], datasets_prompt: str, models: list[dict], models_prompt: str): 314 | 315 | project_name = get_validated_input("What would you like to name this project ?", 316 | lambda a: True, lambda a: a.strip()) 317 | def validate_datasets_selection(received: str): 318 | try: 319 | items = map(int, received.strip().split(",")) 320 | for item in items: 321 | if not 0 <= item < len(datasets): 322 | return False 323 | return True 324 | except: 325 | return False 326 | 327 | selected_datasets = get_validated_input( 328 | datasets_prompt, validate_datasets_selection, lambda a: map(int, a.strip().split(","))) 329 | checkpoint = None 330 | print('\nConfig for', data_type.value, '\n') 331 | epochs = get_validated_input("Max epochs to train for ?\n", 332 | lambda a: a.strip().isnumeric() and 0 <= int(a.strip()), lambda a: int(a.strip())) 333 | 334 | if len(models) > 0: 335 | if get_validated_input("Would you like to use a checkpoint?\n", lambda a: True, 336 | lambda a: a.strip().lower()).startswith("y"): 337 | checkpoint_index = get_validated_input(models_prompt, lambda a: a.strip().isnumeric() and ( 338 | 0 <= int(a.strip()) < len(models_prompt)), lambda a: int(a.strip())) 339 | checkpoint = models[checkpoint_index] 340 | 341 | return data_type,project_name, list(map(lambda a: datasets[a], selected_datasets)), checkpoint, epochs 342 | 343 | 344 | def start_train(): 345 | datasets = get_datasets() 346 | prompt = """What type of training would you like to do? 347 | [0] Train Aim 348 | [1] Train Actions 349 | [2] Train Combined 350 | """ 351 | 352 | dataset_prompt = "Please select datasets from below seperated by a comma:\n" 353 | 354 | for i in range(len(datasets)): 355 | dataset_prompt += f" [{i}] {datasets[i]}\n" 356 | 357 | user_choice = get_validated_input(prompt, lambda a: a.strip().isnumeric() and ( 358 | 0 <= int(a.strip()) <= 2), lambda a: int(a.strip())) 359 | 360 | training_tasks = [] 361 | 362 | models_prompt = "Please select a model from below:\n" 363 | 364 | if user_choice == 0: 365 | 366 | models = get_models(EModelType.Aim) 367 | 368 | for i in range(len(models)): 369 | models_prompt += f" [{i}] {models[i]}\n" 370 | 371 | training_tasks.append(get_train_data( 372 | EModelType.Aim, datasets, dataset_prompt, models, models_prompt)) 373 | 374 | elif user_choice == 1: 375 | 376 | models = get_models(EModelType.Actions) 377 | 378 | for i in range(len(models)): 379 | models_prompt += f" [{i}] {models[i]}\n" 380 | 381 | training_tasks.append(get_train_data( 382 | EModelType.Actions, datasets, dataset_prompt, models, models_prompt)) 383 | 384 | else: 385 | 386 | models = get_models(EModelType.Combined) 387 | 388 | for i in range(len(models)): 389 | models_prompt += f" [{i}] {models[i]}\n" 390 | 391 | training_tasks.append(get_train_data( 392 | EModelType.Combined, datasets, dataset_prompt, models, models_prompt)) 393 | 394 | for task in training_tasks: 395 | task_type,project_name, dataset, checkpoint, epochs = task 396 | if task_type == EModelType.Aim: 397 | train_aim_net( 398 | project_name=project_name,datasets=dataset,checkpoint_model_id=checkpoint['id'] if checkpoint is not None else None, epochs=epochs) 399 | elif task_type == EModelType.Actions: 400 | train_action_net(project_name=project_name,datasets=dataset,checkpoint_model_id=checkpoint['id'] if checkpoint is not None else None, epochs=epochs) 401 | else: 402 | train_combined_net(project_name=project_name,datasets=dataset,checkpoint_model_id=checkpoint['id'] if checkpoint is not None else None, epochs=epochs) 403 | -------------------------------------------------------------------------------- /ai/utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import os 4 | import socket 5 | import time 6 | import traceback 7 | import cv2 8 | import numpy as np 9 | from os import listdir, path 10 | from socket import SHUT_RDWR 11 | from tempfile import TemporaryDirectory 12 | from threading import Thread, Timer, Event 13 | from datetime import datetime 14 | from ai.constants import RAW_DATA_DIR, MODELS_DIR,CAPTURE_HEIGHT_PERCENT 15 | from mss import mss 16 | from queue import Queue 17 | from tqdm import tqdm 18 | from ai.enums import EModelType 19 | from typing import TypeVar, Callable, Union, TypeVar 20 | import math 21 | import sys 22 | import subprocess 23 | 24 | 25 | class Cv2VideoContext: 26 | 27 | def __init__(self, file_path: str): 28 | # example file or database connection 29 | self.file_path = file_path 30 | self.cap = cv2.VideoCapture(file_path, cv2.CAP_FFMPEG) 31 | 32 | def __enter__(self): 33 | if not self.cap.isOpened(): 34 | raise BaseException(f"Error opening video stream or file {self.file_path}") 35 | return self 36 | 37 | def __exit__(self, exc_type, exc_value, traceback): 38 | self.cap.release() 39 | 40 | 41 | T = TypeVar("T") 42 | Z = TypeVar("Z") 43 | 44 | 45 | class EventsSamplerEventTypes: 46 | MOUSE = 0 47 | KEYS = 1 48 | 49 | 50 | class EventsSampler: 51 | def __init__(self, events: list[T]) -> None: 52 | self.events = sorted(events, key=lambda a: a['time']) 53 | self.events_num = len(self.events) 54 | self.last_sampled_time = 0 55 | self.last_sampled_index = 0 56 | 57 | def get(self, idx: int): 58 | return self.events[idx]['time'], self.events[idx]['x'], self.events[idx]['y'], self.events[idx]['keys'] 59 | 60 | def sample_mouse(self, target_time_ms: float = 0) -> Z: 61 | if target_time_ms <= self.events[0]['time']: 62 | return self.events[0] 63 | 64 | if target_time_ms >= self.events[self.events_num - 1]['time']: 65 | return self.events[self.events_num - 1] 66 | 67 | # search_range = range(self.last_sampled_index,self.events_num - 1) if self.last_sampled_time <= target_time_ms else reversed(range(self.last_sampled_index + 1,0)) 68 | # for i in search_range: 69 | search_range = range(0, self.events_num - 1) 70 | for i in search_range: 71 | cur_time = self.events[i]['time'] 72 | next_time = self.events[i + 1]['time'] 73 | if cur_time <= target_time_ms <= next_time: 74 | events_dist = (next_time - cur_time) 75 | 76 | target_time_dist = (target_time_ms - cur_time) 77 | alpha = target_time_dist / events_dist 78 | a = self.events[i] 79 | b = self.events[i + 1] 80 | self.last_sampled_index = i 81 | self.last_sampled_time = target_time_ms 82 | return cur_time, a["x"] + ((b['x'] - a['x']) * alpha), a["y"] + ((b['y'] - a['y']) * alpha) 83 | 84 | raise BaseException("NO SAMPLE FOUND") 85 | 86 | def sample_keys(self, target_time_ms: float = 0) -> Z: 87 | if target_time_ms <= self.events[0]['time']: 88 | return self.events[0] 89 | 90 | if target_time_ms >= self.events[self.events_num - 1]['time']: 91 | return self.events[self.events_num - 1] 92 | 93 | # search_range = range(self.last_sampled_index,self.events_num - 1) if self.last_sampled_time <= target_time_ms else reversed(range(self.last_sampled_index + 1,0)) 94 | # for i in search_range: 95 | search_range = range(0, self.events_num - 1) 96 | for i in search_range: 97 | cur_time = self.events[i]['time'] 98 | next_time = self.events[i + 1]['time'] 99 | if cur_time <= target_time_ms <= next_time: 100 | events_dist = (next_time - cur_time) 101 | 102 | target_time_dist = (target_time_ms - cur_time) 103 | alpha = target_time_dist / events_dist 104 | a = self.events[i] 105 | b = self.events[i + 1] 106 | self.last_sampled_index = i 107 | self.last_sampled_time = target_time_ms 108 | return cur_time, b["keys"] if alpha >= 0.5 else a['keys'] 109 | 110 | raise BaseException("NO SAMPLE FOUND") 111 | 112 | 113 | class KeysSampler: 114 | def __init__(self, keys_events: list) -> None: 115 | self.events = sorted(keys_events, key=lambda a: a['time']) 116 | self.events_num = len(self.events) 117 | self.last_sampled_time = 0 118 | self.last_sampled_index = 0 119 | 120 | def get(self, idx: int): 121 | return self.events[idx]['time'], self.events[idx]['keys'] 122 | 123 | def sample(self, target_time_ms: float = 0, key_press_allowance_ms=6) -> list[float, tuple]: 124 | if target_time_ms <= self.events[0]['time']: 125 | return self.events[0] 126 | 127 | if target_time_ms >= self.events[self.events_num - 1]['time']: 128 | return self.events[self.events_num - 1] 129 | 130 | # search_range = range(self.last_sampled_index,self.events_num - 1) if self.last_sampled_time <= target_time_ms else reversed(range(self.last_sampled_index + 1,0)) 131 | # for i in search_range: 132 | search_range = range(0, self.events_num - 1) 133 | last_idx_with_press = -1 134 | for i in search_range: 135 | event_time, event_keys = self.get(i) 136 | next_event_time, next_event_keys = self.get(i + 1) 137 | if event_keys[0] or event_keys[1]: 138 | last_idx_with_press = i 139 | if event_time <= target_time_ms <= next_event_time: 140 | events_dist = (next_event_time - event_time) 141 | 142 | target_time_dist = (target_time_ms - event_time) 143 | alpha = target_time_dist / events_dist 144 | self.last_sampled_index = i 145 | self.last_sampled_time = target_time_ms 146 | 147 | keys_result = (False, False) 148 | 149 | next_idx_with_press = -1 150 | for j in range(i, self.events_num - 1): 151 | cur = self.get(j) 152 | if cur[1][0] or cur[1][1]: 153 | next_idx_with_press = j 154 | break 155 | 156 | dist_last_press = target_time_ms - self.get(last_idx_with_press)[0] 157 | if next_idx_with_press != -1: 158 | 159 | dist_next_press = self.get(next_idx_with_press)[0] - target_time_ms 160 | 161 | if dist_last_press < dist_next_press: 162 | if dist_last_press <= key_press_allowance_ms: 163 | keys_result = self.get(last_idx_with_press)[1] 164 | else: 165 | if dist_next_press <= key_press_allowance_ms: 166 | keys_result = self.get(next_idx_with_press)[1] 167 | 168 | elif dist_last_press <= key_press_allowance_ms: 169 | keys_result = self.get(last_idx_with_press)[1] 170 | 171 | return [event_time + (events_dist * alpha), keys_result] 172 | 173 | raise BaseException("NO SAMPLE FOUND") 174 | 175 | 176 | def run_file(file_path: str): 177 | process = subprocess.Popen(f"{sys.executable} {file_path}", shell=True) 178 | process.communicate() 179 | return process.returncode 180 | 181 | def derive_capture_params(window_width=1920, window_height=1080): 182 | osu_play_field_ratio = 3 / 4 183 | capture_height= int(window_height * CAPTURE_HEIGHT_PERCENT) 184 | capture_width = int(capture_height / osu_play_field_ratio) 185 | capture_params = [capture_width, capture_height, 186 | int((window_width - capture_width) / 2), int((window_height - capture_height) / 2)] 187 | 188 | return capture_params 189 | 190 | def playfield_coords_to_screen(playfield_x,playfield_y,screen_w=1920,screen_h=1080,account_for_capture_params = False): 191 | 192 | 193 | play_field_ratio = 4 / 3 194 | screen_ratio = screen_w / screen_h 195 | 196 | play_field_factory_width = 512 197 | play_field_factory_height = play_field_factory_width / play_field_ratio 198 | factory_h = play_field_factory_height * 1.2 199 | factory_w = factory_h * screen_ratio 200 | factory_dx = (factory_w - play_field_factory_width) / 2 201 | factory_dy = (factory_h - play_field_factory_height) / 2 202 | screen_dx = factory_dx * (screen_w / factory_w) 203 | screen_dy = factory_dy * (screen_h / factory_h) 204 | screen_x = playfield_x * (screen_w / factory_w) 205 | screen_y = playfield_y * (screen_h / factory_h) 206 | 207 | if account_for_capture_params: 208 | cap_x,cap_y,cap_dx,cap_dy = derive_capture_params(screen_w,screen_h) 209 | screen_dx = screen_dx - cap_dx 210 | screen_dy = screen_dy - cap_dy 211 | 212 | return [screen_x,screen_y,screen_dx,screen_dy] 213 | 214 | 215 | """ 216 | Ensures this context runs for the given fixed time or more 217 | 218 | Returns: 219 | _type_: _description_ 220 | 221 | """ 222 | 223 | 224 | class FixedRuntime: 225 | def __init__(self, target_time: float = 1.0, debug=None): 226 | self.start = 0 227 | self.delay = target_time 228 | self.debug = debug 229 | 230 | def __enter__(self): 231 | self.start = time.time() 232 | return self 233 | 234 | def __exit__(self, exc_type, exc_value, exc_traceback): 235 | elapsed = time.time() - self.start 236 | wait_time = self.delay - elapsed 237 | if wait_time > 0: 238 | time.sleep(wait_time) 239 | if self.debug is not None: 240 | print(f"Context [{self.debug}] elapsed {wait_time * -1:.4f}s") 241 | else: 242 | if self.debug is not None: 243 | print(f"Context [{self.debug}] elapsed {wait_time * -1:.4f}s") 244 | 245 | 246 | MESSAGES_SENT = 0 247 | 248 | AIM_MODELS = [] 249 | CLICKS_MODELS = [] 250 | COMBINED_MODELS = [] 251 | 252 | 253 | def refresh_model_list(): 254 | global AIM_MODELS 255 | global CLICKS_MODELS 256 | global COMBINED_MODELS 257 | AIM_MODELS = [] 258 | CLICKS_MODELS = [] 259 | COMBINED_MODELS = [] 260 | if not os.path.exists(MODELS_DIR): 261 | os.makedirs(MODELS_DIR) 262 | return 263 | 264 | for model_id in os.listdir(MODELS_DIR): 265 | 266 | model_path = os.path.join(MODELS_DIR, model_id) 267 | 268 | with open(os.path.join(model_path, 'info.json'), 'r') as f: 269 | data = json.load(f) 270 | 271 | payload = { 272 | 'id': model_id, 273 | 'name': data['name'], 274 | 'date': datetime.strptime(data['date'], "%Y-%m-%d %H:%M:%S.%f"), 275 | 'channels': data['channels'], 276 | 'datasets': data['datasets'] 277 | } 278 | 279 | if data['type'] == EModelType.Aim.value: 280 | AIM_MODELS.append(payload) 281 | elif data['type'] == EModelType.Actions.value: 282 | CLICKS_MODELS.append(payload) 283 | elif data['type'] == EModelType.Combined.value: 284 | COMBINED_MODELS.append(payload) 285 | 286 | AIM_MODELS = sorted(AIM_MODELS, key=lambda a: 0 - a['date'].timestamp()) 287 | CLICKS_MODELS = sorted(CLICKS_MODELS, key=lambda a: 0 - a['date'].timestamp()) 288 | COMBINED_MODELS = sorted(COMBINED_MODELS, 289 | key=lambda a: 0 - a['date'].timestamp()) 290 | 291 | 292 | refresh_model_list() 293 | 294 | 295 | def get_models(model_type: EModelType) -> list[dict]: 296 | global AIM_MODELS 297 | global CLICKS_MODELS 298 | global COMBINED_MODELS 299 | if model_type == EModelType.Aim: 300 | return AIM_MODELS 301 | elif model_type == EModelType.Actions: 302 | return CLICKS_MODELS 303 | else: 304 | return COMBINED_MODELS 305 | 306 | 307 | def get_datasets() -> list[str]: 308 | return listdir(RAW_DATA_DIR) 309 | 310 | 311 | def get_validated_input(prompt="You forgot to put your own prompt", 312 | validate_fn: Callable[[str], bool] = lambda a: len(a.strip()) != 0, 313 | conversion_fn: Callable[[str], T] = lambda a: a.strip(), 314 | on_validation_error: Callable[[str], None] = lambda 315 | a: print("Invalid input, please try again.")) -> T: 316 | input_as_str = input(prompt) 317 | 318 | if not validate_fn(input_as_str): 319 | on_validation_error(input_as_str) 320 | return get_validated_input(prompt, validate_fn, conversion_fn) 321 | 322 | return conversion_fn(input_as_str) 323 | 324 | 325 | class FileWatcher(Thread): 326 | def __init__(self, file_path, callback, poll_frequency=0.05): 327 | super().__init__(group=None, daemon=True) 328 | self.file_path = file_path 329 | self.callback = callback 330 | self.freq = poll_frequency 331 | self.callback(open(self.file_path).readlines()) 332 | self.buff = Queue() 333 | self.start() 334 | 335 | def kill(self): 336 | if self.is_alive(): 337 | self.buff.put("Shinu") 338 | self.join() 339 | 340 | def run(self): 341 | modified_on = os.path.getmtime(self.file_path) 342 | try: 343 | while True: 344 | if not self.buff.empty(): 345 | break 346 | time.sleep(self.freq) 347 | modified = os.path.getmtime(self.file_path) 348 | if modified != modified_on: 349 | modified_on = modified 350 | self.callback(open(self.file_path).readlines()) 351 | except Exception as e: 352 | print(traceback.format_exc()) 353 | 354 | 355 | class OsuSocketServer: 356 | def __init__(self, on_state_updated) -> None: 357 | self.active_thread = None 358 | self.osu_game = None 359 | self.active = False 360 | self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 361 | self.on_state_updated = on_state_updated 362 | self.pending_messages = {} 363 | 364 | def connect(self): 365 | self.active = True 366 | self.sock.bind(("127.0.0.1", 11000)) 367 | self.osu_game = ("127.0.0.1", 12000) 368 | self.active_thread = Thread(group=None, target=self.receive_messages, daemon=True) 369 | self.active_thread.start() 370 | 371 | def __enter__(self): 372 | self.connect() 373 | return self 374 | 375 | def __exit__(self, exc_type, exc_value, exc_traceback): 376 | self.kill() 377 | 378 | def on_message_internal(self, message): 379 | message_id, content = message.split('|') 380 | if content == "MAP_BEGIN" or content == "MAP_END": 381 | self.on_state_updated(content) 382 | return 383 | # print("<<", content) 384 | if message_id in self.pending_messages.keys(): 385 | task, loop, timr = self.pending_messages[message_id] 386 | loop.call_soon_threadsafe(task.set_result, content) 387 | del self.pending_messages[message_id] 388 | timr.cancel() 389 | 390 | def receive_messages(self): 391 | while self.active: 392 | try: 393 | if self.sock is not None: 394 | message, address = self.sock.recvfrom(1024) 395 | message = message.decode("utf-8") 396 | self.on_message_internal(message) 397 | except socket.timeout: 398 | break 399 | 400 | def send(self, message: str): 401 | self.sock.sendto( 402 | f"NONE|{message}".encode("utf-8"), self.osu_game) 403 | 404 | def cancel_send_and_wait(self, m_id, value): 405 | if m_id in self.pending_messages.keys(): 406 | task, loop, timr = self.pending_messages[m_id] 407 | loop.call_soon_threadsafe(task.set_result, value) 408 | del self.pending_messages[m_id] 409 | 410 | async def send_and_wait(self, message: str, timeout_value="", timeout=10): 411 | global MESSAGES_SENT 412 | loop = asyncio.get_event_loop() 413 | task = asyncio.Future() 414 | message_id = f"{MESSAGES_SENT}" 415 | MESSAGES_SENT += 1 416 | self.pending_messages[message_id] = task, loop, Timer(timeout, self.cancel_send_and_wait, [ 417 | message_id, timeout_value]) 418 | self.pending_messages[message_id][2].start() 419 | self.sock.sendto( 420 | f"{message_id}|{message}".encode("utf-8"), self.osu_game) 421 | result = await task 422 | return result 423 | 424 | def kill(self): 425 | if self.active: 426 | target = self.sock 427 | self.sock = None 428 | target.settimeout(1) 429 | target.shutdown(SHUT_RDWR) 430 | target.close() 431 | self.active = False 432 | 433 | 434 | class ScreenRecorder(Thread): 435 | def __init__(self, fps: int = 30): 436 | super().__init__(group=None, daemon=True) 437 | self.fps = fps 438 | self.stop_event = Event() 439 | self.start() 440 | 441 | def stop(self): 442 | self.stop_event.set() 443 | 444 | def run(self): 445 | filename = f"{int(time.time() * 1000)}.avi" 446 | write_buff = Queue() 447 | with TemporaryDirectory() as record_dir: 448 | 449 | def write_frames(): 450 | frames_saved = 0 451 | 452 | frame = write_buff.get() 453 | 454 | while frame is not None: 455 | frame = np.array(frame) 456 | cv2.imwrite( 457 | path.join(record_dir, f'{frames_saved}.png'), frame) 458 | frames_saved += 1 459 | frame = write_buff.get() 460 | 461 | write_thread = Thread(target=write_frames, group=None, daemon=True) 462 | write_thread.start() 463 | 464 | with mss() as sct: 465 | while True: 466 | with FixedRuntime(1 / self.fps): 467 | write_buff.put(sct.grab(sct.monitors[1])) 468 | if self.stop_event.is_set(): 469 | break 470 | 471 | write_buff.put(None) 472 | write_thread.join() 473 | 474 | files = os.listdir(record_dir) 475 | files.sort(key=lambda a: int(a.split('.')[0])) 476 | 477 | source = cv2.VideoWriter_fourcc(*"MJPG") 478 | 479 | writer = cv2.VideoWriter( 480 | filename, source, float(self.fps), (1920, 1080)) 481 | 482 | for file in tqdm(files, desc=f"Generating Video from {len(files)} frames."): 483 | writer.write(cv2.imread(path.join(record_dir, file))) 484 | 485 | writer.release() 486 | -------------------------------------------------------------------------------- /assets/Oka_Custom.osk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TareHimself/osu-ai/63037c60369b7b8d3757472049a4463dd3fc8801/assets/Oka_Custom.osk -------------------------------------------------------------------------------- /assets/danser-settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "General": { 3 | "OsuSongsDir": "C:\\Users\\Taree\\AppData\\Local\\osu!\\Songs", 4 | "OsuSkinsDir": "C:\\Users\\Taree\\AppData\\Local\\osu!\\Skins", 5 | "OsuReplaysDir": "C:\\Users\\Taree\\AppData\\Local\\osu!\\Replays", 6 | "DiscordPresenceOn": true, 7 | "UnpackOszFiles": true, 8 | "VerboseImportLogs": false 9 | }, 10 | "Graphics": { 11 | "Width": 1920, 12 | "Height": 1080, 13 | "WindowWidth": 1440, 14 | "WindowHeight": 810, 15 | "Fullscreen": true, 16 | "VSync": false, 17 | "FPSCap": 0, 18 | "MSAA": 0, 19 | "ShowFPS": true, 20 | "Experimental": { 21 | "UsePersistentBuffers": false 22 | } 23 | }, 24 | "Audio": { 25 | "GeneralVolume": 0.5, 26 | "MusicVolume": 0.5, 27 | "SampleVolume": 0.5, 28 | "Offset": 0, 29 | "HitsoundPositionMultiplier": 1, 30 | "IgnoreBeatmapSamples": true, 31 | "IgnoreBeatmapSampleVolume": false, 32 | "PlayNightcoreSamples": true, 33 | "BeatScale": 1.2, 34 | "BeatUseTimingPoints": false, 35 | "Linux/Unix": { 36 | "BassPlaybackBufferLength": 100, 37 | "BassDeviceBufferLength": 10, 38 | "BassUpdatePeriod": 5, 39 | "BassDeviceUpdatePeriod": 10 40 | } 41 | }, 42 | "Input": { 43 | "LeftKey": "Z", 44 | "RightKey": "X", 45 | "RestartKey": "`", 46 | "SmokeKey": "C", 47 | "ScreenshotKey": "F2", 48 | "MouseButtonsDisabled": true, 49 | "MouseHighPrecision": false, 50 | "MouseSensitivity": 1 51 | }, 52 | "Gameplay": { 53 | "HitErrorMeter": { 54 | "Show": false, 55 | "Scale": 1, 56 | "Opacity": 1, 57 | "XOffset": 0, 58 | "YOffset": 0, 59 | "PointFadeOutTime": 10, 60 | "ShowPositionalMisses": true, 61 | "PositionalMissScale": 1.5, 62 | "ShowUnstableRate": true, 63 | "UnstableRateDecimals": 0, 64 | "UnstableRateScale": 1, 65 | "StaticUnstableRate": false, 66 | "ScaleWithSpeed": false 67 | }, 68 | "AimErrorMeter": { 69 | "Show": false, 70 | "Scale": 1, 71 | "Opacity": 1, 72 | "XPosition": 1350, 73 | "YPosition": 650, 74 | "PointFadeOutTime": 10, 75 | "DotScale": 1, 76 | "Align": "Right", 77 | "ShowUnstableRate": false, 78 | "UnstableRateScale": 1, 79 | "UnstableRateDecimals": 0, 80 | "StaticUnstableRate": false, 81 | "CapPositionalMisses": true, 82 | "AngleNormalized": false 83 | }, 84 | "Score": { 85 | "Show": true, 86 | "Scale": 1, 87 | "Opacity": 1, 88 | "XOffset": 0, 89 | "YOffset": 0, 90 | "ProgressBar": "Pie", 91 | "ShowGradeAlways": false, 92 | "StaticScore": false, 93 | "StaticAccuracy": false 94 | }, 95 | "HpBar": { 96 | "Show": true, 97 | "Scale": 1, 98 | "Opacity": 1, 99 | "XOffset": 0, 100 | "YOffset": 0 101 | }, 102 | "ComboCounter": { 103 | "Show": true, 104 | "Scale": 1, 105 | "Opacity": 1, 106 | "XOffset": 0, 107 | "YOffset": 0, 108 | "Static": false 109 | }, 110 | "PPCounter": { 111 | "Show": false, 112 | "Scale": 1, 113 | "Opacity": 1, 114 | "XPosition": 5, 115 | "YPosition": 150, 116 | "Color": { 117 | "Hue": 0, 118 | "Saturation": 0, 119 | "Value": 1 120 | }, 121 | "Decimals": 0, 122 | "Align": "CentreLeft", 123 | "ShowInResults": true, 124 | "ShowPPComponents": false, 125 | "Static": false 126 | }, 127 | "HitCounter": { 128 | "Show": false, 129 | "Scale": 1, 130 | "Opacity": 1, 131 | "XPosition": 5, 132 | "YPosition": 190, 133 | "Color300": { 134 | "Hue": 0, 135 | "Saturation": 0, 136 | "Value": 1 137 | }, 138 | "Color100": { 139 | "Hue": 0, 140 | "Saturation": 0, 141 | "Value": 1 142 | }, 143 | "Color50": { 144 | "Hue": 0, 145 | "Saturation": 0, 146 | "Value": 1 147 | }, 148 | "ColorMiss": { 149 | "Hue": 0, 150 | "Saturation": 0, 151 | "Value": 1 152 | }, 153 | "ColorSB": { 154 | "Hue": 0, 155 | "Saturation": 0, 156 | "Value": 1 157 | }, 158 | "Spacing": 48, 159 | "FontScale": 1, 160 | "Align": "Left", 161 | "ValueAlign": "Left", 162 | "Vertical": false, 163 | "Show300": false, 164 | "ShowSliderBreaks": false 165 | }, 166 | "StrainGraph": { 167 | "Show": false, 168 | "Opacity": 1, 169 | "XPosition": 5, 170 | "YPosition": 310, 171 | "Align": "BottomLeft", 172 | "Width": 130, 173 | "Height": 70, 174 | "BgColor": { 175 | "Hue": 0, 176 | "Saturation": 0, 177 | "Value": 0.2 178 | }, 179 | "FgColor": { 180 | "Hue": 297, 181 | "Saturation": 0.4, 182 | "Value": 0.92 183 | }, 184 | "Outline": { 185 | "Show": false, 186 | "Width": 2, 187 | "InnerDarkness": 0.5, 188 | "InnerOpacity": 0.5 189 | } 190 | }, 191 | "KeyOverlay": { 192 | "Show": true, 193 | "Scale": 1, 194 | "Opacity": 1, 195 | "XOffset": 0, 196 | "YOffset": 0 197 | }, 198 | "ScoreBoard": { 199 | "Show": false, 200 | "Scale": 1, 201 | "Opacity": 1, 202 | "XOffset": 0, 203 | "YOffset": 0, 204 | "ModsOnly": false, 205 | "AlignRight": false, 206 | "HideOthers": false, 207 | "ShowAvatars": false, 208 | "ExplosionScale": 1 209 | }, 210 | "Mods": { 211 | "Show": false, 212 | "Scale": 1, 213 | "Opacity": 1, 214 | "XOffset": 0, 215 | "YOffset": 0, 216 | "HideInReplays": false, 217 | "FoldInReplays": false, 218 | "AdditionalSpacing": 0 219 | }, 220 | "Boundaries": { 221 | "Enabled": false, 222 | "BorderThickness": 1, 223 | "BorderFill": 1, 224 | "BorderColor": { 225 | "Hue": 0, 226 | "Saturation": 0, 227 | "Value": 1 228 | }, 229 | "BorderOpacity": 1, 230 | "BackgroundColor": { 231 | "Hue": 0, 232 | "Saturation": 1, 233 | "Value": 0 234 | }, 235 | "BackgroundOpacity": 0.5 236 | }, 237 | "Underlay": { 238 | "Path": "", 239 | "AboveHpBar": false 240 | }, 241 | "HUDFont": "", 242 | "ShowResultsScreen": true, 243 | "ResultsScreenTime": 5, 244 | "ResultsUseLocalTimeZone": false, 245 | "ShowWarningArrows": true, 246 | "ShowHitLighting": false, 247 | "FlashlightDim": 1, 248 | "PlayUsername": "Guest", 249 | "IgnoreFailsInReplays": false, 250 | "UseLazerPP": false 251 | }, 252 | "Skin": { 253 | "CurrentSkin": "Moonshine 2.0 [DT]", 254 | "FallbackSkin": "Moonshine 2.0 [DT]", 255 | "UseColorsFromSkin": true, 256 | "UseBeatmapColors": false, 257 | "Cursor": { 258 | "UseSkinCursor": true, 259 | "Scale": 0, 260 | "TrailScale": 0, 261 | "ForceLongTrail": false, 262 | "LongTrailLength": 2048, 263 | "LongTrailDensity": 1 264 | } 265 | }, 266 | "Cursor": { 267 | "TrailStyle": 1, 268 | "Style23Speed": 0.18, 269 | "Style4Shift": 0.5, 270 | "Colors": { 271 | "EnableRainbow": false, 272 | "RainbowSpeed": 8, 273 | "BaseColor": { 274 | "Hue": 0, 275 | "Saturation": 1, 276 | "Value": 1 277 | }, 278 | "EnableCustomHueOffset": false, 279 | "HueOffset": 0, 280 | "FlashToTheBeat": false, 281 | "FlashAmplitude": 0 282 | }, 283 | "EnableCustomTagColorOffset": true, 284 | "TagColorOffset": -36, 285 | "EnableTrailGlow": true, 286 | "EnableCustomTrailGlowOffset": true, 287 | "TrailGlowOffset": -36, 288 | "ScaleToCS": false, 289 | "CursorSize": 12, 290 | "CursorExpand": false, 291 | "ScaleToTheBeat": false, 292 | "ShowCursorsOnBreaks": true, 293 | "BounceOnEdges": false, 294 | "TrailScale": 1, 295 | "TrailEndScale": 0.4, 296 | "TrailDensity": 1, 297 | "TrailMaxLength": 2000, 298 | "TrailRemoveSpeed": 1, 299 | "GlowEndScale": 0.4, 300 | "InnerLengthMult": 0.9, 301 | "AdditiveBlending": true, 302 | "CursorRipples": true, 303 | "SmokeEnabled": true 304 | }, 305 | "Objects": { 306 | "DrawApproachCircles": true, 307 | "DrawComboNumbers": true, 308 | "DrawFollowPoints": true, 309 | "LoadSpinners": true, 310 | "ScaleToTheBeat": false, 311 | "StackEnabled": true, 312 | "Sliders": { 313 | "ForceSliderBallTexture": true, 314 | "DrawEndCircles": true, 315 | "DrawSliderFollowCircle": true, 316 | "DrawScorePoints": true, 317 | "SliderMerge": false, 318 | "BorderWidth": 1, 319 | "Distortions": { 320 | "Enabled": false, 321 | "ViewportSize": 0, 322 | "UseCustomResolution": false, 323 | "CustomResolutionX": 1920, 324 | "CustomResolutionY": 1080 325 | }, 326 | "Snaking": { 327 | "In": false, 328 | "Out": false, 329 | "OutFadeInstant": true, 330 | "DurationMultiplier": 0, 331 | "FadeMultiplier": 0 332 | } 333 | }, 334 | "Colors": { 335 | "MandalaTexturesTrigger": 5, 336 | "MandalaTexturesAlpha": 0.3, 337 | "Color": { 338 | "EnableRainbow": false, 339 | "RainbowSpeed": 8, 340 | "BaseColor": { 341 | "Hue": 0, 342 | "Saturation": 1, 343 | "Value": 1 344 | }, 345 | "EnableCustomHueOffset": false, 346 | "HueOffset": 0, 347 | "FlashToTheBeat": false, 348 | "FlashAmplitude": 100 349 | }, 350 | "UseComboColors": false, 351 | "ComboColors": [ 352 | { 353 | "Hue": 0, 354 | "Saturation": 1, 355 | "Value": 1 356 | } 357 | ], 358 | "UseSkinComboColors": false, 359 | "UseBeatmapComboColors": false, 360 | "Sliders": { 361 | "WhiteScorePoints": false, 362 | "ScorePointColorOffset": 0, 363 | "SliderBallTint": false, 364 | "Border": { 365 | "UseHitCircleColor": false, 366 | "Color": { 367 | "EnableRainbow": false, 368 | "RainbowSpeed": 8, 369 | "BaseColor": { 370 | "Hue": 0, 371 | "Saturation": 0, 372 | "Value": 1 373 | }, 374 | "EnableCustomHueOffset": false, 375 | "HueOffset": 0, 376 | "FlashToTheBeat": false, 377 | "FlashAmplitude": 100 378 | }, 379 | "EnableCustomGradientOffset": false, 380 | "CustomGradientOffset": 0 381 | }, 382 | "Body": { 383 | "UseHitCircleColor": false, 384 | "Color": { 385 | "EnableRainbow": false, 386 | "RainbowSpeed": 8, 387 | "BaseColor": { 388 | "Hue": 0, 389 | "Saturation": 1, 390 | "Value": 0 391 | }, 392 | "EnableCustomHueOffset": false, 393 | "HueOffset": 0, 394 | "FlashToTheBeat": false, 395 | "FlashAmplitude": 100 396 | }, 397 | "InnerOffset": -0.5, 398 | "OuterOffset": -0.05, 399 | "InnerAlpha": 0.8, 400 | "OuterAlpha": 0.8 401 | } 402 | } 403 | } 404 | }, 405 | "Playfield": { 406 | "DrawObjects": true, 407 | "DrawCursors": true, 408 | "Scale": 1, 409 | "OsuShift": false, 410 | "ShiftX": 0, 411 | "ShiftY": 0, 412 | "ScaleStoryboardWithPlayfield": false, 413 | "MoveStoryboardWithPlayfield": false, 414 | "LeadInTime": 0, 415 | "LeadInHold": 0, 416 | "FadeOutTime": 0, 417 | "SeizureWarning": { 418 | "Enabled": false, 419 | "Duration": 5 420 | }, 421 | "Background": { 422 | "LoadStoryboards": true, 423 | "LoadVideos": false, 424 | "FlashToTheBeat": false, 425 | "Dim": { 426 | "Intro": 1, 427 | "Normal": 1, 428 | "Breaks": 0.5 429 | }, 430 | "Parallax": { 431 | "Enabled": false, 432 | "Amount": 0.1, 433 | "Speed": 0.5 434 | }, 435 | "Blur": { 436 | "Enabled": false, 437 | "Values": { 438 | "Intro": 0, 439 | "Normal": 0.6, 440 | "Breaks": 0.3 441 | } 442 | }, 443 | "Triangles": { 444 | "Enabled": false, 445 | "Shadowed": true, 446 | "DrawOverBlur": true, 447 | "ParallaxMultiplier": 0.5, 448 | "Density": 1, 449 | "Scale": 1, 450 | "Speed": 1 451 | } 452 | }, 453 | "Logo": { 454 | "Enabled": false, 455 | "DrawSpectrum": false, 456 | "Dim": { 457 | "Intro": 0, 458 | "Normal": 1, 459 | "Breaks": 1 460 | } 461 | }, 462 | "Bloom": { 463 | "Enabled": false, 464 | "BloomToTheBeat": true, 465 | "BloomBeatAddition": 0.3, 466 | "Threshold": 0, 467 | "Blur": 0.6, 468 | "Power": 0.7 469 | } 470 | }, 471 | "CursorDance": { 472 | "Movers": [ 473 | { 474 | "Mover": "linear", 475 | "SliderDance": false, 476 | "RandomSliderDance": false 477 | } 478 | ], 479 | "Spinners": [ 480 | { 481 | "Mover": "circle", 482 | "CenterOffsetX": 0, 483 | "CenterOffsetY": 0, 484 | "Radius": 100 485 | } 486 | ], 487 | "ComboTag": false, 488 | "Battle": false, 489 | "DoSpinnersTogether": true, 490 | "TAGSliderDance": false, 491 | "MoverSettings": { 492 | "Bezier": [ 493 | { 494 | "Aggressiveness": 60, 495 | "SliderAggressiveness": 3 496 | } 497 | ], 498 | "Flower": [ 499 | { 500 | "AngleOffset": 90, 501 | "DistanceMult": 0.666, 502 | "StreamAngleOffset": 90, 503 | "LongJump": -1, 504 | "LongJumpMult": 0.7, 505 | "LongJumpOnEqualPos": false 506 | } 507 | ], 508 | "HalfCircle": [ 509 | { 510 | "RadiusMultiplier": 1, 511 | "StreamTrigger": 130 512 | } 513 | ], 514 | "Spline": [ 515 | { 516 | "RotationalForce": false, 517 | "StreamHalfCircle": true, 518 | "StreamWobble": true, 519 | "WobbleScale": 0.67 520 | } 521 | ], 522 | "Momentum": [ 523 | { 524 | "SkipStackAngles": false, 525 | "StreamRestrict": true, 526 | "DurationMult": 2, 527 | "DurationTrigger": 500, 528 | "StreamMult": 0.7, 529 | "RestrictAngle": 90, 530 | "RestrictArea": 40, 531 | "RestrictInvert": true, 532 | "DistanceMult": 0.6, 533 | "DistanceMultOut": 0.45 534 | } 535 | ], 536 | "ExGon": [ 537 | { 538 | "Delay": 50 539 | } 540 | ], 541 | "Linear": [ 542 | { 543 | "WaitForPreempt": true, 544 | "ReactionTime": 100, 545 | "ChoppyLongObjects": false 546 | } 547 | ], 548 | "Pippi": [ 549 | { 550 | "RotationSpeed": 1.6, 551 | "RadiusMultiplier": 0.98, 552 | "SpinnerRadius": 100 553 | } 554 | ] 555 | } 556 | }, 557 | "Knockout": { 558 | "Mode": 0, 559 | "GraceEndTime": -10, 560 | "BubbleMinimumCombo": 200, 561 | "ExcludeMods": "", 562 | "HideMods": "", 563 | "MaxPlayers": 50, 564 | "MinPlayers": 1, 565 | "RevivePlayersAtEnd": false, 566 | "LiveSort": true, 567 | "SortBy": "Score", 568 | "HideOverlayOnBreaks": false, 569 | "MinCursorSize": 3, 570 | "MaxCursorSize": 7, 571 | "AddDanser": false, 572 | "DanserName": "danser" 573 | }, 574 | "Recording": { 575 | "FrameWidth": 1920, 576 | "FrameHeight": 1080, 577 | "FPS": 100, 578 | "EncodingFPSCap": 0, 579 | "Encoder": "hevc_nvenc", 580 | "libx264": { 581 | "RateControl": "crf", 582 | "Bitrate": "10M", 583 | "CRF": 14, 584 | "Profile": "high", 585 | "Preset": "faster", 586 | "AdditionalOptions": "" 587 | }, 588 | "libx265": { 589 | "RateControl": "crf", 590 | "Bitrate": "10M", 591 | "CRF": 18, 592 | "Preset": "fast", 593 | "AdditionalOptions": "" 594 | }, 595 | "h264_nvenc": { 596 | "RateControl": "cq", 597 | "Bitrate": "10M", 598 | "CQ": 22, 599 | "Profile": "high", 600 | "Preset": "p7", 601 | "AdditionalOptions": "" 602 | }, 603 | "hevc_nvenc": { 604 | "RateControl": "cq", 605 | "Bitrate": "10M", 606 | "CQ": 24, 607 | "Preset": "p1", 608 | "AdditionalOptions": "" 609 | }, 610 | "h264_qsv": { 611 | "RateControl": "icq", 612 | "Bitrate": "10M", 613 | "Quality": 15, 614 | "Profile": "high", 615 | "Preset": "slow", 616 | "AdditionalOptions": "" 617 | }, 618 | "hevc_qsv": { 619 | "RateControl": "icq", 620 | "Bitrate": "10M", 621 | "Quality": 20, 622 | "Preset": "slow", 623 | "AdditionalOptions": "" 624 | }, 625 | "custom": { 626 | "CustomOptions": "" 627 | }, 628 | "PixelFormat": "yuv420p", 629 | "Filters": "", 630 | "AudioCodec": "aac", 631 | "aac": { 632 | "Bitrate": "192k", 633 | "AdditionalOptions": "" 634 | }, 635 | "libmp3lame": { 636 | "RateControl": "abr", 637 | "TargetBitrate": "192k", 638 | "AdditionalOptions": "" 639 | }, 640 | "libopus": { 641 | "RateControl": "vbr", 642 | "TargetBitrate": "192k", 643 | "AdditionalOptions": "" 644 | }, 645 | "flac": { 646 | "CompressionLevel": 12, 647 | "AdditionalOptions": "" 648 | }, 649 | "customAudio": { 650 | "CustomOptions": "" 651 | }, 652 | "AudioFilters": "", 653 | "OutputDir": "videos", 654 | "Container": "mkv", 655 | "ShowFFmpegLogs": true, 656 | "MotionBlur": { 657 | "Enabled": false, 658 | "OversampleMultiplier": 16, 659 | "BlendFrames": 24, 660 | "BlendFunctionID": 27, 661 | "GaussWeightsMult": 1.5 662 | } 663 | } 664 | } -------------------------------------------------------------------------------- /assets/good-play-autopilot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TareHimself/osu-ai/63037c60369b7b8d3757472049a4463dd3fc8801/assets/good-play-autopilot.png -------------------------------------------------------------------------------- /assets/good-play-relax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TareHimself/osu-ai/63037c60369b7b8d3757472049a4463dd3fc8801/assets/good-play-relax.png -------------------------------------------------------------------------------- /assets/skins/AngeLMegumin (RemakeNeuroTest) [AI] (Unknown).osk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TareHimself/osu-ai/63037c60369b7b8d3757472049a4463dd3fc8801/assets/skins/AngeLMegumin (RemakeNeuroTest) [AI] (Unknown).osk -------------------------------------------------------------------------------- /assets/skins/AngeLMegumin (RemakeNeuroTest)(With Cursor) [MegumiWithCursor] (Unknown).osk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TareHimself/osu-ai/63037c60369b7b8d3757472049a4463dd3fc8801/assets/skins/AngeLMegumin (RemakeNeuroTest)(With Cursor) [MegumiWithCursor] (Unknown).osk -------------------------------------------------------------------------------- /assets/skins/Moonshine 2.0 [Eclipse].osk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TareHimself/osu-ai/63037c60369b7b8d3757472049a4463dd3fc8801/assets/skins/Moonshine 2.0 [Eclipse].osk -------------------------------------------------------------------------------- /assets/skins/Oka Custom [Oka Custom (No Cursor)] .osk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TareHimself/osu-ai/63037c60369b7b8d3757472049a4463dd3fc8801/assets/skins/Oka Custom [Oka Custom (No Cursor)] .osk -------------------------------------------------------------------------------- /assets/skins/Oka Custom [Oka_Custom] .osk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TareHimself/osu-ai/63037c60369b7b8d3757472049a4463dd3fc8801/assets/skins/Oka Custom [Oka_Custom] .osk -------------------------------------------------------------------------------- /experiments/rl_test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import random 3 | import gymnasium 4 | from rl.agent import OsuAgent 5 | from rl.env import OsuEnviroment 6 | 7 | env = OsuEnviroment() 8 | state = env.reset() 9 | try: 10 | while True: 11 | done_with_episode = False 12 | while not done_with_episode: 13 | done_with_episode = env.step() 14 | time.sleep(0.1) 15 | env.reset() 16 | except KeyboardInterrupt: 17 | env.agent.kill() 18 | -------------------------------------------------------------------------------- /experiments/rt.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | import gymnasium as gym 3 | env = gym.make("LunarLander-v2", render_mode="human") 4 | observation, info = env.reset() 5 | 6 | for _ in range(1000): 7 | # agent policy that uses the observation and info 8 | action = env.action_space.sample_mouse() 9 | observation, reward, terminated, truncated, info = env.step(action) 10 | 11 | if terminated or truncated: 12 | observation, info = env.reset() 13 | 14 | env.close() 15 | -------------------------------------------------------------------------------- /experiments/s_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import traceback 4 | from threading import Thread 5 | 6 | 7 | class WatchFileContent(Thread): 8 | def __init__(self, file_path, callback,poll_frequency=0.05): 9 | super().__init__(group=None, daemon=True) 10 | self.file_path = file_path 11 | self.callback = callback 12 | self.freq = poll_frequency 13 | self.callback(open(self.file_path).readlines()) 14 | 15 | def run(self): 16 | modifiedOn = os.path.getmtime(self.file_path) 17 | try: 18 | while True: 19 | time.sleep(self.freq) 20 | modified = os.path.getmtime(self.file_path) 21 | if modified != modifiedOn: 22 | modifiedOn = modified 23 | self.callback(open(self.file_path).readlines()) 24 | except Exception as e: 25 | print(traceback.format_exc()) 26 | 27 | 28 | # def on_file_modified(lines): 29 | # print(lines) 30 | 31 | def on_left_state_modified(lines): 32 | print("Left Button:","DOWN" if lines[0] == "1" else "UP") 33 | 34 | def on_right_state_modified(lines): 35 | print("Right Button:", "DOWN" if lines[0] == "1" else "UP") 36 | 37 | # WatchFileContent(r'C:\Users\Taree\Pictures\accuracy.txt', 38 | # on_file_modified).start() 39 | 40 | WatchFileContent(r'C:\Users\Taree\Pictures\Action RightButton.txt', 41 | on_left_state_modified,0.01).start() 42 | 43 | WatchFileContent(r'C:\Users\Taree\Pictures\Action LeftButton.txt', 44 | on_right_state_modified, 0.01).start() 45 | 46 | while True: 47 | time.sleep(10) 48 | -------------------------------------------------------------------------------- /experiments/socket_test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import asyncio 3 | 4 | from ai.utils import OsuSocketServer 5 | 6 | server = OsuSocketServer() 7 | while True: 8 | time.sleep(0.01) 9 | print("\tGame Time: ", asyncio.run( 10 | server.send_and_wait("time")), " ", end="\r") if server.osu_game is not None else print("\tClient hasnt connected ", end="\r") 11 | -------------------------------------------------------------------------------- /experiments/test_1.py: -------------------------------------------------------------------------------- 1 | from ai.constants import PLAY_AREA_CAPTURE_PARAMS 2 | import numpy as np 3 | import win32gui 4 | import win32ui 5 | import win32con 6 | import time 7 | import cv2 8 | 9 | hwnd = win32gui.GetDesktopWindow() 10 | 11 | stack_num = 10 12 | stack_interval = 0.001 13 | 14 | width = int(PLAY_AREA_CAPTURE_PARAMS[0]) 15 | height = int(PLAY_AREA_CAPTURE_PARAMS[1]) 16 | wDC = win32gui.GetWindowDC(hwnd) 17 | dcObj = win32ui.CreateDCFromHandle(wDC) 18 | cDC = dcObj.CreateCompatibleDC() 19 | dataBitMap = win32ui.CreateBitmap() 20 | dataBitMap.CreateCompatibleBitmap(dcObj, width, height) 21 | frames = [] 22 | 23 | for i in range(stack_num): 24 | cDC.SelectObject(dataBitMap) 25 | cDC.BitBlt((0, 0), (width, height), dcObj, 26 | (PLAY_AREA_CAPTURE_PARAMS[2], PLAY_AREA_CAPTURE_PARAMS[3]), win32con.SRCCOPY) 27 | time.sleep(stack_interval) 28 | # convert the raw data into a format opencv can read 29 | #dataBitMap.SaveBitmapFile(cDC, 'debug.bmp') 30 | signedIntsArray = dataBitMap.GetBitmapBits(True) 31 | frame = np.frombuffer(signedIntsArray, dtype='uint8') 32 | frame.shape = (height, width, 4) 33 | frame = np.ascontiguousarray(frame[..., :3]) 34 | print(frame.shape, frame.dtype) 35 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 36 | print(frame.shape) 37 | 38 | frames.append(frame) 39 | 40 | # free resources 41 | dcObj.DeleteDC() 42 | cDC.DeleteDC() 43 | win32gui.ReleaseDC(hwnd, wDC) 44 | win32gui.DeleteObject(dataBitMap.GetHandle()) 45 | print("Before stack", frames[0].shape) 46 | 47 | final = np.stack(frames) 48 | print("After stack", final[0].shape) 49 | 50 | cv2.imshow('Debug', final[len(final) - 3: len(final)].transpose(1, 2, 0)) 51 | cv2.waitKey(10000) 52 | # cv2.imwrite('debug.png', ) 53 | print(final.shape, final[0].shape, width, height) 54 | 55 | a = np.array([[1, 2, 3], [4, 5, 6]]) 56 | b = np.array([[7, 8, 9], [10, 11, 12]]) 57 | c = np.stack([a, a, b, b]) 58 | print(c.shape, c[0].shape, a.shape, b.shape) 59 | -------------------------------------------------------------------------------- /experiments/test_a.py: -------------------------------------------------------------------------------- 1 | from ai.eval import AimThread 2 | from ai.utils import FixedRuntime 3 | 4 | AI = AimThread( 5 | model_id=f'D:\Github\osu-ai\models\model_aim_body floating_20-02-23-19-58-01.pt') 6 | # AC = ActionsThread( 7 | # model_path=f'D:\Github\osu-ai\models\model_action_body floating_20-02-23-21-45-00.pt') 8 | # socket_server = OsuSocketServer(on_state_updated=lambda a: a) 9 | # socket_server.send('save,test,start,0.01') 10 | try: 11 | while True: 12 | with FixedRuntime(0.01): 13 | pass 14 | 15 | except KeyboardInterrupt: 16 | # socket_server.send('save,test,stop,0.01') 17 | pass 18 | -------------------------------------------------------------------------------- /experiments/test_b.py: -------------------------------------------------------------------------------- 1 | import time 2 | from ai.play_buggy import start_play 3 | 4 | start_play() 5 | -------------------------------------------------------------------------------- /experiments/test_c.py: -------------------------------------------------------------------------------- 1 | from windows import WindowCapture 2 | import cv2 3 | from ai.utils import FixedRuntime 4 | from collections import deque 5 | cap = WindowCapture() 6 | previous = deque(maxlen=3) 7 | result = cv2.VideoWriter('filename.avi', 8 | cv2.VideoWriter_fourcc(*'MJPG'), 9 | 20, (1920, 1080)) 10 | try: 11 | while True: 12 | with FixedRuntime(0.0167): 13 | frame, stacked = cap.capture( 14 | prev_frames=list(previous), stack_num=3) 15 | stacked = stacked.transpose(1, 2, 0) 16 | result.write(stacked) 17 | previous.append(frame) 18 | cv2.imshow("Debug", stacked) 19 | cv2.waitKey(1) 20 | 21 | except KeyboardInterrupt: 22 | result.release() 23 | -------------------------------------------------------------------------------- /experiments/test_d.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | 4 | # try: 5 | # with ScreenRecorder(20) as s: 6 | # while True: 7 | # time.sleep(5) 8 | 9 | # except Exception as e: 10 | # print(e) 11 | # print(traceback.print_exc()) 12 | 13 | 14 | # def load_video(location): 15 | # vidcap = cv2.VideoCapture(location) 16 | # return (vidcap, int(os.path.getctime(location) * 1000)) 17 | 18 | 19 | # video, created = load_video('1678631402353.avi') 20 | 21 | 22 | # def get_frame(time_seconds): 23 | # global video 24 | # video.set(cv2.CAP_PROP_POS_MSEC, time_seconds) 25 | # hasFrames, image = video.read() 26 | # if hasFrames: 27 | # return image 28 | # return None 29 | 30 | 31 | # record_time = 1678630643467 32 | # cur_time = 1678630643467 33 | # try: 34 | # while True: 35 | # cv2.imshow("Debug", get_frame((cur_time - record_time))) 36 | # cv2.waitKey(100) 37 | # cur_time += 1 38 | 39 | # except Exception as e: 40 | # print(e) 41 | # print(traceback.print_exc()) 42 | 43 | 44 | # with open('sample.txt', 'r') as f: 45 | # for line in f.readlines(): 46 | # timestamp, data = line.split('|') 47 | 48 | # k1, k2, mx, my = data.split(',') 49 | 50 | # mx = int((int(mx) / 1920) * 1280) 51 | 52 | # my = int((int(my) / 1080) * 720) 53 | 54 | # timestamp = int(timestamp) 55 | 56 | # target_secs = (1678631412546 - 1678631402353) 57 | # print(target_secs) 58 | # frame = get_frame(target_secs) 59 | 60 | # frame_with_circle = cv2.circle( 61 | # frame, (mx, my), 20, (255, 0, 0), 10) 62 | # cv2.imshow("Debug", frame_with_circle) 63 | # cv2.waitKey(0) 64 | 65 | 66 | # # sample = get_frame(1*1000) 67 | # # print(created) 68 | # # cv2.imshow("Image", sample) 69 | # # cv2.waitKey(0) 70 | 71 | FILES_PATH = os.path.join(os.getcwd(), 'pending-capture') 72 | for file in os.listdir(FILES_PATH): 73 | data, ext = file.split('.') 74 | projId, data = data.split('-') 75 | 76 | timestamp, k1, k2, mx, my = data.split(',') 77 | 78 | mx = int(mx) 79 | 80 | my = int(my) 81 | 82 | timestamp = int(timestamp) 83 | 84 | frame = cv2.imread(os.path.join(FILES_PATH, file)) 85 | 86 | frame_with_circle = cv2.circle( 87 | frame, (mx, my), 20, (255, 0, 0), 10) 88 | cv2.imshow("Debug", frame_with_circle) 89 | cv2.waitKey(1) 90 | -------------------------------------------------------------------------------- /experiments/test_e.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | from tqdm import tqdm 4 | 5 | # try: 6 | # with ScreenRecorder(20) as s: 7 | # while True: 8 | # time.sleep(5) 9 | 10 | # except Exception as e: 11 | # print(e) 12 | # print(traceback.print_exc()) 13 | 14 | 15 | # def load_video(location): 16 | # vidcap = cv2.VideoCapture(location) 17 | # return (vidcap, int(os.path.getctime(location) * 1000)) 18 | 19 | 20 | # video, created = load_video('1678631402353.avi') 21 | 22 | 23 | # def get_frame(time_seconds): 24 | # global video 25 | # video.set(cv2.CAP_PROP_POS_MSEC, time_seconds) 26 | # hasFrames, image = video.read() 27 | # if hasFrames: 28 | # return image 29 | # return None 30 | 31 | 32 | # record_time = 1678630643467 33 | # cur_time = 1678630643467 34 | # try: 35 | # while True: 36 | # cv2.imshow("Debug", get_frame((cur_time - record_time))) 37 | # cv2.waitKey(100) 38 | # cur_time += 1 39 | 40 | # except Exception as e: 41 | # print(e) 42 | # print(traceback.print_exc()) 43 | 44 | 45 | # with open('sample.txt', 'r') as f: 46 | # for line in f.readlines(): 47 | # timestamp, data = line.split('|') 48 | 49 | # k1, k2, mx, my = data.split(',') 50 | 51 | # mx = int((int(mx) / 1920) * 1280) 52 | 53 | # my = int((int(my) / 1080) * 720) 54 | 55 | # timestamp = int(timestamp) 56 | 57 | # target_secs = (1678631412546 - 1678631402353) 58 | # print(target_secs) 59 | # frame = get_frame(target_secs) 60 | 61 | # frame_with_circle = cv2.circle( 62 | # frame, (mx, my), 20, (255, 0, 0), 10) 63 | # cv2.imshow("Debug", frame_with_circle) 64 | # cv2.waitKey(0) 65 | 66 | 67 | # # sample = get_frame(1*1000) 68 | # # print(created) 69 | # # cv2.imshow("Image", sample) 70 | # # cv2.waitKey(0) 71 | 72 | FILES_PATH = os.path.join(os.getcwd(), 'pending-capture') 73 | COMPRESSED_PATH = os.path.join(os.getcwd(), 'source.zip') 74 | zip = zipfile.ZipFile(COMPRESSED_PATH, "w", zipfile.ZIP_DEFLATED) 75 | 76 | for file in tqdm(os.listdir(FILES_PATH), desc="Zipping up source files"): 77 | zip.write(os.path.join(FILES_PATH, file), file) 78 | # data, ext = file.split('.') 79 | # projId, data = data.split('-') 80 | 81 | # timestamp, k1, k2, mx, my = data.split(',') 82 | 83 | # mx = int(mx) 84 | 85 | # my = int(my) 86 | 87 | # timestamp = int(timestamp) 88 | 89 | # frame = cv2.imread(os.path.join(FILES_PATH, file)) 90 | 91 | # frame_with_circle = cv2.circle( 92 | # frame, (mx, my), 20, (255, 0, 0), 10) 93 | # cv2.imshow("Debug", frame_with_circle) 94 | # cv2.waitKey(1) 95 | 96 | zip.close() 97 | -------------------------------------------------------------------------------- /experiments/test_f.py: -------------------------------------------------------------------------------- 1 | from ai.dataset import OsuDataset 2 | 3 | 4 | dataset = OsuDataset(['test', 'test', 's']) 5 | -------------------------------------------------------------------------------- /experiments/torch_t.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | a = torch.Tensor([[5, 2, 3], [1, 6, 3], [1, 2, 7]]) 5 | 6 | # torch.Tensor([0, 1, 2]).type(torch.LongTensor) 7 | b = np.arange(3, dtype=np.int32) 8 | 9 | c = torch.Tensor([0, 1, 2]).type(torch.LongTensor) 10 | 11 | print(b, c, a[b, c]) 12 | -------------------------------------------------------------------------------- /experiments/vel_test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import cv2 3 | from collections import deque 4 | from windows import WindowCapture 5 | 6 | cap = WindowCapture("osu!") 7 | STACK_NUM = 3 8 | frame_history = deque(maxlen=3) 9 | while True: 10 | frame, stacked = cap.capture(list(frame_history), STACK_NUM) 11 | frame_history.append(frame) 12 | cv2.imshow("Debug", stacked) 13 | cv2.waitKey(1) 14 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | if __name__ == '__main__': 3 | 4 | from ai.utils import get_validated_input 5 | 6 | USER_MAIN_MENU = """What would you like to do ? 7 | [0] Train or finetune a model 8 | [1] Convert a video and json into a dataset 9 | [2] Test a model 10 | [3] Quit 11 | """ 12 | 13 | QUIT_CHOICE = 3 14 | 15 | 16 | def get_input(): 17 | input_as_str = input(USER_MAIN_MENU).strip() 18 | if not input_as_str.isnumeric(): 19 | return get_input() 20 | 21 | return int(input_as_str) 22 | 23 | 24 | def run(): 25 | get_input_params = [USER_MAIN_MENU, lambda a: a.strip().isnumeric() and ( 26 | 0 <= int(a.strip()) <= 3), lambda a: int(a.strip())] 27 | 28 | user_choice = get_validated_input(*get_input_params) 29 | 30 | while user_choice != QUIT_CHOICE: 31 | if user_choice == 0: 32 | from ai.train import start_train 33 | start_train() 34 | elif user_choice == 1: 35 | from ai.convert import start_convert 36 | start_convert() 37 | elif user_choice == 2: 38 | from ai.play import start_play 39 | start_play() 40 | 41 | user_choice = get_validated_input(*get_input_params) 42 | 43 | run() 44 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "osu-ai" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["Tare Ebelo <75279482+TareHimself@users.noreply.github.com>"] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.9" 10 | 11 | keyboard= "^0.13.5" 12 | numpy= "^1.24.3" 13 | torch= "^2.0.1" 14 | torchvision= "^0.15.2" 15 | tqdm= "^4.64.1" 16 | opencv-python= "^4.7.0.72" 17 | mss= "^9.0.1" 18 | timm= "^0.9.2" 19 | mouse = "^0.7.1" 20 | 21 | [tool.poetry.dev-dependencies] 22 | poethepoet = "^0.20.0" 23 | 24 | [tool.poe.tasks] 25 | uninstall-torch = "pip uninstall -y torch torchvision" 26 | install-torch-cuda = "pip install torch torchvision --index-url https://download.pytorch.org/whl/cu117" 27 | force-cuda = ["uninstall-torch","install-torch-cuda"] 28 | use-win32 = "pip install pywin32" 29 | 30 | 31 | [build-system] 32 | requires = ["poetry-core"] 33 | build-backend = "poetry.core.masonry.api" 34 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gymnasium==0.28.0 2 | keyboard==0.13.5 3 | numpy~=1.24.3 4 | opencv_python==4.7.0.72 5 | torch~=2.0.1+cu117 6 | torchvision~=0.15.2+cu117 7 | tqdm==4.64.1 8 | opencv-python~=4.7.0.72 9 | mss~=9.0.1 10 | timm~=0.9.2 -------------------------------------------------------------------------------- /rl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TareHimself/osu-ai/63037c60369b7b8d3757472049a4463dd3fc8801/rl/__init__.py -------------------------------------------------------------------------------- /rl/agent.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import win32gui 3 | import win32ui 4 | import win32con 5 | import asyncio 6 | import cv2 7 | import keyboard 8 | import time 9 | from queue import Queue 10 | from torchvision import transforms 11 | from ai.utils import OsuSocketServer 12 | from threading import Thread 13 | from ai.constants import PLAY_AREA_CAPTURE_PARAMS, FINAL_RESIZE_PERCENT 14 | 15 | ConvertToTensor = transforms.ToTensor() 16 | 17 | MAX_MEMORY = 100_000 18 | BATCH_SIZE = 32 19 | LR = 0.001 20 | 21 | 22 | class OsuAgentState: 23 | WAITING_FOR_MAP = "[WAITING FOR MAP]" 24 | WAITING_FOR_OPENING = "[WAITING FOR OP]" 25 | PLAYING_MAP = "[PLAYING]" 26 | 27 | 28 | class OsuAgent: 29 | 30 | def __init__(self, stacks): 31 | self.stacks = stacks 32 | self.sock = OsuSocketServer(self.on_map_state_updated) 33 | self.hwnd = win32gui.FindWindow(None, "osu! (development)") 34 | self.buff = Queue() 35 | self.state = OsuAgentState.WAITING_FOR_MAP 36 | Thread(target=self.draw, group=None, daemon=True).start() 37 | 38 | def on_map_state_updated(self, state: str): 39 | 40 | if state == "MAP_BEGIN": 41 | self.update_state(OsuAgentState.WAITING_FOR_OPENING) 42 | 43 | if state == "MAP_END": 44 | self.update_state(OsuAgentState.WAITING_FOR_MAP) 45 | 46 | def update_state(self, newState: str): 47 | initial = self.state 48 | self.state = newState 49 | print("State changed from", initial, "To", self.state) 50 | 51 | def draw(self): 52 | # while True: 53 | # stacked = self.buff.get() 54 | # if stacked is None: 55 | # break 56 | # cv2.imshow("Debug", stacked.transpose(1, 2, 0)) 57 | # cv2.waitKey(100) 58 | pass 59 | 60 | def capture_frames(self, stack_num=3, stack_interval=0.01, resize: tuple[int, int] = (1920, 1080)): 61 | width = int(PLAY_AREA_CAPTURE_PARAMS[0]) 62 | height = int(PLAY_AREA_CAPTURE_PARAMS[1]) 63 | wDC = win32gui.GetWindowDC(self.hwnd) 64 | dcObj = win32ui.CreateDCFromHandle(wDC) 65 | cDC = dcObj.CreateCompatibleDC() 66 | dataBitMap = win32ui.CreateBitmap() 67 | dataBitMap.CreateCompatibleBitmap(dcObj, width, height) 68 | frames = [] 69 | display = [] 70 | for i in range(stack_num): 71 | cDC.SelectObject(dataBitMap) 72 | cDC.BitBlt((0, 0), (width, height), dcObj, 73 | (PLAY_AREA_CAPTURE_PARAMS[2], PLAY_AREA_CAPTURE_PARAMS[3]), win32con.SRCCOPY) 74 | time.sleep(stack_interval) 75 | # convert the raw data into a format opencv can read 76 | #dataBitMap.SaveBitmapFile(cDC, 'debug.bmp') 77 | signedIntsArray = dataBitMap.GetBitmapBits(True) 78 | frame = np.frombuffer(signedIntsArray, dtype='uint8') 79 | frame.shape = (height, width, 4) 80 | frame = np.ascontiguousarray(frame[..., :3]) 81 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 82 | display.append(frame) 83 | frame = cv2.resize(frame, resize, interpolation=cv2.INTER_LINEAR) 84 | frames.append(frame) 85 | 86 | frames = np.stack(frames) 87 | display = np.stack(display) 88 | # free resources 89 | dcObj.DeleteDC() 90 | cDC.DeleteDC() 91 | win32gui.ReleaseDC(self.hwnd, wDC) 92 | win32gui.DeleteObject(dataBitMap.GetHandle()) 93 | 94 | return frames, display 95 | 96 | def get_state(self): 97 | if self.state == OsuAgentState.WAITING_FOR_MAP: 98 | while self.state == OsuAgentState.WAITING_FOR_MAP: 99 | time.sleep(0.01) 100 | 101 | while self.state == OsuAgentState.WAITING_FOR_OPENING: 102 | score, acc, game_time = asyncio.run( 103 | self.sock.send_and_wait("state", "0.0,0.0,-10.0")).split(',') 104 | if float(game_time) >= 0: 105 | self.update_state(OsuAgentState.PLAYING_MAP) 106 | break 107 | time.sleep(0.01) 108 | 109 | stacked, display_frame = self.capture_frames(stack_interval=0.01, stack_num=self.stacks, resize=(int(PLAY_AREA_CAPTURE_PARAMS[0] * FINAL_RESIZE_PERCENT), int( 110 | PLAY_AREA_CAPTURE_PARAMS[1] * FINAL_RESIZE_PERCENT))) 111 | score, acc, game_time = asyncio.run( 112 | self.sock.send_and_wait("state", "0.0,0.0,0.0")).split(',') 113 | 114 | state_arr = [float(score), float(acc), float(game_time)] 115 | 116 | # drop the alpha channel, or cv.matchTemplate() will throw an error like: 117 | # error: (-215:Assertion failed) (depth == CV_8U || depth == CV_32F) && type == _templ.type() 118 | # && _img.dims() <= 2 in function 'cv::matchTemplate' 119 | 120 | # make image C_CONTIGUOUS to avoid errors that look like: 121 | # File ... in draw_rectangles 122 | # TypeError: an integer is required (got type tuple) 123 | # see the discussion here: 124 | # https://github.com/opencv/opencv/issues/14866#issuecomment-580207109 125 | 126 | self.buff.put(display_frame) 127 | return [stacked / 255] + state_arr 128 | 129 | def do_action(self, action): 130 | if action < 0.5: 131 | keyboard.release("z") 132 | else: 133 | keyboard.press("z") 134 | 135 | def reset(self): 136 | keyboard.press_and_release("`") 137 | time.sleep(0.01) 138 | print("State at reset", self.state) 139 | while self.state == OsuAgentState.WAITING_FOR_MAP: 140 | time.sleep(0.01) 141 | 142 | def kill(self): 143 | self.buff.put(None) 144 | self.sock.kill() 145 | -------------------------------------------------------------------------------- /rl/dqn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class DQN(torch.nn.Module): 6 | """ 7 | Works so far 8 | 9 | Args: 10 | torch (_type_): _description_ 11 | """ 12 | 13 | def __init__(self, action_space=1, stacks=4): 14 | super().__init__() 15 | self.conv = nn.Sequential( 16 | nn.Conv2d(stacks, 32, 8, stride=4), 17 | nn.ReLU(), 18 | nn.Conv2d(32, 64, 4, stride=2), 19 | nn.ReLU(), 20 | nn.Conv2d(64, 64, 3, stride=1), 21 | nn.ReLU(), 22 | nn.Flatten(), 23 | nn.Linear(7488, 512), 24 | nn.ReLU(), 25 | nn.Linear(512, action_space), 26 | ) 27 | 28 | def forward(self, images): 29 | return self.conv(images) 30 | 31 | def save_model(self, path="test.pt"): 32 | torch.save(self.state_dict(), path) 33 | 34 | def load_model(self, path="test.pt"): 35 | self.load_state_dict(torch.load(path)) 36 | -------------------------------------------------------------------------------- /rl/env.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import numpy as np 4 | import random 5 | from torch import optim, nn 6 | from ai.constants import PLAY_AREA_CAPTURE_PARAMS, FINAL_RESIZE_PERCENT, PYTORCH_DEVICE 7 | from rl.agent import OsuAgent 8 | from collections import deque 9 | from rl.dqn import DQN 10 | 11 | IMAGE_SHAPE = (int(PLAY_AREA_CAPTURE_PARAMS[0] * FINAL_RESIZE_PERCENT), 12 | int(PLAY_AREA_CAPTURE_PARAMS[1] * FINAL_RESIZE_PERCENT), 3) 13 | 14 | CAPACITY_MAX = 32 15 | 16 | 17 | class OsuEnviroment(): 18 | 19 | def __init__(self) -> None: 20 | super().__init__() 21 | self.stacks = 5 22 | self.agent = OsuAgent(self.stacks) 23 | self.memory = deque([], maxlen=CAPACITY_MAX) 24 | self.lr = 1e-4 25 | self.gamma = 0.99 26 | self.tau = 1.0 27 | self.model = DQN(stacks=self.stacks).to(PYTORCH_DEVICE) 28 | self.target = DQN(stacks=self.stacks).to(PYTORCH_DEVICE) 29 | self.target.load_state_dict(self.model.state_dict()) 30 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) 31 | self.criterion = nn.MSELoss() 32 | self.epsilon = 0 33 | self.plays = 0 34 | 35 | def remember(self, mem): 36 | self.memory.append(mem) 37 | 38 | def step(self): 39 | 40 | prev_state = self.agent.get_state() 41 | 42 | if prev_state is None: 43 | return True 44 | 45 | last_playfield, last_score, last_accuracy, last_game_time = prev_state 46 | 47 | action = self.sample(last_playfield) 48 | 49 | self.agent.do_action(action) 50 | 51 | time.sleep(0.05) 52 | 53 | new_state = self.agent.get_state() 54 | 55 | if new_state is None: 56 | return True 57 | 58 | playfield, score, accuracy, game_time = new_state 59 | 60 | reward = 0 61 | 62 | if accuracy < last_accuracy: 63 | if action == 0: 64 | reward = -800 65 | else: 66 | reward = 400 67 | 68 | is_over = accuracy < 5 and game_time > 2 69 | 70 | self.remember( 71 | np.array([last_playfield, action, reward, playfield, int(is_over)], dtype=object)) 72 | 73 | is_over = is_over if not self.train() else True 74 | return is_over 75 | 76 | def reset(self): 77 | self.plays += 1 78 | self.agent.reset() 79 | print("Attempt", self.plays) 80 | 81 | def predict_one(self, img): 82 | model_in = torch.from_numpy(img) 83 | 84 | return self.predict_one_tensor(model_in) 85 | 86 | def predict_one_tensor(self, model_in): 87 | model_in = model_in.reshape( 88 | (1, model_in.shape[0], model_in.shape[1], model_in.shape[2])).type( 89 | torch.FloatTensor).to(PYTORCH_DEVICE) 90 | 91 | output = self.model(model_in) 92 | 93 | _, predicated = torch.max(output, dim=1) 94 | return predicated 95 | 96 | def sample(self, state): 97 | if random.randint(0, 200) > self.epsilon: 98 | return random.randint(0, 1) 99 | else: 100 | with torch.no_grad(): 101 | return int(self.predict_one(state).item()) 102 | 103 | def train(self): 104 | if len(self.memory) < CAPACITY_MAX: 105 | return False 106 | 107 | states, actions, rewards, next_states, dones = zip( 108 | *self.memory) 109 | 110 | state: torch.Tensor = torch.from_numpy( 111 | np.stack(states, axis=0)).type(torch.FloatTensor).to(PYTORCH_DEVICE) 112 | next_state: torch.Tensor = torch.from_numpy( 113 | np.stack(next_states, axis=0)).type(torch.FloatTensor).to(PYTORCH_DEVICE) 114 | reward: torch.Tensor = torch.from_numpy( 115 | np.stack(rewards, axis=0)).type(torch.LongTensor).to(PYTORCH_DEVICE) 116 | done: torch.Tensor = torch.from_numpy( 117 | np.stack(dones, axis=0)).to(PYTORCH_DEVICE) 118 | 119 | action = torch.from_numpy( 120 | np.stack(actions, axis=0)).type(torch.LongTensor).to(PYTORCH_DEVICE) 121 | 122 | with torch.no_grad(): 123 | target_max, _ = self.target(next_state).max(dim=1) 124 | td_target = reward.flatten() + self.gamma * target_max * (1 - done.flatten()) 125 | 126 | old_val = self.model(state).squeeze(1)[action] 127 | 128 | loss = self.criterion(td_target, old_val) 129 | 130 | self.optimizer.zero_grad() 131 | loss.backward() 132 | self.optimizer.step() 133 | 134 | print(f"Optimized with loss {loss.item()}") 135 | 136 | for target_network_param, q_network_param in zip(self.target.parameters(), self.model.parameters()): 137 | target_network_param.data.copy_( 138 | self.tau * q_network_param.data + 139 | (1.0 - self.tau) * target_network_param.data 140 | ) 141 | 142 | self.memory.clear() 143 | self.epsilon += 1 144 | return True 145 | -------------------------------------------------------------------------------- /rl/memory.py: -------------------------------------------------------------------------------- 1 | from collections import deque, namedtuple 2 | import random 3 | 4 | 5 | Transition = namedtuple('Transition', 6 | ('state', 'action', 'next_state', 'reward')) 7 | 8 | 9 | class ReplayMemory(object): 10 | 11 | def __init__(self, capacity): 12 | self.memory = deque([], maxlen=capacity) 13 | 14 | def push(self, *args): 15 | """Save a transition""" 16 | self.memory.append(Transition(*args)) 17 | 18 | def sample(self, batch_size): 19 | return random.sample(self.memory, batch_size) 20 | 21 | def __len__(self): 22 | return len(self.memory) 23 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from ai.converter import ReplayConverter 2 | import os 3 | 4 | # for dataset in os.listdir("./to_convert"): 5 | # ReplayConverter(dataset,danser_video=os.path.join('./to_convert',dataset,f"{dataset}.mkv"),replay_json=os.path.join('./to_convert',dataset,f"{dataset}.json"),save_dir="./",num_readers=3,debug=False) 6 | # # ReplayConverter("test",danser_video="sawai miku - colorful (wo cursor).mkv",replay_json="sawai miku - colorful (wo cursor).json",save_dir="./",num_readers=1,video_fps=60,frame_interval_ms=50) 7 | 8 | # import torch 9 | 10 | # a = torch.Tensor([[1,1,1,0]]) 11 | # b = torch.Tensor([[1,1,1,1]]) 12 | 13 | # def get_acc(pred: torch.Tensor,truth: torch.Tensor,thresh: int = 60,is_combined=False): 14 | # pred = pred.detach().clone() 15 | # truth = truth.detach().clone() 16 | 17 | # pred[:,0] *= 1920 18 | # pred[:,1] *= 1080 19 | # truth[:,0] *= 1920 20 | # truth[:,1] *= 1080 21 | 22 | # diff = (pred[:,:-2] - truth[:,:-2]) if is_combined else pred - truth 23 | 24 | # dist = torch.sqrt((diff ** 2).sum(dim=1)) 25 | 26 | # dist[dist < thresh] = 1 27 | 28 | # dist[dist >= thresh] = 0 29 | 30 | # if not is_combined: 31 | # return dist.mean().item() 32 | 33 | # pred_keys = pred[:,2:] 34 | # truth_keys = truth[:,2:] 35 | 36 | # pred_keys[pred_keys >= 0.5] = 1 37 | # truth_keys[truth_keys >= 0.5] = 1 38 | # pred_keys[pred_keys < 0.5] = 0 39 | # truth_keys[truth_keys < 0.5] = 0 40 | 41 | # return (dist.mean().item() + torch.all(pred_keys == truth_keys,dim=1).float().mean().item()) / 2 42 | 43 | 44 | # print(a,'\n',b,'\n',get_acc(a,b,is_combined=True)) 45 | 46 | 47 | ReplayConverter("Rightfully 8", "Rightfully 8.mkv", 48 | "Rightfully 8.json", max_in_memory=5000, save_dir="./", 49 | num_writers=1, debug=True) 50 | -------------------------------------------------------------------------------- /windows.py: -------------------------------------------------------------------------------- 1 | from threading import Thread 2 | import numpy as np 3 | from queue import Queue 4 | 5 | 6 | # class WindowCapture: 7 | # def __init__(self, window_name=None) -> None: 8 | # if window_name is None: 9 | # self.hwnd = win32gui.GetDesktopWindow() 10 | # else: 11 | # self.hwnd = win32gui.FindWindow(None, window_name) 12 | # if not self.hwnd: 13 | # WindowCapture.list_window_names() 14 | # raise Exception( 15 | # f"Window '{window_name}' Not Found Select window name from above") 16 | # self.num_captured = 0 17 | # self.wDC = None 18 | # self.dcObj = None 19 | # self.cDC = None 20 | # self.dataBitMap = None 21 | 22 | # def ensure_resources(self, width: int, height: int): 23 | # if self.num_captured > 50: 24 | # self.dcObj.DeleteDC() 25 | # self.cDC.DeleteDC() 26 | # win32gui.ReleaseDC(self.hwnd, self.wDC) 27 | # win32gui.DeleteObject(self.dataBitMap.GetHandle()) 28 | # self.wDC = None 29 | # self.dcObj = None 30 | # self.cDC = None 31 | # self.dataBitMap = None 32 | # self.num_captured = 0 33 | 34 | # if self.wDC is None: 35 | # self.wDC = win32gui.GetWindowDC(self.hwnd) 36 | # self.dcObj = win32ui.CreateDCFromHandle(self.wDC) 37 | # self.cDC = self.dcObj.CreateCompatibleDC() 38 | # self.dataBitMap = win32ui.CreateBitmap() 39 | # self.dataBitMap.CreateCompatibleBitmap(self.dcObj, width, height) 40 | # self.num_captured += 1 41 | 42 | # def get_frame(self, resize: tuple[int, int], width: int, height: int, dx: int, dy: int, is_stacking=False): 43 | # self.ensure_resources(width, height) 44 | 45 | # self.cDC.SelectObject(self.dataBitMap) 46 | # self.cDC.BitBlt((0, 0), (width, height), self.dcObj, 47 | # (dx, dy), win32con.SRCCOPY) 48 | 49 | # # convert the raw data into a format opencv can read 50 | # #dataBitMap.SaveBitmapFile(cDC, 'debug.bmp') 51 | # signedIntsArray = self.dataBitMap.GetBitmapBits(True) 52 | # frame = np.frombuffer(signedIntsArray, dtype='uint8') 53 | # frame.shape = (height, width, 4) 54 | # frame = np.ascontiguousarray(frame[..., :3]) 55 | # frame = cv2.resize(frame, resize, interpolation=cv2.INTER_LINEAR) 56 | # if not is_stacking: 57 | # return frame 58 | 59 | # return cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 60 | 61 | # def capture(self, prev_frames=[], stack_num=0, resize: tuple[int, int] = (1920, 1080), width=1920, height=1080, dx=0, dy=0) -> np.ndarray: 62 | # width = int(width) 63 | # height = int(height) 64 | 65 | # self.ensure_resources(width, height) 66 | 67 | # is_stacking = stack_num != 0 68 | # frame = self.get_frame(resize, width, height, dx, dy, is_stacking) 69 | 70 | # if not is_stacking: 71 | # return frame 72 | 73 | # prev_count = len(prev_frames) 74 | # needed_count = stack_num - prev_count 75 | # final_frames = [] 76 | # if needed_count > 1: 77 | # final_frames = prev_frames + [frame for _ in range(needed_count)] 78 | # else: 79 | # final_frames = prev_frames[prev_count - 80 | # (stack_num - 1):prev_count] + [frame] 81 | 82 | # return [frame, np.stack(final_frames)] 83 | 84 | # # Deleting (Calling destructor) 85 | # def __del__(self): 86 | # if self.wDC is not None: 87 | # self.dcObj.DeleteDC() 88 | # self.cDC.DeleteDC() 89 | # win32gui.ReleaseDC(self.hwnd, self.wDC) 90 | # win32gui.DeleteObject(self.dataBitMap.GetHandle()) 91 | # self.wDC = None 92 | # self.dcObj = None 93 | # self.cDC = None 94 | # self.dataBitMap = None 95 | # self.num_captured = 0 96 | 97 | # def list_window_names(): 98 | # def winEnumHandler(hwnd, ctx): 99 | # if win32gui.IsWindowVisible(hwnd): 100 | # print(hex(hwnd), win32gui.GetWindowText(hwnd)) 101 | # win32gui.EnumWindows(winEnumHandler, None) 102 | 103 | 104 | # class WindowStream: 105 | # def __init__(self, window_name=None, width=1920, height=1080, dx=0, dy=0) -> None: 106 | # self.window_name = window_name 107 | # self.width = width 108 | # self.height = height 109 | # self.dx = dx 110 | # self.dy = dy 111 | # self.frame_buffer = Queue() 112 | # self.frames = 0 113 | # Thread(daemon=True, group=None, target=self.capture).start() 114 | # Thread(daemon=True, group=None, target=self.do_frame_rate).start() 115 | # Thread(daemon=True, group=None, target=self.stream).start() 116 | 117 | # def do_frame_rate(self): 118 | # while True: 119 | # time.sleep(1) 120 | # print(f"Capture FPS {self.frames:.0f} ", end="\r") 121 | # self.frames = 0 122 | 123 | # def capture(self): 124 | # window_capture = WindowCapture(self.window_name) 125 | # while True: 126 | # frame = window_capture.capture( 127 | # self.width, self.height, self.dx, self.dy) 128 | # self.frame_buffer.put(frame) 129 | # self.frames += 1 130 | 131 | # def stream(self): 132 | # while True: 133 | # frame = self.frame_buffer.get() 134 | # if frame is not None: 135 | # cv2.imshow( 136 | # f"Stream of window {self.window_name if self.window_name is not None else 'Desktop'}", frame) 137 | # cv2.waitKey(1) 138 | 139 | 140 | 141 | --------------------------------------------------------------------------------