├── tools ├── __init__.py ├── utils │ ├── __init__.py │ ├── break_test.py │ ├── str_compare.py │ ├── load_and_save.py │ ├── huggingface.py │ ├── numpy_shorts.py │ ├── index_find_str.py │ ├── ask_parameter.py │ └── song_metadata.py ├── config │ ├── __init__.py │ ├── exclusion.py │ ├── check_folder_structure.py │ ├── mapper_selection.py │ └── paths.py ├── fail_list │ ├── __init__.py │ └── black_list.py └── PowerBeats_extension │ ├── __init__.py │ ├── music_shift.py │ ├── PowerBeats_shift.py │ └── update_artist_name.py ├── app_helper ├── __init__.py ├── cover.jpg ├── update_dir_path.py ├── set_app_paths.py └── check_input.py ├── bs_shift ├── __init__.py ├── bps_find_songs.py ├── copyfavorites.py ├── hashtest.py ├── cleanup_n_format.py ├── export_map.py └── shift.py ├── training ├── __init__.py ├── eval_autoenc_music.py ├── plot_model.py ├── helpers.py ├── eval_bs_automapper.py ├── train_autoenc_music.py ├── train_bs_automapper.py └── tensorflow_models.py ├── beat_prediction ├── __init__.py ├── validate_find_beats.py ├── beat_to_lstm.py ├── beat_prop.py ├── find_beats.py └── ai_beat_gen.py ├── map_creation ├── __init__.py ├── note_postprocessing.py ├── class_helpers.py ├── bpm_optimizer.py ├── artificial_mod.py ├── find_bpm.py ├── gen_sliders.py └── gen_obstacles.py ├── preprocessing ├── __init__.py ├── map_info_processing.py ├── load_dic_dif_casting.py ├── bs_mapper_pre.py ├── beat_data_helper.py └── music_processing.py ├── .gitignore ├── lighting_prediction ├── __init__.py ├── tf_lighting.py ├── generate_lighting.py └── train_lighting.py ├── pytest.ini ├── requirements.txt ├── .github └── ISSUE_TEMPLATE │ └── bug_report.md ├── LICENSE ├── tests └── test_song_metadata.py ├── main_training.py ├── countlines.py ├── Readme.md ├── How to - Linux install for training and app.md ├── evaluation └── evaluate_beat_algorithms.py └── main.py /tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app_helper/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bs_shift/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /beat_prediction/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /map_creation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/config/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/fail_list/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea 2 | /*.pclprof -------------------------------------------------------------------------------- /lighting_prediction/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/PowerBeats_extension/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | python_files = tests/test_*.py 3 | addopts = -ra 4 | -------------------------------------------------------------------------------- /app_helper/cover.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fred-brenner/InfernoSaber---BeatSaber-Automapper/HEAD/app_helper/cover.jpg -------------------------------------------------------------------------------- /tools/config/exclusion.py: -------------------------------------------------------------------------------- 1 | ################################################### 2 | # exclusions for shifting from beatsaber to project 3 | # all in lower letters! 4 | ################################################### 5 | 6 | exclusion = list(["6 lane", "6lane", "8 lane", "8lane", "onesaber", 7 | "one saber", "360°", "90°", "1234"]) 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy~=1.13.1 2 | # numpy~=2.0.2 3 | tensorflow~=2.15.1 4 | ffmpy~=0.5.0 5 | librosa~=0.11.0 6 | # aubio==0.4.7 7 | # pydub~=0.25.1 8 | tabulate~=0.9.0 9 | pillow~=10.4.0 10 | joblib~=1.4.2 11 | scikit-learn~=1.3.2 12 | keras~=2.15.0 13 | keras-tcn==3.5.4 14 | progressbar2~=4.5.0 15 | huggingface-hub~=0.29.3 16 | requests~=2.32.3 17 | gradio~=5.29.1 18 | yt-dlp 19 | mutagen 20 | -------------------------------------------------------------------------------- /tools/utils/break_test.py: -------------------------------------------------------------------------------- 1 | # import matplotlib.pyplot as plt 2 | # from scipy.signal import savgol_filter 3 | 4 | from tools.config import config, paths 5 | from map_creation.sanity_check import add_breaks 6 | from tools.utils.load_and_save import load_pkl 7 | 8 | data = load_pkl(paths.temp_path) 9 | [notes_l, notes_r, timings] = data 10 | 11 | notes_l = add_breaks(notes_l, timings) 12 | notes_r = add_breaks(notes_r, timings) 13 | 14 | notes_l 15 | -------------------------------------------------------------------------------- /tools/utils/str_compare.py: -------------------------------------------------------------------------------- 1 | ####################################### 2 | # compare a string to a list of strings 3 | # may return the matched string 4 | ####################################### 5 | 6 | def str_compare(str, str_list, return_str, silent): 7 | str = str.lower() 8 | for s in str_list: 9 | if s in str: 10 | if not silent: 11 | print("Exclude: " + str + " | Found " + s) 12 | 13 | if return_str: 14 | return s 15 | else: 16 | return True 17 | # no similarity found 18 | return False 19 | -------------------------------------------------------------------------------- /tools/utils/load_and_save.py: -------------------------------------------------------------------------------- 1 | """ 2 | Import and Export of numpy arrays into npy files 3 | """ 4 | 5 | import numpy as np 6 | import pickle 7 | 8 | 9 | def load_npy(data_path: str) -> np.array: 10 | ar = np.load(data_path, allow_pickle=True) 11 | return ar 12 | 13 | 14 | def save_npy(ar: np.array, data_path: str): 15 | ar.dump(data_path) 16 | 17 | 18 | def load_pkl(data_path: str) -> list: 19 | with open(data_path + 'debug_variables.pkl', 'rb') as f: 20 | data = pickle.load(f) 21 | return data 22 | 23 | 24 | def save_pkl(data: list, data_path: str): 25 | with open(data_path + 'debug_variables.pkl', 'wb') as f: 26 | pickle.dump(data, f) 27 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help me improve 4 | title: '' 5 | labels: '' 6 | assignees: fred-brenner 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | 30 | **Additional context** 31 | Add any other context about the problem here. 32 | -------------------------------------------------------------------------------- /beat_prediction/validate_find_beats.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # import matplotlib.pyplot as plt 3 | import os 4 | 5 | # from preprocessing.bs_mapper_pre import load_beat_data 6 | # from beat_prediction.find_beats import find_beats 7 | # from tools.config import paths 8 | 9 | 10 | def plot_beat_vs_real(beat_pred, beat_real): 11 | plt.figure() 12 | plt.vlines(beat_pred, 0, 1, colors='k', linestyles='solid', linewidth=0.2) 13 | plt.scatter(beat_real, [0.5] * len(beat_real)) 14 | plt.show() 15 | 16 | # 17 | # if __name__ == '__main__': 18 | # name_ar = os.listdir(paths.songs_pred) 19 | # pitch_list = find_beats(name_ar, train_data=False) 20 | # 21 | # name_ar = ['Born This Way', 'Dizzy'] 22 | # _, real_beats = load_beat_data(name_ar) 23 | # 24 | # idx = 0 25 | # plot_beat_vs_real(pitch_list[idx], real_beats[idx]) 26 | # 27 | # print("") 28 | -------------------------------------------------------------------------------- /tools/PowerBeats_extension/music_shift.py: -------------------------------------------------------------------------------- 1 | # This script shifts the songs (.egg) to another folder (.mp3) 2 | 3 | 4 | import os 5 | import shutil 6 | 7 | import tools.config.paths as paths 8 | from bs_shift.export_map import convert_music_file 9 | 10 | 11 | # paths 12 | copy_path_new = "C:/Users/frede/Music/" 13 | copy_path_origin = paths.songs_pred 14 | 15 | # folder check 16 | if not os.path.isdir(copy_path_new): 17 | print("Could not find new song folder! Exit") 18 | exit() 19 | if not os.path.isdir(copy_path_origin): 20 | print("Could not find song origin folder! Exit") 21 | exit() 22 | 23 | counter = 0 24 | for root, _, files in os.walk(copy_path_origin): 25 | for song_file in files: 26 | if song_file.endswith(".egg"): 27 | counter += 1 28 | print(f"Converting song: {song_file}") 29 | # copy in new directory and rename to .mp3 30 | output_file = f"{song_file[:-4]}.mp3" 31 | convert_music_file(os.path.join(root, song_file), 32 | copy_path_new + output_file) 33 | 34 | print(f"Finished shift to mp3 of {counter} song files") 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Frederic Brenner 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tools/utils/huggingface.py: -------------------------------------------------------------------------------- 1 | import os 2 | from huggingface_hub import snapshot_download 3 | 4 | from tools.config import config, paths 5 | from tools.config.mapper_selection import get_full_model_path 6 | 7 | 8 | def model_download(model_branch=None): 9 | model_name = "BierHerr/InfernoSaber" 10 | model_folder = paths.model_path 11 | if model_branch is None: 12 | model_branch = config.use_mapper_selection 13 | 14 | # Create folder if it doesn't exist 15 | if not os.path.exists(model_folder): 16 | os.makedirs(model_folder) 17 | 18 | # Check if model already exists 19 | if len(os.listdir(model_folder)) > 5: 20 | print("Model is already setup. Skipping download") 21 | else: 22 | print(f"Download model: {model_branch} from huggingface...") 23 | snapshot_download(repo_id=model_name, revision=model_branch, 24 | local_dir=model_folder, local_dir_use_symlinks=False, 25 | ignore_patterns=["Readme.md", ".git*"]) 26 | 27 | # check that model exists on the example of event generator 28 | _ = get_full_model_path(config.event_gen_version) 29 | -------------------------------------------------------------------------------- /tools/PowerBeats_extension/PowerBeats_shift.py: -------------------------------------------------------------------------------- 1 | # This script shifts the songs (.egg) to another folder (.ogg) 2 | # shift.py needs to be run first 3 | import os, sys, inspect 4 | 5 | current_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 6 | parent_dir = os.path.dirname(current_dir) 7 | sys.path.insert(0, parent_dir) 8 | 9 | import config.paths as paths 10 | # import os 11 | import shutil 12 | 13 | # paths 14 | copy_path_new = "C:/Users/frede/Music/" 15 | copy_path_origin = paths.copy_path_song 16 | 17 | # folder check 18 | if not os.path.isdir(copy_path_new): 19 | print("Could not find new song folder! Exit") 20 | exit() 21 | if not os.path.isdir(copy_path_origin): 22 | print("Could not find song origin folder! Exit") 23 | exit() 24 | 25 | counter = 0 26 | for song_file in os.listdir(copy_path_origin): 27 | counter += 1 28 | if not song_file.endswith(".egg"): 29 | print(f"Warning: unknown file type: {song_file}") 30 | else: 31 | # copy in new directory and rename to .ogg 32 | new_name = song_file[:-4] + ".ogg" 33 | shutil.copyfile(copy_path_origin + song_file, copy_path_new + new_name) 34 | # print(song_file) 35 | 36 | print(f"Finished shift to ogg of {counter} song files") -------------------------------------------------------------------------------- /bs_shift/bps_find_songs.py: -------------------------------------------------------------------------------- 1 | # saves numpy arrays of difficulty per song 2 | import os 3 | import numpy as np 4 | import tools.config.paths as paths 5 | 6 | 7 | def bps_find_songs(info_flag=True) -> None: 8 | diff_array = [] 9 | name_array = [] 10 | 11 | for i in os.listdir(paths.dict_all_path): 12 | if i.endswith("_notes.dat"): 13 | # notes file 14 | map_dict_notes = np.load(paths.dict_all_path + i, allow_pickle=True) 15 | if len(map_dict_notes[0]) < 1: 16 | print(f"No notes found in {i}") 17 | 18 | # get song time 19 | diff_time = max(map_dict_notes[0]) - min(map_dict_notes[0]) 20 | 21 | # get beats per second 22 | bps = round(len(map_dict_notes[0]) / diff_time, 2) 23 | 24 | # append to array with name in front 25 | diff_array.append(bps) 26 | name_array.append(i[:-10]) 27 | 28 | if info_flag: 29 | print(f"\nInfo: Highest avg cut_per_second found in one song: {max(diff_array)}") 30 | 31 | # save arrays 32 | diff_array = np.asarray(diff_array) 33 | name_array = np.asarray(name_array) 34 | if len(diff_array) != len(name_array): 35 | print("Error in bps_find_songs.py") 36 | exit() 37 | 38 | np.save(paths.diff_ar_file, diff_array) 39 | np.save(paths.name_ar_file, name_array) 40 | 41 | # Finished 42 | # print("Finished notes per second calculation") 43 | -------------------------------------------------------------------------------- /map_creation/note_postprocessing.py: -------------------------------------------------------------------------------- 1 | from tools.config import config 2 | import random 3 | 4 | 5 | def remove_double_notes(notes_all): 6 | for idx, notes in enumerate(notes_all): 7 | # each "notes" is one 8 | type_covered = [] 9 | remove_index = [] 10 | for i in range(int(len(notes) / 4)): 11 | n_type = notes[i * 4 + 2] 12 | if n_type == 3: 13 | # ignore bombs? 14 | continue 15 | if n_type not in type_covered: 16 | type_covered.append(n_type) 17 | else: 18 | remove_index.append(i) 19 | if len(remove_index) > 0: 20 | remove_index.reverse() 21 | for r_index in remove_index: 22 | for _ in range(4): 23 | notes.pop(r_index * 4) 24 | notes_all[idx] = notes 25 | 26 | if config.single_notes_only_strict_flag: 27 | for idx, notes in enumerate(notes_all): 28 | if len(notes) > 4: 29 | target_note = 0 if random.random() > config.single_notes_remove_lr else 1 30 | for i in range(int(len(notes) / 4)): 31 | if notes[i*4+2] == target_note: 32 | notes = notes[i*4:i*4+4] 33 | break 34 | # else 35 | if len(notes) > 4: 36 | notes = notes[0:4] 37 | notes_all[idx] = notes 38 | 39 | return notes_all 40 | -------------------------------------------------------------------------------- /tools/utils/numpy_shorts.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from random import shuffle 3 | 4 | 5 | def np_append(old, new, axis=0): 6 | if old is None: 7 | out = new 8 | else: 9 | out = np.concatenate((old, new), axis=axis) 10 | return out 11 | 12 | 13 | def minmax_3d(ar: np.array) -> np.array: 14 | ar -= ar.min() 15 | ar /= ar.max() 16 | return ar 17 | 18 | 19 | def reduce_number_of_songs(name_ar, hard_limit=50): 20 | if len(name_ar) > hard_limit: 21 | shuffle(name_ar) 22 | print(f"Info: Loading reduced song number into generator to not overload the RAM (from {len(name_ar)})") 23 | name_ar = name_ar[:hard_limit] 24 | print(f"Importing {len(name_ar)} songs") 25 | return name_ar 26 | 27 | 28 | def get_factor_from_max_speed(max_speed, lb=0.5, ub=1.5): 29 | max_speed = max_speed / 4 30 | ls = 0 31 | us = 15 32 | 33 | if max_speed <= ls: 34 | return lb 35 | elif max_speed >= us: 36 | return ub 37 | else: 38 | factor = lb + (ub - lb) * (max_speed / us) 39 | return factor 40 | 41 | 42 | def add_onset_half_times(times, min_time=0.1, max_time=1.5): 43 | diff = times[1:] - times[:-1] 44 | new_times = list(times) 45 | for idx, d in enumerate(diff): 46 | if min_time <= d <= max_time: 47 | new_times.append(times[idx] + (d / 2)) 48 | 49 | new_times = np.asarray(new_times) 50 | new_times = np.sort(new_times) 51 | return new_times 52 | -------------------------------------------------------------------------------- /bs_shift/copyfavorites.py: -------------------------------------------------------------------------------- 1 | """ 2 | Thanks to @trendy_ideology for creating this script! 3 | Second script (2/2) 4 | """ 5 | 6 | import csv 7 | import os 8 | import shutil 9 | 10 | 11 | def read_ids_from_csv(file_path): 12 | ids = [] 13 | with open(file_path, newline='', encoding='utf-8') as csvfile: 14 | reader = csv.DictReader(csvfile) 15 | for row in reader: 16 | if row['ID'].strip(): # Ensuring the ID is not empty 17 | ids.append(row['ID'].strip()) 18 | return ids 19 | 20 | 21 | def copy_matching_folders(ids, source_dir, target_dir): 22 | os.makedirs(target_dir, exist_ok=True) 23 | folders = [folder for folder in os.listdir(source_dir) if os.path.isdir(os.path.join(source_dir, folder))] 24 | 25 | for folder in folders: 26 | folder_id = folder.split(' ')[0] # Get the part before the first space 27 | folder_id = folder_id.split('(')[0] # Or before the first parenthesis if needed 28 | if folder_id in ids: 29 | source_folder_path = os.path.join(source_dir, folder) 30 | target_folder_path = os.path.join(target_dir, folder) 31 | 32 | # Check if the folder already exists in the target to avoid overwriting 33 | if not os.path.exists(target_folder_path): 34 | shutil.copytree(source_folder_path, target_folder_path) 35 | print(f"Copied {folder} to {target_dir}") 36 | else: 37 | print(f"Skipped {folder} as it already exists in {target_dir}") 38 | 39 | 40 | def main(): 41 | csv_file = 'output.csv' 42 | source_directory = "E:/SteamLibrary/steamapps/common/Beat Saber/Beat Saber_Data/CustomLevels/" 43 | target_directory = "C:/Users/frede/Desktop/BS_Automapper/Data/training/favorites_bs_input/" 44 | 45 | ids = read_ids_from_csv(csv_file) 46 | copy_matching_folders(ids, source_directory, target_directory) 47 | 48 | 49 | if __name__ == "__main__": 50 | main() 51 | -------------------------------------------------------------------------------- /map_creation/class_helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.preprocessing import OneHotEncoder 3 | import joblib 4 | 5 | from tools.config import config, paths 6 | 7 | 8 | def update_out_class(in_class_l, y_class, idx): 9 | if y_class is None: 10 | return in_class_l 11 | 12 | last_class_lstm = in_class_l[idx-1] 13 | new_class_lstm = np.concatenate((last_class_lstm[1:], y_class), axis=0) 14 | in_class_l[idx] = new_class_lstm 15 | 16 | return in_class_l 17 | 18 | 19 | def get_class_size(file): 20 | enc = joblib.load(file) 21 | # if type(enc) == list: 22 | # size = len(enc) 23 | # else: 24 | size = len(enc.categories_[0]) 25 | return size 26 | 27 | 28 | def cast_y_class(y_class): 29 | max_idx = np.argmax(y_class) 30 | y_class = np.zeros_like(y_class) 31 | y_class[0, max_idx] = 1 32 | 33 | return y_class 34 | 35 | 36 | def decode_onehot_class(y_class_map, file): 37 | # test = np.argmax(y_class_map, axis=-1).reshape(-1) 38 | enc = joblib.load(file) 39 | 40 | y = np.asarray(y_class_map).reshape((len(y_class_map), -1)) 41 | y_class_num = enc.inverse_transform(y) 42 | 43 | return y_class_num 44 | 45 | 46 | def add_favor_factor_next_class(y_class, y_class_last): 47 | if y_class_last is None: 48 | return y_class 49 | 50 | next_class = np.argmax(y_class_last) + 1 51 | 52 | start_idx = next_class - 10 53 | end_idx = next_class + 12 54 | while start_idx < 0: 55 | start_idx += 2 56 | while end_idx > y_class.shape[1]: 57 | end_idx -= 2 58 | 59 | flc = config.favor_last_class * np.random.rand(int((end_idx - start_idx) / 2)) 60 | y_class[:, start_idx:end_idx:2] += flc 61 | 62 | # # check that next class does not exceed array length 63 | # if next_class < y_class_last.shape[-1]: 64 | # flc = config.favor_last_class * (np.random.random() + 0.5) / 1.5 65 | # y_class[:, next_class] += flc 66 | 67 | return y_class 68 | -------------------------------------------------------------------------------- /bs_shift/hashtest.py: -------------------------------------------------------------------------------- 1 | """ 2 | Thanks to @trendy_ideology for creating this script! 3 | First script (1/2) 4 | """ 5 | 6 | import requests 7 | import csv 8 | import json 9 | 10 | 11 | def query_beatsaver_api(hash_value): 12 | url = f"https://api.beatsaver.com/maps/hash/{hash_value}" 13 | response = requests.get(url, headers={"accept": "application/json"}) 14 | if response.status_code == 200: 15 | data = response.json() 16 | return { 17 | "ID": data.get("id", ""), 18 | "SongName": data["metadata"]["songName"], 19 | "LevelAuthorName": data["metadata"]["levelAuthorName"] 20 | } 21 | else: 22 | return None 23 | 24 | 25 | def main(player_dat_file): 26 | # Open the output CSV file 27 | with open('output.csv', mode='w', newline='', encoding='utf-8') as file: 28 | writer = csv.writer(file) 29 | # Write the header row 30 | writer.writerow(['QueriedHash', 'ID', 'SongName', 'LevelAuthorName']) 31 | 32 | # Read the hash values from hashes.txt 33 | player_data_all = [] 34 | with open(player_dat_file, 'r') as f: 35 | player_data_all = json.load(f) 36 | 37 | hashes_file = player_data_all['localPlayers'][0]['favoritesLevelIds'] 38 | hashes_file = [f.strip("custom_level_") for f in hashes_file] 39 | for line in hashes_file: 40 | hash_value = line.strip() 41 | result = query_beatsaver_api(hash_value) 42 | if result: 43 | # Write the extracted information to the CSV file 44 | writer.writerow([hash_value, result["ID"], result["SongName"], result["LevelAuthorName"]]) 45 | else: 46 | # If no result, write the hash with empty fields for the other columns 47 | writer.writerow([hash_value, "", "", ""]) 48 | 49 | 50 | if __name__ == "__main__": 51 | player_data_file = r"C:\Users\frede\AppData\LocalLow\Hyperbolic Magnetism\Beat Saber\PlayerData.dat" 52 | main(player_data_file) 53 | -------------------------------------------------------------------------------- /app_helper/update_dir_path.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | def update_dir_path(file_path, keyword='dir_path', new_value=''): 4 | """ 5 | Updates the line containing 'dir_path' in the specified file to the new new_value. 6 | 7 | Args: 8 | file_path (str): Path to the file to be updated. 9 | keyword (str): The keyword to search for in the file. 10 | new_value (str, int, float, bool): The new content to set. 11 | """ 12 | try: 13 | # Read the file 14 | with open(file_path, 'r') as file: 15 | lines = file.readlines() 16 | 17 | found_it = False 18 | # Update the specific line containing 'dir_path' 19 | with open(file_path, 'w') as file: 20 | for line in lines: 21 | # Check if the line contains 'dir_path' 22 | if re.match(rf"^\s*{keyword}\s*=", line) and not found_it: 23 | # Determine the format based on the type of new_value 24 | if isinstance(new_value, str): 25 | formatted_value = f'"{new_value}"' 26 | else: 27 | formatted_value = str(new_value) 28 | # Replace with the new value and ensure a newline is added 29 | file.write(f'{keyword} = {formatted_value}\n') 30 | print(f"Updated {keyword} in {file_path} to: {new_value}") 31 | found_it = True 32 | else: 33 | file.write(line) 34 | if not found_it: 35 | print(f"Error: Could not find keyword {keyword} in configuration file.") 36 | 37 | except FileNotFoundError: 38 | print(f"Error: The file {file_path} does not exist.") 39 | except Exception as e: 40 | print(f"An error occurred: {e}") 41 | return 42 | 43 | # Example usage 44 | if __name__ == "__main__": 45 | file_path = "tools/config/paths.py" 46 | keyword = "dir_path" 47 | new_value = "/new/directory/path" 48 | update_dir_path(file_path, keyword, new_value) -------------------------------------------------------------------------------- /tools/fail_list/black_list.py: -------------------------------------------------------------------------------- 1 | from tools.config import paths 2 | import os 3 | 4 | 5 | # Append name of failed title without ending (e.g. no .dat) 6 | def append_fail(name): 7 | # check if already on black list 8 | on_black_list = False 9 | try: 10 | with open(paths.black_list_file, 'r') as f: 11 | # black_list exists 12 | black_list = f.readlines() 13 | if name + '\n' in black_list: 14 | on_black_list = True 15 | except: 16 | pass 17 | 18 | if not on_black_list: 19 | # open black_list 20 | with open(paths.black_list_file, 'a') as f: 21 | # save title name 22 | f.writelines(name + '\n') 23 | 24 | 25 | def delete_fails(): 26 | if not os.path.isfile(paths.black_list_file): 27 | print("No fails noted so far.") 28 | return 29 | 30 | with open(paths.black_list_file, 'r') as f: 31 | black_list = f.readlines() 32 | 33 | # print("Attention: Removing is irreversible. Check file names. ") 34 | # remove \n from black list elements 35 | for idx, el in enumerate(black_list): 36 | if "\n" in el: 37 | black_list[idx] = black_list[idx][:-1] 38 | 39 | check = [paths.copy_path_map, paths.dict_all_path, paths.copy_path_song] 40 | endings = ['.egg', '.dat', '_info.dat', '_events.dat', '_obstacles.dat', '_notes.dat'] 41 | 42 | # stack list for endings 43 | black_list_end = [] 44 | for el in black_list: 45 | for ending in endings: 46 | black_list_end.append(el + ending) 47 | 48 | for check_path in check: 49 | for file_name in os.listdir(check_path): 50 | if file_name in black_list_end: 51 | print(f"Delete {file_name}") 52 | os.remove(check_path + file_name) 53 | 54 | 55 | if __name__ == '__main__': 56 | q = input("Reset black list? [y or n]") 57 | if q == 'y': 58 | os.remove(paths.black_list_file) 59 | print("Black list deleted.") 60 | -------------------------------------------------------------------------------- /map_creation/bpm_optimizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def align_beats_on_bpm(timings_in_s, bpm): 5 | """ 6 | Aligns the timings to the nearest beat of the given bpm. 7 | :param timings_in_s: list of timings in seconds 8 | :param bpm: beats per minute 9 | :return: list of timings in seconds aligned to the nearest beat of the given bpm 10 | """ 11 | average_bpm = (timings_in_s[-1] - timings_in_s[0]) / len(timings_in_s) * 60 12 | timings_bs = timings_in_s * bpm / 60 13 | 14 | # try to first match all elements on integer numbers 15 | timings_bs_new = np.round(timings_bs) 16 | timing_helper = timings_bs_new[1:] - timings_bs_new[:-1] 17 | already_done = False 18 | for idx in range(len(timing_helper)): 19 | if timing_helper[idx] != 0 and already_done: 20 | already_done = False 21 | if timing_helper[idx] == 0 and not already_done: 22 | already_done = True 23 | # found a clash, need to check first how many elements 24 | n_clash = 1 25 | for j in range(idx + 1, len(timing_helper)): 26 | if timing_helper[j] != 0: 27 | break 28 | n_clash += 1 29 | # in case the last element is a zero, the previous range would be skipped 30 | if idx + 2 >= len(timings_bs_new): 31 | j = len(timing_helper) - 1 32 | # define start and end time to put clash in between 33 | start_time = timings_bs_new[idx] 34 | if (j + 2) >= len(timings_bs_new): 35 | # use the last element as end time 36 | end_time = timings_bs_new[-1] + 1 37 | else: 38 | end_time = timings_bs_new[j + 1] 39 | 40 | # define new entries 41 | new_entries = np.linspace(start_time, end_time, n_clash + 2) 42 | for _j in range(n_clash): 43 | timings_bs_new[idx + _j + 1] = new_entries[_j + 1] 44 | 45 | # timings_aligned = np.asarray(timings_bs_new, dtype=float) * 60 / bpm 46 | timings_aligned = timings_bs_new * 60 / bpm 47 | return timings_aligned -------------------------------------------------------------------------------- /training/eval_autoenc_music.py: -------------------------------------------------------------------------------- 1 | from helpers import * 2 | from plot_model import run_plot_autoenc 3 | from tensorflow_models import * 4 | from preprocessing.music_processing import run_music_preprocessing 5 | from tools.config import config, paths 6 | from tools.config.mapper_selection import get_full_model_path 7 | 8 | # Setup configuration 9 | ##################### 10 | min_bps_limit = config.min_bps_limit 11 | max_bps_limit = config.max_bps_limit 12 | test_samples = config.test_samples 13 | np.random.seed(3) 14 | 15 | # Data Preprocessing 16 | #################### 17 | # get name array 18 | name_ar, _ = filter_by_bps(min_bps_limit, max_bps_limit) 19 | name_ar = [name_ar[0]] 20 | print(f"Importing {len(name_ar)} song") 21 | 22 | # load song input 23 | song_ar, _ = run_music_preprocessing(name_ar, save_file=False, 24 | song_combined=True, channels_last=True) 25 | 26 | # sample into train/val/test 27 | ds_test = song_ar[:test_samples] 28 | 29 | # Model Building 30 | ################ 31 | 32 | auto_encoder, _ = load_keras_model(get_full_model_path(config.autoenc_version)) 33 | 34 | encoder, _ = load_keras_model(get_full_model_path(config.enc_version)) 35 | 36 | # create model 37 | # if auto_encoder is None: 38 | # encoder = create_keras_model('enc1', learning_rate) 39 | # decoder = create_keras_model('dec1', learning_rate) 40 | # auto_input = Input(shape=(24, 20, 1)) 41 | # encoded = encoder(auto_input) 42 | # decoded = decoder(encoded) 43 | # auto_encoder = Model(auto_input, decoded) 44 | # 45 | # adam = Adam(learning_rate=learning_rate, weight_decay=learning_rate/n_epochs) 46 | # auto_encoder.compile(loss='mean_squared_error', optimizer=adam, metrics=['accuracy']) 47 | # encoder.compile(loss='mean_squared_error', optimizer=adam, metrics=['accuracy']) 48 | 49 | 50 | # Model Evaluation 51 | ################## 52 | print("\nEvaluating test data...") 53 | eval = auto_encoder.evaluate(ds_test, ds_test) 54 | # print(f"Test loss: {eval[0]:.4f}, test accuracy: {eval[1]:.4f}") 55 | 56 | run_plot_autoenc(encoder, auto_encoder, ds_test, save=False) 57 | 58 | print("\nFinished Evaluation") 59 | -------------------------------------------------------------------------------- /training/plot_model.py: -------------------------------------------------------------------------------- 1 | # import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | from tools.config import paths, config 5 | 6 | 7 | def plot_autoenc_results(img_in, img_repr, img_out, n_samples, scale_repr=True, save=False): 8 | bneck_reduction = len(img_repr.flatten()) / len(img_in.flatten()) * 100 9 | print(f"Bottleneck shape: {img_repr.shape}. Reduction to {bneck_reduction:.1f}%") 10 | print("Plot original images vs. reconstruction") 11 | fig, axes = plt.subplots(nrows=3, ncols=n_samples, figsize=(12, 8)) 12 | fig.suptitle(f"Reduction to {bneck_reduction:.1f}%") 13 | if scale_repr: 14 | img_repr -= img_repr.min() 15 | img_repr /= img_repr.max() 16 | 17 | # plot original image 18 | for idx in np.arange(n_samples): 19 | fig.add_subplot(3, n_samples, idx + 1) 20 | plt.imshow(np.transpose(img_in[idx], (0, 1, 2)), cmap='hot') 21 | # plt.imshow(np.transpose(img_in[idx], (1, 2, 0)), cmap='hot') 22 | 23 | # plot bottleneck distribution 24 | if len(img_repr.shape) < 4: 25 | square_bottleneck = int(img_repr.shape[1]/4) if int(img_repr.shape[1]/4) > 0 else 1 26 | img_repr = img_repr.reshape((img_repr.shape[0]), 1, -1, square_bottleneck) 27 | for idx in np.arange(n_samples): 28 | fig.add_subplot(3, n_samples, idx + n_samples + 1) 29 | plt.imshow(np.transpose(img_repr[idx], (1, 2, 0)), cmap='hot') 30 | 31 | # plot output image 32 | for idx in np.arange(n_samples): 33 | fig.add_subplot(3, n_samples, idx + 2*n_samples + 1) 34 | plt.imshow(np.transpose(img_out[idx], (0, 1, 2)), cmap='hot') 35 | 36 | plt.axis('off') 37 | if save: 38 | save_path = f"{paths.model_path}bneck{config.bottleneck_len}_encoder_decoder_example.png" 39 | fig.savefig(save_path) 40 | 41 | plt.show() 42 | 43 | 44 | def run_plot_autoenc(enc_model, auto_model, ds_test, save=False): 45 | # Plot first batch of test images 46 | output = auto_model.predict(ds_test, verbose=0) 47 | repr_out = enc_model.predict(ds_test, verbose=0) 48 | 49 | plot_autoenc_results(ds_test, repr_out, output, len(ds_test), save=save) 50 | -------------------------------------------------------------------------------- /map_creation/artificial_mod.py: -------------------------------------------------------------------------------- 1 | # helper functions for changing the map apart from the original input 2 | from random import random 3 | 4 | from tools.config import config 5 | 6 | 7 | def mirror_notes(n): 8 | def mirror_note_type(nt): 9 | if nt == 0: 10 | # Left (Red) Note to blue 11 | nt = 1 12 | elif nt == 1: 13 | # Right (Blue) Note to red 14 | nt = 0 15 | else: 16 | # stays the same for bombs (nt=3) 17 | pass 18 | return nt 19 | 20 | def mirror_note_index(ni): 21 | if ni == 0: 22 | ni = 3 23 | elif ni == 1: 24 | ni = 2 25 | elif ni == 2: 26 | ni = 1 27 | elif ni == 3: 28 | ni = 0 29 | return ni 30 | 31 | note_type = n[2::4] 32 | # check for a single note side 33 | if len(set(note_type)) == 1: 34 | note_type = note_type[0] 35 | new_nt = mirror_note_type(note_type) 36 | # change the note type 37 | for i in range(int(len(n) / 4)): 38 | n[2 + i * 4] = new_nt 39 | # inverse the lineIndex (column) 40 | for i in range(int(len(n) / 4)): 41 | new_index = mirror_note_index(n[0 + i * 4]) 42 | n[0 + i * 4] = new_index 43 | return n 44 | 45 | 46 | def gimme_more_notes(notes: list): 47 | more_note_probability = config.gimme_more_notes_prob 48 | more_note_prob_increase_diff = 5 # increase the probability linearly after this threshold 49 | if more_note_probability > 0: 50 | extra_note_prob = config.max_speed/4 - more_note_prob_increase_diff 51 | if extra_note_prob > 0: 52 | more_note_probability += 0.1 * 0.1*extra_note_prob 53 | # take list of notes and search for single ones 54 | for i, section in enumerate(notes): 55 | if more_note_probability >= random(): 56 | # skip in case of empty notes 57 | if len(section) < 4: 58 | continue 59 | section_mirror = mirror_notes(section.copy()) 60 | notes[i].extend(section_mirror) 61 | 62 | return notes 63 | -------------------------------------------------------------------------------- /tools/PowerBeats_extension/update_artist_name.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | music_folder_path = r"C:\Users\frede\Desktop\BS_Automapper\Data\prediction\songs_predict\y_old" 4 | list_of_artists = [] 5 | 6 | for i, song in enumerate(os.listdir(music_folder_path)): 7 | src = os.path.join(music_folder_path, song) 8 | ending = song[-4:] 9 | song = song[:-4] 10 | song_split = song.split(" - ") 11 | if len(song_split) == 2: 12 | artist = song_split[0] 13 | song = song_split[1] 14 | print(f"{i} | Artist: {artist}, Song: {song}") 15 | 16 | if artist in list_of_artists: 17 | print(f"Found artist {artist}. Continue.") 18 | continue 19 | if song in list_of_artists: 20 | print(f"Found wrong order. Correcting.") 21 | new_song = artist 22 | new_artist = song 23 | else: 24 | 25 | inp1 = input("Correct? (empty for yes, else specify artist)") 26 | if inp1 == "": 27 | list_of_artists.append(artist) 28 | continue 29 | else: 30 | new_artist = inp1 31 | inp1 = input("Song? (empty for old, else specify song name)") 32 | if inp1 != "": 33 | new_song = inp1 34 | else: 35 | new_song = song_split[1] 36 | else: 37 | print(f"{i} | Song: {song_split[0]}") 38 | new_artist = input("Update Artist name (optional)") 39 | new_song = input("Update song name (optional)") 40 | if new_song == "": 41 | if new_artist == "": 42 | continue 43 | else: 44 | new_song = song_split[0] 45 | 46 | if new_artist != "": 47 | if new_artist not in list_of_artists: 48 | list_of_artists.append(new_artist) 49 | new_song_name = f"{new_artist} - {new_song}{ending}" 50 | else: 51 | new_song_name = f"{new_song}{ending}" 52 | if new_song == "": 53 | print("Warning. Skipping unspecified song name") 54 | continue 55 | dst = os.path.join(music_folder_path, new_song_name) 56 | print(dst) 57 | os.rename(src, dst) 58 | -------------------------------------------------------------------------------- /tests/test_song_metadata.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | ROOT_DIR = Path(__file__).resolve().parents[1] 5 | if str(ROOT_DIR) not in sys.path: 6 | sys.path.insert(0, str(ROOT_DIR)) 7 | 8 | import importlib 9 | import pytest 10 | 11 | from tools.config import config 12 | 13 | 14 | def reload_module(): 15 | module = importlib.import_module("tools.utils.song_metadata") 16 | return importlib.reload(module) 17 | 18 | 19 | def test_extract_metadata_prefers_audio_tags(monkeypatch): 20 | module = reload_module() 21 | 22 | class DummyAudio: 23 | tags = { 24 | "title": ["Some Title"], 25 | "artist": ["Some Artist"], 26 | "genre": ["Electronic"], 27 | } 28 | 29 | monkeypatch.setattr(module, "MutagenFile", lambda *_args, **_kwargs: DummyAudio()) 30 | 31 | metadata = module.extract_metadata("ignored.mp3") 32 | assert metadata == { 33 | "title": "Some Title", 34 | "artist": "Some Artist", 35 | "genre": "Electronic", 36 | } 37 | 38 | 39 | def test_extract_metadata_fallbacks_to_filename(monkeypatch): 40 | module = reload_module() 41 | 42 | class DummyAudio: 43 | tags = {} 44 | 45 | monkeypatch.setattr(module, "MutagenFile", lambda *_args, **_kwargs: DummyAudio()) 46 | 47 | metadata = module.extract_metadata("Artist Name - Song Title.ogg") 48 | assert metadata == {"artist": "Artist Name", "title": "Song Title"} 49 | 50 | 51 | def test_extract_metadata_without_mutagen(monkeypatch): 52 | module = reload_module() 53 | monkeypatch.setattr(module, "MutagenFile", None) 54 | 55 | metadata = module.extract_metadata("Other Artist - Other Song.flac") 56 | assert metadata == {"artist": "Other Artist", "title": "Other Song"} 57 | 58 | 59 | def test_extract_metadata_returns_empty_when_no_match(monkeypatch): 60 | module = reload_module() 61 | 62 | class DummyAudio: 63 | tags = {} 64 | 65 | monkeypatch.setattr(module, "MutagenFile", lambda *_args, **_kwargs: DummyAudio()) 66 | 67 | with monkeypatch.context() as m: 68 | m.setattr(config, "metadata_naming_convention", "{artist} - {title}") 69 | metadata = module.extract_metadata("NotMatchingFilename.mp3") 70 | 71 | assert metadata == {} 72 | -------------------------------------------------------------------------------- /tools/utils/index_find_str.py: -------------------------------------------------------------------------------- 1 | ########################################### 2 | # This script finds the value to a keyword 3 | # in a BeatSaber string dictionary. 4 | ########################################## 5 | # Normalize is used for cleaning non-chars 6 | ########################################## 7 | 8 | 9 | def index_find_str(idx, line, search_string): 10 | idx = line.find(search_string, idx) 11 | idx = line.find(':', idx) + 1 12 | idx_end = line.find(',', idx) 13 | idx_end_help = line.find('}', idx) 14 | if idx_end > idx_end_help or idx_end == -1: 15 | # End of paragraph 16 | idx_end = idx_end_help 17 | 18 | return idx, idx_end 19 | 20 | 21 | def normalize_song_name(song_name: str, check_out: bool) -> str: 22 | if check_out: 23 | # uncheck points for float values 24 | song_name = song_name.replace(".", "") 25 | song_name = song_name.replace("-", "") 26 | 27 | song_name = song_name.replace(",", "") 28 | song_name = song_name.replace('"', "") 29 | song_name = song_name.replace("'", "") 30 | # delete folder placeholder 31 | song_name = song_name.replace("/", "") 32 | song_name = song_name.replace("\\", "") 33 | song_name = song_name.replace("?", "") 34 | song_name = song_name.replace("<", "") 35 | song_name = song_name.replace(">", "") 36 | song_name = song_name.replace("!", "") 37 | song_name = song_name.replace("&", "") 38 | song_name = song_name.replace("%", "") 39 | song_name = song_name.replace(":", "") 40 | song_name = song_name.replace("*", "") 41 | 42 | if song_name[0] == ' ': 43 | song_name = song_name[1:] 44 | 45 | return song_name 46 | 47 | 48 | def return_find_str(idx, line, search_string, check_out=True) -> (str, int): 49 | idx = line.find(search_string, idx) 50 | idx = line.find(':', idx) + 1 51 | idx_end = line.find(',', idx) 52 | idx_end_help = line.find('}', idx) 53 | if idx_end > idx_end_help or idx_end == -1: 54 | # End of paragraph 55 | idx_end = idx_end_help 56 | 57 | value = line[idx:idx_end] 58 | 59 | # remove unknown chars 60 | value = normalize_song_name(value, check_out) 61 | 62 | return value, idx_end 63 | -------------------------------------------------------------------------------- /app_helper/set_app_paths.py: -------------------------------------------------------------------------------- 1 | from app_helper.update_dir_path import update_dir_path 2 | from tools.config import paths, config 3 | from tools.config.check_folder_structure import check_folder_structure 4 | 5 | 6 | def set_app_paths(input_dir): 7 | print(f"Checking directory setup for: {input_dir}") 8 | update_file_paths(input_dir) 9 | check_folder_structure() 10 | 11 | 12 | def update_file_paths(input_dir): 13 | if not input_dir.endswith('/'): 14 | input_dir += '/' 15 | paths.dir_path = input_dir 16 | update_dir_path('tools/config/paths.py', 'dir_path', input_dir) 17 | 18 | ######################## 19 | # input folder structure 20 | paths.model_path = paths.dir_path + "model/" 21 | if config.use_mapper_selection == '' or config.use_mapper_selection is None: 22 | paths.model_path += "general_new/" 23 | else: 24 | paths.model_path += f"{config.use_mapper_selection.lower()}/" 25 | paths.pred_path = paths.dir_path + "prediction/" 26 | paths.train_path = paths.dir_path + "training/" 27 | paths.temp_path = paths.dir_path + "temp/" 28 | 29 | ############################ 30 | # input subfolder structure 31 | paths.copy_path_song = paths.train_path + "songs_egg/" 32 | paths.copy_path_map = paths.train_path + "maps/" 33 | 34 | paths.dict_all_path = paths.train_path + "maps_dict_all/" 35 | 36 | paths.songs_pred = paths.pred_path + "songs_predict/" 37 | 38 | paths.new_map_path = paths.pred_path + "new_map/" 39 | paths.fail_path = paths.train_path + "fail_list/" 40 | paths.diff_path = paths.train_path + "songs_diff/" 41 | paths.song_data = paths.train_path + "song_data/" 42 | 43 | paths.ml_input_path = paths.train_path + "ml_input/" 44 | 45 | paths.diff_ar_file = paths.diff_path + "diff_ar.npy" 46 | paths.name_ar_file = paths.diff_path + "name_ar.npy" 47 | 48 | paths.ml_input_beat_file = paths.ml_input_path + "beat_ar.npy" 49 | paths.ml_input_song_file = paths.ml_input_path + "song_ar.npy" 50 | 51 | paths.black_list_file = paths.fail_path + "black_list.txt" 52 | 53 | paths.notes_classify_dict_file = paths.model_path + "notes_class_dict.pkl" 54 | paths.beats_classify_encoder_file = paths.model_path + "onehot_encoder_beats.pkl" 55 | paths.events_classify_encoder_file = paths.model_path + "onehot_encoder_events.pkl" 56 | -------------------------------------------------------------------------------- /beat_prediction/beat_to_lstm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # from progressbar import ProgressBar 3 | 4 | from tools.config import config 5 | from tools.utils import numpy_shorts 6 | 7 | 8 | def beat_to_lstm(song_input, beat_resampled): 9 | tcn_len = config.tcn_len 10 | 11 | x_tcn_all = None 12 | y_tcn_all = None 13 | 14 | # bar = ProgressBar(max_value=len(song_input)) 15 | 16 | for i_song in range(len(song_input)): 17 | # bar.update(i_song+1) 18 | x = song_input[i_song] 19 | if beat_resampled is not None: 20 | y = beat_resampled[i_song] 21 | 22 | x_tcn = [] 23 | y_tcn = [] 24 | for i in range(x.shape[1] - tcn_len): 25 | x_tcn.append(x[:, i:i+tcn_len]) 26 | if beat_resampled is None: 27 | y_tcn.append(0) 28 | else: 29 | y_tcn.append(y[i+tcn_len-1]) 30 | 31 | x_tcn = np.asarray(x_tcn) 32 | # 3D tensor with shape (batch_size, time_steps, seq_len) 33 | x_tcn = x_tcn.reshape(x_tcn.shape[0], tcn_len, -1) 34 | 35 | y_tcn = np.asarray(y_tcn) 36 | 37 | x_tcn_all = numpy_shorts.np_append(x_tcn_all, x_tcn, 0) 38 | y_tcn_all = numpy_shorts.np_append(y_tcn_all, y_tcn, 0) 39 | 40 | return x_tcn_all, y_tcn_all 41 | 42 | 43 | def beat_to_tcn(song_input, beat_resampled): 44 | tcn_len = config.tcn_len 45 | 46 | # beta: use first sample only 47 | x = song_input[0] 48 | y = beat_resampled[0] 49 | 50 | x_tcn = [] 51 | y_tcn = [] 52 | for i in range(x.shape[1] - tcn_len): 53 | x_tcn.append(x[:, i:i+tcn_len]) 54 | y_tcn.append(y[i+tcn_len-1]) 55 | 56 | x_tcn = np.asarray(x_tcn) 57 | # 3D tensor with shape (batch_size, timesteps, input_dim) 58 | x_tcn = x_tcn.reshape(x_tcn.shape[0], tcn_len, -1) 59 | 60 | y_tcn = np.asarray(y_tcn) 61 | 62 | return x_tcn, y_tcn 63 | 64 | 65 | def last_beats_to_lstm(beats): 66 | tcn_len = config.tcn_len 67 | x_beats = [] 68 | for idx in range(len(beats)): 69 | if idx < tcn_len: 70 | b = np.zeros(tcn_len) 71 | if idx > 0: 72 | b[-idx:] = beats[:idx] 73 | else: 74 | b = beats[idx-tcn_len:idx] 75 | x_beats.append(b) 76 | 77 | x_beats = np.asarray(x_beats) 78 | x_beats = x_beats.reshape(x_beats.shape[0], x_beats.shape[1], 1) 79 | return x_beats 80 | -------------------------------------------------------------------------------- /tools/config/check_folder_structure.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script checks the folder structure. 3 | Missing folders can be automatically created. 4 | """ 5 | 6 | import os 7 | import shutil 8 | 9 | import tools.config.paths as paths 10 | from tools.utils.ask_parameter import ask_parameter 11 | 12 | 13 | def check_folder_structure(): 14 | # generic check if folder exists 15 | def check_exists(data_path, create=True): 16 | if data_path.endswith('/'): 17 | # check if folder exists 18 | exist_flag = os.path.isdir(data_path) 19 | else: 20 | print(f'Missing "/" at the end of folder: {data_path}. Exit') 21 | exit() 22 | if not exist_flag: 23 | if create: 24 | # if ask_parameter(f'Create folder <{data_path}> ? [y or n]', param_type='bool'): 25 | # create missing folder 26 | print(f"Creating folder: {data_path}") 27 | os.makedirs(data_path) 28 | else: 29 | print(f"Missing folder: {data_path}") 30 | return False 31 | return True 32 | 33 | # iterate through folder structure 34 | if not check_exists(paths.dir_path, create=False): 35 | print("Adjust input directory path (dir_path) in config") 36 | exit() 37 | # if not check_exists(paths.bs_song_path, create=False): 38 | # print("Adjust BeatSaber path in config") 39 | # exit() 40 | 41 | check_exists(paths.model_path) 42 | check_exists(paths.pred_path) 43 | check_exists(paths.train_path) 44 | check_exists(paths.temp_path) 45 | 46 | check_exists(paths.copy_path_song) 47 | check_exists(paths.copy_path_map) 48 | 49 | check_exists(paths.dict_all_path) 50 | 51 | check_exists(paths.songs_pred) 52 | 53 | # check_exists(paths.pred_input_path) 54 | check_exists(paths.new_map_path) 55 | if not os.path.isfile(paths.new_map_path + "cover.jpg"): 56 | src = f"{paths.main_path}app_helper/cover.jpg" 57 | dst = f"{paths.new_map_path}cover.jpg" 58 | shutil.copy(src, dst) 59 | 60 | check_exists(paths.fail_path) 61 | check_exists(paths.diff_path) 62 | check_exists(paths.song_data) 63 | 64 | # check_exists(paths.class_maps) 65 | check_exists(paths.ml_input_path) 66 | 67 | print(f"Finished folder setup: {paths.dir_path}") 68 | 69 | 70 | if __name__ == "__main__": 71 | check_folder_structure() 72 | -------------------------------------------------------------------------------- /main_training.py: -------------------------------------------------------------------------------- 1 | # Run all training scripts for a new model 2 | import os 3 | import shutil 4 | import sys 5 | import subprocess 6 | 7 | from tools.config import paths, config 8 | 9 | 10 | from training.helpers import test_gpu_tf 11 | # import map_creation.gen_beats as beat_generator 12 | # from bs_shift.export_map import * 13 | 14 | # Check Cuda compatible GPU 15 | if not test_gpu_tf(): 16 | exit() 17 | 18 | print(f"use_mapper_selection value: {config.use_mapper_selection}") 19 | print(f"use_bpm_selection value: {config.use_bpm_selection}") 20 | if config.use_bpm_selection: 21 | print(f"with bpm limits [{config.min_bps_limit}, {config.max_bps_limit}]") 22 | input("Adapted the mapper_selection and use_bpm_selection in the config file?\n" 23 | "Press enter to continue...") 24 | 25 | # create folder if required 26 | if not os.path.isdir(paths.model_path): 27 | print(f"Creating model folder: {config.use_mapper_selection}") 28 | os.makedirs(paths.model_path) 29 | 30 | else: 31 | if len(os.listdir(paths.model_path)) > 1: 32 | print("Model folder already available. Exit manually to change folder in config.") 33 | input("Continue with same model folder?") 34 | 35 | # TRAINING 36 | ########## 37 | print("Which trainings do you want to start? Reply with y or n for each model.") 38 | run_list = input("1. shift music | 2. music autoencoder | 3. song mapper | 4. beat generator | 5. lights generator | ") 39 | if len(run_list) != 5: 40 | print("Wrong input format. Exit") 41 | exit() 42 | 43 | # run bs_shift / shift.py 44 | if run_list[0].lower() == 'y': 45 | print("Hint 1: To import favorite maps manually start the scripts " 46 | "bs_shift/hashtest.py and bs_shift/copyfavorites.py first.") 47 | print("Hint 2: Before running shift.py, make sure all maps are in " 48 | "the correct format with bs_shift/cleanup_n_format.py first.") 49 | print(f"Analyzing BS music files from folder: {paths.bs_input_path}") 50 | subprocess.call(['python3', './bs_shift/shift.py']) 51 | 52 | # run training / train_autoenc_music.py 53 | # os.system("training/train_autoenc_music.py") 54 | if run_list[1].lower() == 'y': 55 | subprocess.call(['python3', './training/train_autoenc_music.py']) 56 | 57 | # run training / train_bs_automapper.py 58 | if run_list[2].lower() == 'y': 59 | subprocess.call(['python3', './training/train_bs_automapper.py']) 60 | 61 | # run beat_prediction / ai_beat_gen.py 62 | if run_list[3].lower() == 'y': 63 | subprocess.call(['python3', './beat_prediction/ai_beat_gen.py']) 64 | 65 | # run lighting_prediction / train_lighting.py 66 | if run_list[4].lower() == 'y': 67 | subprocess.call(['python3', './lighting_prediction/train_lighting.py']) 68 | -------------------------------------------------------------------------------- /preprocessing/map_info_processing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from tools.config import config, paths 5 | 6 | 7 | def get_mapper_name(name_ar): 8 | if not isinstance(name_ar, list): 9 | name_ar = [name_ar] 10 | 11 | folder = paths.copy_path_map 12 | for name in name_ar: 13 | name = f"{name}_info.dat" 14 | if name not in os.listdir(folder): 15 | print(f"Error: Could not find file: {name}") 16 | exit() 17 | 18 | with open(folder + name, 'r') as f: 19 | info_dict = json.load(f) 20 | level_author = info_dict['_levelAuthorName'] 21 | 22 | return level_author 23 | 24 | 25 | def get_maps_from_mapper(mapper_name_list, ignore_capital=True, full_match=False): 26 | music_list = [] 27 | folder = paths.copy_path_map 28 | for file in os.listdir(folder): 29 | if file.endswith("_info.dat"): 30 | song_name = file[:-9] 31 | level_author = get_mapper_name(song_name) 32 | if isinstance(mapper_name_list, str): 33 | mapper_name_list = [mapper_name_list] 34 | 35 | for mapper_name in mapper_name_list: 36 | if ignore_capital: 37 | if full_match: 38 | if level_author.lower() == mapper_name.lower(): 39 | music_list.append(song_name) 40 | else: 41 | words1 = set(level_author.lower().split()) 42 | words2 = set(mapper_name.lower().split()) 43 | if words2.issubset(words1): 44 | music_list.append(song_name) 45 | else: 46 | if level_author == mapper_name: 47 | music_list.append(song_name) 48 | print(f"Found {len(music_list)} songs from mapper {mapper_name_list}.") 49 | return music_list 50 | 51 | 52 | if __name__ == '__main__': 53 | mapper_name_list = [] 54 | folder = paths.copy_path_map 55 | for file in os.listdir(folder): 56 | if file.endswith("_info.dat"): 57 | song_name = file[:-9] 58 | level_author = get_mapper_name(song_name) 59 | # print(level_author) 60 | mapper_name_list.append(level_author) 61 | from collections import Counter 62 | 63 | string_counts = Counter(mapper_name_list) 64 | sorted_strings = sorted(string_counts.items(), key=lambda x: x[1]) 65 | for string, count in sorted_strings: 66 | print(f"{string}: {count} times") 67 | # mapper_name = get_mapper_name('#ThatPOWER') 68 | # print(mapper_name) 69 | 70 | mapper_name = 'Skyler' 71 | map_list = get_maps_from_mapper(mapper_name) 72 | print(map_list) 73 | -------------------------------------------------------------------------------- /map_creation/find_bpm.py: -------------------------------------------------------------------------------- 1 | from aubio import source, tempo 2 | from numpy import median, diff 3 | 4 | 5 | def get_file_bpm(path, params=None): 6 | """ Calculate the beats per minute (bpm) of a given file. 7 | path: path to the file 8 | param: dictionary of parameters 9 | output also song length 10 | """ 11 | if params is None: 12 | params = {} 13 | # default: 14 | samplerate, win_s, hop_s = 44100, 1024, 512 15 | if 'mode' in params: 16 | if params.mode in ['super-fast']: 17 | # super fast 18 | samplerate, win_s, hop_s = 4000, 128, 64 19 | elif params.mode in ['fast']: 20 | # fast 21 | samplerate, win_s, hop_s = 8000, 512, 128 22 | elif params.mode in ['default']: 23 | pass 24 | else: 25 | raise ValueError("unknown mode {:s}".format(params.mode)) 26 | # manual settings 27 | if 'samplerate' in params: 28 | samplerate = params.samplerate 29 | if 'win_s' in params: 30 | win_s = params.win_s 31 | if 'hop_s' in params: 32 | hop_s = params.hop_s 33 | 34 | s = source(path, samplerate, hop_s) 35 | samplerate = s.samplerate 36 | o = tempo("specdiff", win_s, hop_s, samplerate) 37 | # List of beats, in samples 38 | beats = [] 39 | # Total number of frames read 40 | total_frames = 0 41 | 42 | while True: 43 | samples, read = s() 44 | is_beat = o(samples) 45 | if is_beat: 46 | this_beat = o.get_last_s() 47 | beats.append(this_beat) 48 | # if o.get_confidence() > .2 and len(beats) > 2.: 49 | # break 50 | total_frames += read 51 | if read < hop_s: 52 | break 53 | 54 | def beats_to_bpm(beats, path): 55 | # if enough beats are found, convert to periods then to bpm 56 | if len(beats) > 1: 57 | if len(beats) < 4: 58 | print("few beats found in {:s}".format(path)) 59 | bpms = 60. / diff(beats) 60 | return median(bpms) 61 | else: 62 | print("not enough beats found in {:s}".format(path)) 63 | return 0 64 | 65 | return [beats_to_bpm(beats, path), total_frames / samplerate] 66 | 67 | 68 | if __name__ == '__main__': 69 | import argparse 70 | 71 | parser = argparse.ArgumentParser() 72 | parser.add_argument('-m', '--mode', 73 | help="mode [default|fast|super-fast]", 74 | dest="mode", default='default') 75 | parser.add_argument('sources', 76 | nargs='+', 77 | help="input_files") 78 | args = parser.parse_args() 79 | for f in args.sources: 80 | bpm = get_file_bpm(f, params=args) 81 | print("{:6s} {:s}".format("{:2f}".format(bpm), f)) 82 | -------------------------------------------------------------------------------- /tools/config/mapper_selection.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tools.config import paths, config 4 | 5 | 6 | def return_mapper_list(mapper_shortcut): 7 | if mapper_shortcut == 'curated1': 8 | mapper_list = ['Nuketime', 'Heisenberg', 'Ruckus', 9 | 'Joetastic', 'BennyDaBeast', 'Hexagonial', 10 | 'Ryger', 'Skyler Wallace', 'Uninstaller', 11 | 'Teuflum', 'GreatYazer', 'puds' 12 | 'Moriik', 'Ab', 'DE125', 'Skeelie', 13 | 'Psyc0pathic', 'Hexagonial', 'Electrostats', 14 | 'DankruptMemer', 'StyngMe', 'Rustic', 15 | 'Souk', 'Oddloop', 'Chroma', 'Pendulum', 16 | 'Excession', 'Jez', 'Kry', 't+pazolite'] 17 | 18 | elif mapper_shortcut == 'curated2': 19 | mapper_list = ['Nuketime', 'Ruckus', 'Joetastic', 20 | 'BennyDaBeast', 'Hexagonial', 'Ryger', 21 | 'Skyler Wallace', 'Uninstaller', 'Teuflum', 22 | 'Oddloop', 'Souk', 'Pendulum', 't+pazolite'] 23 | 24 | else: 25 | mapper_list = mapper_shortcut 26 | return mapper_list 27 | 28 | 29 | def get_full_model_path(model_name_partial, full_path=True): 30 | model_folder = paths.model_path 31 | if not os.path.isdir(model_folder): 32 | raise FileNotFoundError(f"Model folder not set up: {model_folder}") 33 | files = os.listdir(model_folder) 34 | if 0 < len(files) < 7: 35 | print(f"Warning: Check content of: {model_folder}. Not enough models found (yet).") 36 | if not (0 < len(files) < 15): # allow some extra files, e.g. extra zip file 37 | raise FileNotFoundError(f"Model save files missing or corrupt. Check content of: {model_folder}") 38 | for f in files: 39 | if f.startswith(model_name_partial): 40 | if full_path: 41 | return os.path.join(model_folder, f) 42 | else: 43 | return f 44 | 45 | raise FileNotFoundError(f"Could not find model {model_name_partial} in {model_folder}") 46 | 47 | 48 | def update_model_file_paths(check_model_exists=True): 49 | # update the file paths for the models if the folder is changed 50 | paths.model_path = paths.dir_path + "model/" 51 | if config.use_mapper_selection == '' or config.use_mapper_selection is None: 52 | paths.model_path += "general_new/" 53 | else: 54 | paths.model_path += f"{config.use_mapper_selection.lower()}/" 55 | paths.notes_classify_dict_file = os.path.join(paths.model_path, "notes_class_dict.pkl") 56 | paths.beats_classify_encoder_file = os.path.join(paths.model_path, "onehot_encoder_beats.pkl") 57 | paths.events_classify_encoder_file = os.path.join(paths.model_path, "onehot_encoder_events.pkl") 58 | 59 | if check_model_exists: 60 | # check that model exists on the example of event generator 61 | _ = get_full_model_path(config.event_gen_version) 62 | print(f"Using model: {config.use_mapper_selection}") 63 | -------------------------------------------------------------------------------- /tools/config/paths.py: -------------------------------------------------------------------------------- 1 | ########################################## 2 | # config file for all paths used in project 3 | ########################################## 4 | # edit directory paths (C:/...) for each PC 5 | # !!! Only use "/" and not "\" 6 | # !!! Always end with "/" 7 | ########################################## 8 | 9 | # /mnt/c/Users/frede/Desktop/BS_Automapper/InfernoSaber---BeatSaber-Automapper 10 | 11 | import os 12 | from tools.config import config 13 | 14 | 15 | ################################# (change this for your pc) 16 | # setup folder for input data (automatically determined if inside this project) 17 | dir_path = "" 18 | 19 | bs_song_path = "" 20 | bs_input_path = "" 21 | 22 | ############################# (no need to change) 23 | # main workspace path 24 | main_path = os.path.abspath(os.getcwd()) 25 | max_tries = 3 26 | for i in range(0, max_tries): 27 | if not os.path.isfile(main_path + '/main.py'): 28 | # not found, search root folder 29 | main_path = os.path.dirname(main_path) 30 | else: 31 | # found main folder 32 | break 33 | 34 | if not os.path.isfile(main_path + '/main.py'): 35 | print(f"dir_path={dir_path}") 36 | print(f"main_path={main_path}") 37 | print("Could not find root directory. Exit") 38 | exit() 39 | main_path += '/' 40 | 41 | ######################## 42 | # input folder structure 43 | model_path = dir_path + "model/" 44 | if config.use_mapper_selection == '' or config.use_mapper_selection is None: 45 | model_path += "general_new/" 46 | else: 47 | model_path += f"{config.use_mapper_selection.lower()}/" 48 | pred_path = dir_path + "prediction/" 49 | train_path = dir_path + "training/" 50 | temp_path = dir_path + "temp/" 51 | 52 | ############################ 53 | # input subfolder structure 54 | copy_path_song = train_path + "songs_egg/" 55 | copy_path_map = train_path + "maps/" 56 | 57 | dict_all_path = train_path + "maps_dict_all/" 58 | # pic_path = train_path + "songs_pic/" 59 | 60 | songs_pred = pred_path + "songs_predict/" 61 | # pic_path_pred = pred_path + "songs_pic_predict/" 62 | 63 | # pred_path = pred_path + "np_pred/" 64 | # pred_input_path = pred_path + "input/" 65 | new_map_path = pred_path + "new_map/" 66 | 67 | fail_path = train_path + "fail_list/" 68 | diff_path = train_path + "songs_diff/" 69 | song_data = train_path + "song_data/" 70 | 71 | # class_maps = train_path + "classify_maps/" 72 | ml_input_path = train_path + "ml_input/" 73 | 74 | diff_ar_file = diff_path + "diff_ar.npy" 75 | name_ar_file = diff_path + "name_ar.npy" 76 | 77 | ml_input_beat_file = ml_input_path + "beat_ar.npy" 78 | ml_input_song_file = ml_input_path + "song_ar.npy" 79 | 80 | black_list_file = fail_path + "black_list.txt" 81 | 82 | notes_classify_dict_file = f"{model_path}notes_class_dict.pkl" 83 | # beats_classify_encoder_file = pred_path + f"onehot_encoder_beats_{config.min_bps_limit}-{config.max_bps_limit}.pkl" 84 | beats_classify_encoder_file = model_path + "onehot_encoder_beats.pkl" 85 | events_classify_encoder_file = model_path + "onehot_encoder_events.pkl" 86 | -------------------------------------------------------------------------------- /lighting_prediction/tf_lighting.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Dense, Input, LSTM, Flatten, Dropout, \ 2 | MaxPooling2D, Conv2D, BatchNormalization, SpatialDropout2D, concatenate, \ 3 | Reshape, Conv2DTranspose, UpSampling2D 4 | from keras.models import Model 5 | # import numpy as np 6 | 7 | 8 | def create_tf_model(model_type, dim_in, dim_out, nr=128): 9 | print("Setup keras model") 10 | if model_type == 'lstm_light': 11 | # in_song (lin), in_time (rec) 12 | input_a = Input(shape=(dim_in[0][1]), name='input_song_enc') 13 | input_b = Input(shape=(dim_in[1][1:]), name='input_time_lstm') 14 | 15 | lstm_out = LSTM(32, return_sequences=False)(input_b) 16 | 17 | x = concatenate([input_a, lstm_out]) 18 | x = Dense(512, activation='relu')(x) 19 | x = Dropout(0.05)(x) 20 | x = Dense(256, activation='sigmoid')(x) 21 | 22 | out = Dense(dim_out[1], activation='softmax', name='output')(x) 23 | 24 | model = Model(inputs=[input_a, input_b], outputs=out) 25 | return model 26 | 27 | elif model_type == 'lstm_full': 28 | # in_song (lin), in_time (rec), in_class (rec) 29 | input_a = Input(shape=(dim_in[0][1]), name='input_song_enc') 30 | input_b = Input(shape=(dim_in[1][1:]), name='input_time_lstm') 31 | input_c = Input(shape=(dim_in[2][1:]), name='input_class_lstm') 32 | 33 | lstm_b = LSTM(32, return_sequences=True)(input_b) 34 | lstm_c = LSTM(256, return_sequences=True)(input_c) 35 | 36 | lstm_in = concatenate([lstm_b, lstm_c]) 37 | lstm_out = LSTM(64, return_sequences=False)(lstm_in) 38 | 39 | x = concatenate([input_a, lstm_out]) 40 | x = Dense(512, activation='relu')(x) 41 | x = Dropout(0.05)(x) 42 | x = Dense(256, activation='sigmoid')(x) 43 | 44 | out = Dense(dim_out[1], activation='softmax', name='output')(x) 45 | 46 | model = Model(inputs=[input_a, input_b, input_c], outputs=out) 47 | return model 48 | 49 | elif model_type == 'lstm_half': 50 | # in_song (lin), in_time (rec), in_class (rec) 51 | input_a = Input(shape=(dim_in[0][1:]), name='input_song_enc') 52 | input_b = Input(shape=(dim_in[1][1:]), name='input_time_lstm') 53 | input_c = Input(shape=(dim_in[2][1:]), name='input_class_lstm') 54 | 55 | conv = Conv2D(32, kernel_size=3, activation='relu')(input_a) 56 | conv = Flatten('channels_last')(conv) 57 | lstm_b = LSTM(32, return_sequences=True)(input_b) 58 | lstm_c = LSTM(64, return_sequences=True)(input_c) 59 | 60 | lstm_in = concatenate([lstm_b, lstm_c]) 61 | lstm_out = LSTM(64, return_sequences=False)(lstm_in) 62 | 63 | x = concatenate([conv, lstm_out]) 64 | x = Dense(1024, activation='relu')(x) 65 | x = Dropout(0.05)(x) 66 | x = Dense(dim_out[1]*dim_out[2], activation='sigmoid')(x) 67 | 68 | out = Reshape(target_shape=(dim_out[1:]))(x) 69 | # out = Dense(dim_out[1:], activation='softmax', name='output')(x) 70 | 71 | model = Model(inputs=[input_a, input_b, input_c], outputs=out) 72 | return model 73 | -------------------------------------------------------------------------------- /beat_prediction/beat_prop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # import madmom 3 | # import matplotlib.pyplot as plt 4 | # from scipy.ndimage.filters import maximum_filter 5 | 6 | from tools.config import config 7 | from tools.utils import numpy_shorts 8 | 9 | 10 | def delete_offbeats(beat_resampled, song_input, x_volume, x_onset): 11 | for i_song in range(len(beat_resampled)): 12 | rd = np.random.rand(len(beat_resampled[i_song])) 13 | rd += beat_resampled[i_song] 14 | 15 | beat_resampled[i_song] = beat_resampled[i_song][rd > config.delete_offbeats] 16 | song_input[i_song] = song_input[i_song][:, rd > config.delete_offbeats] 17 | 18 | x_volume[i_song] = x_volume[i_song][rd > config.delete_offbeats] 19 | x_onset[i_song] = x_onset[i_song][rd > config.delete_offbeats] 20 | 21 | return beat_resampled, song_input, x_volume, x_onset 22 | 23 | 24 | def get_beat_prop(x_song): 25 | # get volume through absolute frequency values 26 | # beat_a = None 27 | # beat_b = None 28 | # for song in x_song: 29 | # set_a = volume_check(song) 30 | # set_a = tcn_reshape(set_a) 31 | # beat_a = numpy_shorts.np_append(beat_a, set_a, 0) 32 | # 33 | # set_b = onset_detection(song) 34 | # set_b = tcn_reshape(set_b) 35 | # beat_b = numpy_shorts.np_append(beat_b, set_b, 0) 36 | beat_a = [] 37 | beat_b = [] 38 | for song in x_song: 39 | set_a = volume_check(song) 40 | beat_a.append(set_a) 41 | 42 | set_b = onset_detection(song) 43 | beat_b.append(set_b) 44 | 45 | return [beat_a, beat_b] 46 | 47 | 48 | def tcn_reshape(x_input): 49 | x_tcn = None 50 | for ar in x_input: 51 | tcn_len = config.tcn_len 52 | ar_out = np.zeros((len(ar) - tcn_len, tcn_len)) 53 | for idx in range(len(ar) - tcn_len): 54 | ar_out[idx] = ar[idx:idx+tcn_len] 55 | 56 | ar_out = ar_out.reshape(ar_out.shape[0], ar_out.shape[1], 1) 57 | x_tcn = numpy_shorts.np_append(x_tcn, ar_out, 0) 58 | return x_tcn 59 | 60 | 61 | def volume_check(x_song): 62 | volume = np.zeros(x_song.shape[1]) 63 | for idx in range(len(volume)): 64 | volume[idx] = x_song[:, idx].sum() 65 | # normalize 66 | volume = numpy_shorts.minmax_3d(volume) 67 | return volume 68 | 69 | 70 | def onset_detection(x_song): 71 | x_song = x_song.T 72 | # sf = madmom.features.onsets.spectral_flux(x_song) 73 | # calculate the difference 74 | diff = np.diff(x_song, axis=0) 75 | # keep only the positive differences 76 | pos_diff = np.maximum(0, diff) 77 | # sum everything to get the spectral flux 78 | sf = np.sum(pos_diff, axis=1) 79 | sf = numpy_shorts.minmax_3d(sf) 80 | 81 | sf = np.hstack((np.zeros(1), sf)) 82 | 83 | # # maximum filter size spreads over 3 frequency bins 84 | # size = (1, 3) 85 | # max_spec = maximum_filter(x_song, size=size) 86 | # diff = np.zeros_like(x_song) 87 | # diff[1:] = (x_song[1:] - max_spec[: -1]) 88 | # pos_diff = np.maximum(0, diff) 89 | # superflux = np.sum(pos_diff, axis=1) 90 | # superflux = numpy_shorts.minmax_3d(superflux) 91 | # 92 | # fig = plt.figure() 93 | # plt.plot(sf, label='sf') 94 | # plt.plot(superflux, linestyle='dashed', label='superflux') 95 | # plt.legend() 96 | # plt.show() 97 | 98 | return sf 99 | -------------------------------------------------------------------------------- /preprocessing/load_dic_dif_casting.py: -------------------------------------------------------------------------------- 1 | #################### 2 | #Load map dictionary 3 | #################### 4 | 5 | import numpy as np 6 | 7 | 8 | def load_dic_dif_casting(paths): 9 | 10 | print("Load map dictionary and difficulty casting") 11 | map_dict_events = [] 12 | map_names = [] 13 | map_dict_notes = [] 14 | map_dict_obstacles = [] 15 | diff_ar = [] 16 | # diff = [] 17 | # dict_all_list = os.listdir(dict_all_path) 18 | name_name, name_idx = load_names_np(paths.pic_path) 19 | 20 | #Get unique names 21 | map_names = [] 22 | map_names_count = [] 23 | for count_maps in name_name: 24 | if len(map_names) == 0 or count_maps != map_names[-1]: 25 | map_names.append(count_maps) 26 | map_names_count.append(1) 27 | else: 28 | map_names_count[-1] = map_names_count[-1]+1 29 | map_names = np.asarray(map_names) 30 | map_names_count = np.asarray(map_names_count) 31 | 32 | ############################################### 33 | # Test 34 | test = len(np.unique(np.asarray(name_name))) 35 | if test != len(map_names): 36 | print("Error when matching map names") 37 | exit() 38 | if sum(map_names_count) != len(name_name): 39 | print("Error when counting map names") 40 | exit() 41 | 42 | ####################################################### 43 | # Load all notes, events and obstacles in correct order 44 | # Cast difficulties for each map segment 45 | ####################################################### 46 | map_names_count_index = 0 47 | map_idx = 0 48 | print("Loading maps input data") 49 | for dict_name in map_names: 50 | # Load notes, events, obstacles, all already divided by bpm from info file! (in real sec) 51 | map_dict_events.append(np.load(paths.dict_all_path + dict_name + "_events.dat"), allow_pickle=True) 52 | map_dict_notes.append(np.load(paths.dict_all_path + dict_name + "_notes.dat"), allow_pickle=True) 53 | map_dict_obstacles.append(np.load(paths.dict_all_path + dict_name + "_obstacles.dat"), allow_pickle=True) 54 | 55 | # Test notes available in song 56 | if map_dict_notes[-1].shape[0] == 0: 57 | print("Could not find notes in " + str(dict_name) + " Exit!") 58 | exit() 59 | 60 | # Cast difficulty for map names in pictures 61 | # return_bps = False 62 | for names_i in name_name: 63 | if names_i == dict_name: 64 | # #time index fitting to window size (starts with 1) 65 | # diff_find_idx = name_idx[map_idx] - 1 66 | # #append difficulty from window size 67 | # diff_ar = diff_find(diff_ar, map_dict_notes[-1], config.window, diff_find_idx, return_bps) 68 | bps_ar = np.load(paths.diff_path + "diff_ar.npy") 69 | names_ar = np.load(paths.diff_path + "name_ar.npy") 70 | diff_ar.append(bps_to_diff(bps_ar, names_ar, names_i)) 71 | map_idx += 1 72 | 73 | # Test difficulty array length for every dict_name 74 | map_names_count_index += 1 75 | if len(diff_ar) != np.sum(map_names_count[:map_names_count_index]): 76 | print("Error: difficulty casting at " + dict_name) 77 | 78 | return map_names, map_names_count, name_name, name_idx, map_dict_notes, map_dict_events, map_dict_obstacles, diff_ar -------------------------------------------------------------------------------- /tools/utils/ask_parameter.py: -------------------------------------------------------------------------------- 1 | #################################################### 2 | # This script creates a prompt and checks input 3 | # Parameter type can be 'int', 'float', 'string' or 4 | # custom list of variables e.g. [5.1, 'yes'] 5 | #################################################### 6 | # Author: Frederic Brenner 7 | # Email: frederic.brenner@tum.de 8 | #################################################### 9 | # Date: 04.2020 10 | #################################################### 11 | 12 | def ask_parameter(parameter, param_type=None): 13 | inp = None 14 | tries = 2 # try max times 15 | input_flag = True 16 | while input_flag: 17 | if tries == 0: 18 | print("Too many false inputs. \nExit") 19 | exit() 20 | inp = input('Enter value for {}: '.format(parameter)) 21 | if inp == '': 22 | print("Empty input not allowed.") 23 | tries -= 1 24 | continue 25 | 26 | if param_type is not None: 27 | 28 | # positive integer > 0 29 | if param_type == 'uint': 30 | if inp.isdigit(): 31 | inp = int(inp) 32 | if inp > 0: 33 | input_flag = False # finished 34 | else: 35 | print("Wrong input, must be greater zero.") 36 | else: 37 | print("Wrong input, only unsigned integer allowed.") 38 | 39 | elif param_type == 'int': 40 | if inp.isdigit(): 41 | inp = int(inp) 42 | input_flag = False # finished 43 | elif inp.startswith('-') and inp[1:].isdigit(): 44 | inp = int(inp) 45 | input_flag = False # finished 46 | else: 47 | print("Wrong input, only integer allowed.") 48 | 49 | elif param_type == 'float': 50 | if inp.replace('.', '', 1).isdigit(): 51 | inp = float(inp) 52 | input_flag = False # finished 53 | else: 54 | print("Wrong input, only float allowed.") 55 | 56 | elif param_type == 'string': 57 | if not inp.isdigit() and not inp.replace('.', '', 1).isdigit(): 58 | input_flag = False # finished 59 | else: 60 | print("Wrong input, only string allowed.") 61 | 62 | elif param_type == 'bool': 63 | if inp.lower() in ['y', 'true']: 64 | inp = True 65 | input_flag = False 66 | elif inp.lower() in ['n', 'false']: 67 | inp = False 68 | input_flag = False 69 | 70 | else: 71 | if inp.isdigit(): 72 | inp = int(inp) 73 | elif inp.replace('.', '', 1).isdigit(): 74 | inp = float(inp) 75 | # list of possible values 76 | if inp in param_type: 77 | input_flag = False # finished 78 | else: 79 | print("Wrong input, only {} allowed.".format(param_type)) 80 | else: 81 | # No parameter check defined -> finished 82 | input_flag = False 83 | 84 | tries -= 1 85 | 86 | print('') 87 | return inp 88 | 89 | 90 | if __name__ == '__main__': 91 | # Test 92 | parameter = 'test_parameter' 93 | # param_type = 'int' 94 | param_type = [0, 1.5, 'test'] 95 | test = ask_parameter(parameter, param_type) 96 | print('{} = {}'.format(parameter, test)) 97 | -------------------------------------------------------------------------------- /training/helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import glob 4 | import os 5 | from keras.models import load_model 6 | 7 | # Get the main script's directory 8 | import sys 9 | script_dir = os.path.dirname(os.path.realpath(__file__)) 10 | parent_dir = os.path.abspath(os.path.join(script_dir, "..")) 11 | sys.path.append(parent_dir) 12 | 13 | from tools.utils.load_and_save import load_npy 14 | from tools.config import paths, config 15 | from tools.config.mapper_selection import return_mapper_list, get_full_model_path 16 | from preprocessing.map_info_processing import get_maps_from_mapper 17 | 18 | 19 | def ai_encode_song(song): 20 | # Load pretrained model 21 | encoder_path = get_full_model_path(config.enc_version) 22 | encoder = load_model(encoder_path) 23 | # apply autoencoder to input 24 | in_song_l = encoder.predict(song, verbose=0) 25 | return in_song_l 26 | 27 | 28 | def filter_by_bps(min_limit=None, max_limit=None): 29 | if config.use_bpm_selection: 30 | print("Importing maps by BPM") 31 | # return songs in difficulty range 32 | diff_ar = load_npy(paths.diff_ar_file) 33 | name_ar = load_npy(paths.name_ar_file) 34 | 35 | if min_limit is not None: 36 | selection = diff_ar > min_limit 37 | name_ar = name_ar[selection] 38 | diff_ar = diff_ar[selection] 39 | if max_limit is not None: 40 | selection = diff_ar < max_limit 41 | name_ar = name_ar[selection] 42 | diff_ar = diff_ar[selection] 43 | else: 44 | print(f"Importing maps by mapper: {config.use_mapper_selection}") 45 | mapper_name = return_mapper_list(config.use_mapper_selection) 46 | name_ar = get_maps_from_mapper(mapper_name) 47 | diff_ar = np.ones_like(name_ar, dtype='float')*config.min_bps_limit 48 | 49 | return list(name_ar), list(diff_ar) 50 | 51 | 52 | def test_gpu_tf(): 53 | if tf.test.gpu_device_name(): 54 | print('Default GPU Device: {}'.format(tf.test.gpu_device_name())) 55 | return True 56 | else: 57 | print('Warning: No GPU found.\n Continue with CPU') 58 | # input('Continue with CPU?') 59 | return True 60 | # tf.test.is_built_with_cuda() 61 | # print(tf.config.list_physical_devices()) 62 | return False 63 | 64 | 65 | def load_keras_model(save_model_name, lr=None): 66 | model = None 67 | # print("Load keras model from disk") 68 | if save_model_name == "old": 69 | keras_models = glob.glob(paths.model_path + "*.h5") 70 | latest_file = max(keras_models, key=os.path.getctime) 71 | else: 72 | if not save_model_name.startswith(paths.model_path): 73 | latest_file = paths.model_path + save_model_name 74 | else: 75 | latest_file = save_model_name 76 | if not latest_file.endswith('.h5'): 77 | latest_file += '.h5' 78 | 79 | if os.path.isfile(latest_file): 80 | model = load_model(latest_file) 81 | latest_file = os.path.basename(latest_file) 82 | if config.verbose_level > 4: 83 | print("Keras model loaded: " + latest_file) 84 | else: 85 | print(f"Could not find model on disk: {latest_file}") 86 | print("Creating new model...") 87 | return None, save_model_name 88 | 89 | # # print(K.get_value(model.optimizer.lr)) 90 | # if lr is not None: 91 | # K.set_value(model.optimizer.lr, lr) 92 | # print("Set learning rate to: " + str(K.get_value(model.optimizer.lr))) 93 | return model, latest_file 94 | 95 | 96 | def categorical_to_class(cat_ar): 97 | cat_num = np.argmax(cat_ar, axis=-1) 98 | # cat_num = np.asarray(cat_num) 99 | return cat_num 100 | 101 | 102 | def calc_class_weight(np_ar): 103 | classes, counts = np.unique(np_ar, return_counts=True) 104 | 105 | counts = counts.min() / counts 106 | # counts -= 0.1 107 | class_weight = {} 108 | for i, cls in enumerate(classes): 109 | class_weight[cls] = counts[i] 110 | 111 | print(f"Weight matrix: {class_weight}") 112 | 113 | return class_weight 114 | -------------------------------------------------------------------------------- /bs_shift/cleanup_n_format.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from tools.config import config, paths 5 | 6 | 7 | def read_json_content_file(file_path: str, filename="") -> list[str]: 8 | if filename != "": 9 | file_path = os.path.join(file_path, filename) 10 | try: 11 | with open(file_path) as f: 12 | dat_content = json.load(f) 13 | except Exception as e: 14 | print(f"Could not read file: {file_path}. Check file manually. Exit") 15 | print(f"Error: {e.args}") 16 | exit() 17 | return dat_content 18 | 19 | 20 | def get_difficulty_file_names(info_file_path: str) -> dict: 21 | # import dat file 22 | dat_content = read_json_content_file(info_file_path) 23 | 24 | beatmap_set_dict = dat_content['_difficultyBeatmapSets'] 25 | i = -1 26 | beatmap_dict = beatmap_set_dict[i] 27 | while not beatmap_dict['_beatmapCharacteristicName'] == 'Standard': 28 | i -= 1 29 | if abs(i) > len(beatmap_set_dict): 30 | print(f"Error: Could not find Standard beatmap key in {info_file_path}") 31 | exit() 32 | beatmap_dict = beatmap_set_dict[i] 33 | 34 | beatmap_dict = beatmap_dict['_difficultyBeatmaps'] 35 | diff_file_names = {} 36 | for diff_dict in beatmap_dict: 37 | diff_file_names[diff_dict['_difficulty']] = diff_dict['_beatmapFilename'] 38 | return diff_file_names 39 | 40 | # expert_plus_dict = beatmap_dict[-1] 41 | # expert_plus_name = expert_plus_dict['_beatmapFilename'] 42 | # return expert_plus_name 43 | 44 | 45 | def check_info_name(bs_song_path): 46 | for root, dirs, files in os.walk(bs_song_path): 47 | for file in files: 48 | if file.lower().endswith("info.dat"): 49 | if not file == "info.dat": 50 | if file.lower().startswith("bpm"): 51 | src_path = os.path.join(root, file) 52 | os.remove(src_path) 53 | else: 54 | src_path = os.path.join(root, file) 55 | dst_path = os.path.join(root, "info.dat") 56 | os.rename(src_path, dst_path) 57 | 58 | 59 | def check_beatmap_name(bs_song_path): 60 | # expected_name = "ExpertPlus.dat" 61 | for root, dirs, files in os.walk(bs_song_path): 62 | 63 | if len(files) <= 2: 64 | continue 65 | diff_file_names = get_difficulty_file_names(f"{root}/info.dat") 66 | for key, file_name in diff_file_names.items(): 67 | src_path = os.path.join(root, file_name) 68 | if not os.path.isfile(src_path): 69 | # print(f"Error: Missing map file: {src_path}. Skipping") 70 | continue 71 | dst_path = os.path.join(root, f"{key}.dat") 72 | os.rename(src_path, dst_path) 73 | 74 | # exp_plus_name = get_difficulty_file_names(f"{root}/info.dat") 75 | # else: 76 | # continue 77 | # if exp_plus_name != expected_name: 78 | # src_path = os.path.join(root, exp_plus_name) 79 | # if os.path.isfile(src_path): 80 | # dst_path = os.path.join(root, expected_name) 81 | # os.rename(src_path, dst_path) 82 | # # else: already renamed 83 | 84 | 85 | def check_info_content(bs_song_path): 86 | expected_name = "info.dat" 87 | for folders in os.listdir(bs_song_path): 88 | info_file = os.path.join(bs_song_path, folders, expected_name) 89 | if os.path.isfile(info_file): 90 | # pretty print / overwrite info data to allow line search 91 | dat_content = read_json_content_file(info_file) 92 | with open(info_file, "w") as f: 93 | json.dump(dat_content, f, indent=9) 94 | 95 | 96 | def clean_songs(): 97 | bs_song_path = paths.bs_input_path 98 | print("Warning: This script is not fully tested and might break some song files.\n" 99 | "Do not use on your original beat saber folder!") 100 | print(f"Cleanup folder: {bs_song_path}") 101 | input("Continue with Enter") 102 | 103 | check_info_name(bs_song_path) 104 | 105 | check_beatmap_name(bs_song_path) 106 | 107 | check_info_content(bs_song_path) 108 | 109 | 110 | if __name__ == "__main__": 111 | clean_songs() 112 | -------------------------------------------------------------------------------- /app_helper/check_input.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tools.config import paths, config 4 | 5 | 6 | bs_folder_name = "Beat Saber_Data/CustomLevels" 7 | bs_folder_name2 = "SharedMaps/CustomLevels" 8 | 9 | 10 | def check_int_input(inp, start=1, end=10): 11 | if isinstance(inp, int): 12 | if start <= inp <= end: 13 | return True 14 | return False 15 | 16 | 17 | def check_float_input(inp, start=0.0, end=10.0): 18 | if isinstance(inp, float): 19 | if start <= inp <= end: 20 | return True 21 | return False 22 | 23 | 24 | def check_str_input(inp, min_len=5): 25 | if isinstance(inp, str): 26 | if len(inp) >= min_len: 27 | return True 28 | return False 29 | 30 | 31 | def get_summary(diff1=config.difficulty_1, diff2 = config.difficulty_2, diff3 = config.difficulty_3, 32 | diff4 = config.difficulty_4, diff5 = config.difficulty_5) -> str: 33 | log = [] 34 | if not isinstance(diff1, float) and not isinstance(diff1, int): 35 | try: 36 | diff1 = diff1.value 37 | diff2 = diff2.value 38 | diff3 = diff3.value 39 | diff4 = diff4.value 40 | diff5 = diff5.value 41 | except: 42 | print("Error: Could not convert difficulty values to numbers.") 43 | 44 | # log number of songs found 45 | try: 46 | files = os.listdir(paths.songs_pred) 47 | song_list = [] 48 | for file_name in files: 49 | ending = file_name.split('.')[-1] 50 | if ending in ['mp3', 'mp4', 'm4a', 'wav', 'aac', 'flv', 'wma', 'ogg', 'egg']: 51 | song_list.append(file_name) 52 | if len(song_list) > 0: 53 | log.append(f"Info: Found {len(song_list)} song(s).") 54 | else: 55 | log.append("Error: Found 0 songs. Please go to first tab or manually copy them to the input folder.") 56 | except OSError: 57 | log.append("Error: Data folder not found.") 58 | return "\n".join(log) 59 | 60 | # check export functionality 61 | filename = paths.bs_song_path 62 | # filename = filename.replace('\\\\', '/').replace('\\', '/') 63 | if os.path.isdir(filename) and bs_folder_name in filename: 64 | log.append("Info: Beat Saber folder found. Maps will be exported by default") 65 | elif os.path.isdir(filename) and bs_folder_name2 in filename: 66 | log.append("Info: Beat Saber folder found. Maps will be exported by default") 67 | else: 68 | log.append("Info: Beat Saber folder not found. Link it in the first tab to automatically export maps.") 69 | 70 | # check difficulty rating 71 | diff_count = 0 72 | diff_count_values = [] 73 | if not isinstance(diff1, float) and not isinstance(diff1, int): 74 | log.append("Error: Difficulty 1 is not set. If not required, set it to 0") 75 | else: 76 | if diff1 > 0: 77 | diff_count += 1 78 | diff_count_values.append(diff1) 79 | if not isinstance(diff2, float) and not isinstance(diff2, int): 80 | log.append("Error: Difficulty 2 is not set. If not required, set it to 0") 81 | else: 82 | if diff2 > 0: 83 | diff_count += 1 84 | diff_count_values.append(diff2) 85 | if not isinstance(diff3, float) and not isinstance(diff3, int): 86 | log.append("Error: Difficulty 3 is not set. If not required, set it to 0") 87 | else: 88 | if diff3 > 0: 89 | diff_count += 1 90 | diff_count_values.append(diff3) 91 | if not isinstance(diff4, float) and not isinstance(diff4, int): 92 | log.append("Error: Difficulty 4 is not set. If not required, set it to 0") 93 | else: 94 | if diff4 > 0: 95 | diff_count += 1 96 | diff_count_values.append(diff4) 97 | if not isinstance(diff5, float) and not isinstance(diff5, int): 98 | log.append("Error: Difficulty 5 is not set. If not required, set it to 0") 99 | else: 100 | if diff5 > 0: 101 | diff_count += 1 102 | diff_count_values.append(diff5) 103 | diff_count_values = [str(diff_float) for diff_float in diff_count_values] 104 | log.append(f"Info: Generating {diff_count} difficulties for each song: [{', '.join(diff_count_values)}]. " 105 | f"Total runs: {diff_count * len(song_list)}") 106 | 107 | return "\n".join(log) 108 | -------------------------------------------------------------------------------- /preprocessing/bs_mapper_pre.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.preprocessing import OneHotEncoder 3 | import pickle 4 | 5 | from preprocessing.beat_data_helper import * 6 | from tools.config import paths, config 7 | from tools.utils.numpy_shorts import reduce_number_of_songs 8 | from training.helpers import filter_by_bps 9 | from preprocessing.music_processing import run_music_preprocessing 10 | # from tools.utils import numpy_shorts 11 | 12 | 13 | # Setup configuration 14 | np.random.seed(config.random_seed) 15 | min_bps_limit = config.min_bps_limit 16 | max_bps_limit = config.max_bps_limit 17 | 18 | 19 | def lstm_shift(song_in, time_in, ml_out): 20 | n_samples = len(time_in) 21 | lstm_len = config.lstm_len 22 | start = lstm_len + 1 23 | 24 | # ml_out 25 | if ml_out is None: 26 | l_ml_out = None 27 | else: 28 | l_ml_out = ml_out[start:] 29 | l_out_in = [] 30 | # time in 31 | l_time_in = [] 32 | 33 | for idx in range(start, n_samples): 34 | if ml_out is not None: 35 | l_out_in.append(ml_out[idx-start:idx-1]) 36 | l_time_in.append(time_in[idx-start:idx-1]) 37 | 38 | l_time_in = np.asarray(l_time_in).reshape((-1, lstm_len, 1)) 39 | 40 | l_out_in = np.asarray(l_out_in) 41 | # l_out_in = l_out_in.reshape(l_out_in.shape[0], 1, lstm_len, -1) 42 | # song_in 43 | song_in = song_in[start:] 44 | 45 | return [song_in, l_time_in, l_out_in], l_ml_out 46 | 47 | 48 | def load_beat_data(name_ar: list, return_notes=False): 49 | print("Loading maps input data") 50 | map_dict_notes, _, _ = load_raw_beat_data(name_ar) 51 | notes_ar, time_ar = sort_beats_by_time(map_dict_notes) 52 | if return_notes: 53 | return notes_ar, time_ar 54 | beat_class = cluster_notes_in_classes(notes_ar) 55 | return beat_class, time_ar 56 | 57 | 58 | def load_ml_data(train=True): 59 | 60 | # get name array 61 | name_ar, _ = filter_by_bps(min_bps_limit, max_bps_limit) 62 | 63 | # Reduce amount of songs 64 | name_ar = reduce_number_of_songs(name_ar, hard_limit=config.mapper_song_limit) 65 | 66 | # load beats (output) 67 | beat_ar, time_ar = load_beat_data(name_ar) 68 | 69 | # load song (input) 70 | song_ar, rm_index = run_music_preprocessing(name_ar, time_ar, save_file=False, 71 | song_combined=False) 72 | 73 | # filter invalid indices 74 | idx = 0 75 | for rm_idx in rm_index: 76 | if len(rm_idx) > 0: 77 | # remove invalid songs 78 | name_ar.pop(idx) 79 | # diff_ar.pop(idx) 80 | beat_ar.pop(idx) 81 | song_ar.pop(idx) 82 | time_ar.pop(idx) 83 | else: 84 | idx += 1 85 | 86 | # calculate time between 87 | timing_ar = calc_time_between_beats(time_ar) 88 | 89 | song_input = song_ar[0] 90 | time_input = np.asarray(timing_ar[0], dtype='float16') 91 | ml_output = np.asarray(beat_ar[0]) 92 | 93 | for idx in range(1, len(song_ar)): 94 | song_input = np.vstack((song_input, song_ar[idx])) 95 | time_input = np.hstack((time_input, np.asarray(timing_ar[idx], dtype='float16'))) 96 | ml_output = np.hstack((ml_output, np.asarray(beat_ar[idx]))) 97 | 98 | # onehot encode output 99 | ml_output = ml_output.reshape(-1, 1) 100 | ml_output = onehot_encode(ml_output) 101 | ml_output = ml_output.toarray() 102 | 103 | return [song_input, time_input], ml_output 104 | 105 | 106 | def onehot_encode(ml_output): 107 | encoder = OneHotEncoder(dtype=int) 108 | encoder.fit(ml_output) 109 | ml_output = encoder.transform(ml_output) 110 | 111 | # save onehot encoder 112 | save_path = paths.beats_classify_encoder_file 113 | with open(save_path, "wb") as enc_file: 114 | pickle.dump(encoder, enc_file) 115 | # return ml data 116 | return ml_output 117 | 118 | 119 | def calc_time_between_beats(time_ar): 120 | # default time for start 121 | dft_time = 1 122 | timing_input = [] 123 | 124 | for song in time_ar: 125 | temp = np.concatenate(([dft_time], np.diff(song)), axis=0) 126 | # temp = [] 127 | # for idx in range(len(song)): 128 | # if idx == 0: 129 | # timing = dft_time 130 | # else: 131 | # timing = song[idx] - song[idx-1] 132 | # temp.append(timing) 133 | timing_input.append(list(temp)) 134 | return timing_input 135 | 136 | 137 | if __name__ == '__main__': 138 | load_ml_data() 139 | -------------------------------------------------------------------------------- /tools/utils/song_metadata.py: -------------------------------------------------------------------------------- 1 | """Utilities for extracting and storing song metadata.""" 2 | 3 | from __future__ import annotations 4 | 5 | import json 6 | import os 7 | from typing import Dict, Optional 8 | import re 9 | from pathlib import Path 10 | 11 | from tools.config import paths, config 12 | 13 | try: 14 | from mutagen import File as MutagenFile 15 | except ImportError: # pragma: no cover - gracefully handle optional dependency 16 | MutagenFile = None 17 | 18 | 19 | _METADATA_KEYS = ("title", "artist", "album", "genre") 20 | 21 | 22 | def _sanitize_metadata(metadata: Dict[str, Optional[str]]) -> Dict[str, str]: 23 | """Return a copy of *metadata* with falsy entries removed and values cast to strings.""" 24 | 25 | if not metadata: 26 | return {} 27 | 28 | sanitized: Dict[str, str] = {} 29 | for key, value in metadata.items(): 30 | if value is None: 31 | continue 32 | if isinstance(value, (list, tuple)): 33 | if not value: 34 | continue 35 | value = value[0] 36 | text = str(value).strip() 37 | if text: 38 | sanitized[key] = text 39 | return sanitized 40 | 41 | 42 | def _metadata_from_filename(file_path: str) -> Dict[str, str]: 43 | """Attempt to derive song metadata from *file_path* using configured conventions.""" 44 | 45 | if not getattr(config, "enable_auto_metadata", False): 46 | return {} 47 | 48 | convention = getattr(config, "metadata_naming_convention", "") 49 | if not convention or "{artist}" not in convention or "{title}" not in convention: 50 | return {} 51 | 52 | filename = Path(file_path).stem 53 | 54 | pattern = re.escape(convention) 55 | pattern = pattern.replace(r"\{artist\}", r"(?P.+?)") 56 | pattern = pattern.replace(r"\{title\}", r"(?P.+?)") 57 | 58 | match = re.fullmatch(pattern, filename) 59 | if not match: 60 | return {} 61 | 62 | metadata: Dict[str, Optional[str]] = { 63 | key: (value.strip() if value is not None else value) 64 | for key, value in match.groupdict().items() 65 | } 66 | return _sanitize_metadata(metadata) 67 | 68 | 69 | def extract_metadata(file_path: str) -> Dict[str, str]: 70 | """Extract metadata from *file_path* if possible. 71 | 72 | Returns an empty dictionary if the metadata cannot be read and the filename does not 73 | match the configured metadata naming convention. 74 | """ 75 | 76 | metadata: Dict[str, Optional[str]] = {} 77 | 78 | if MutagenFile is not None: 79 | try: 80 | audio = MutagenFile(file_path, easy=True) 81 | except Exception: 82 | audio = None 83 | 84 | if audio and getattr(audio, "tags", None): 85 | for key in _METADATA_KEYS: 86 | value = audio.tags.get(key) 87 | if value: 88 | metadata[key] = value 89 | 90 | sanitized = _sanitize_metadata(metadata) 91 | if sanitized: 92 | return sanitized 93 | 94 | return _metadata_from_filename(file_path) 95 | 96 | 97 | def save_metadata(name: str, metadata: Dict[str, Optional[str]]) -> None: 98 | """Persist *metadata* for the song *name* (without extension).""" 99 | 100 | sanitized = _sanitize_metadata(metadata) 101 | metadata_path = os.path.join(paths.songs_pred, f"{name}.json") 102 | 103 | if sanitized: 104 | os.makedirs(paths.songs_pred, exist_ok=True) 105 | with open(metadata_path, "w", encoding="utf-8") as file: 106 | json.dump(sanitized, file, ensure_ascii=False, indent=2) 107 | elif os.path.exists(metadata_path): 108 | os.remove(metadata_path) 109 | 110 | 111 | def load_metadata(name: str) -> Dict[str, str]: 112 | """Load persisted metadata for the song *name* (without extension).""" 113 | 114 | metadata_path = os.path.join(paths.songs_pred, f"{name}.json") 115 | if not os.path.isfile(metadata_path): 116 | return {} 117 | 118 | try: 119 | with open(metadata_path, "r", encoding="utf-8") as file: 120 | data = json.load(file) 121 | except (OSError, json.JSONDecodeError): 122 | return {} 123 | 124 | if not isinstance(data, dict): 125 | return {} 126 | 127 | return _sanitize_metadata(data) 128 | 129 | 130 | def metadata_to_tags(metadata: Dict[str, Optional[str]] | None) -> Optional[Dict[str, str]]: 131 | """Prepare *metadata* for embedding into an audio file.""" 132 | 133 | return _sanitize_metadata(metadata or {}) or None 134 | 135 | -------------------------------------------------------------------------------- /training/eval_bs_automapper.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from datetime import datetime 3 | from tensorflow import keras 4 | from keras.optimizers import Adam 5 | from tabulate import tabulate 6 | 7 | from helpers import * 8 | from lighting_prediction.train_lighting import lstm_shift_events_half 9 | from tensorflow_models import * 10 | from preprocessing.bs_mapper_pre import load_ml_data, lstm_shift 11 | from tools.config import config, paths 12 | from tools.config.mapper_selection import get_full_model_path 13 | 14 | # Check Cuda compatible GPU 15 | if not test_gpu_tf(): 16 | exit() 17 | 18 | # Setup configuration 19 | ##################### 20 | min_bps_limit = config.min_bps_limit 21 | max_bps_limit = config.max_bps_limit 22 | learning_rate = config.map_learning_rate 23 | n_epochs = config.map_n_epochs 24 | batch_size = config.map_batch_size 25 | test_samples = config.map_test_samples 26 | np.random.seed(3) 27 | 28 | # Data Preprocessing 29 | #################### 30 | ml_input, ml_output = load_ml_data() 31 | # ml_input, ml_output = lstm_shift(ml_input[0], ml_input[1], ml_output) 32 | # [in_song, in_time_l, in_class_l] = ml_input 33 | # in_song_l = ai_encode_song(in_song) 34 | in_song_l = ai_encode_song(ml_input[0]) 35 | ml_input, ml_output = lstm_shift_events_half(in_song_l, ml_input[1], ml_output, config.lstm_len) 36 | [in_song_l, in_time_l, in_class_l] = ml_input 37 | 38 | # Sample into train/val/test 39 | ############################ 40 | last_test_samples = len(in_song_l) - test_samples 41 | # use last samples as test data 42 | in_song_test = in_song_l[last_test_samples:] 43 | in_time_test = in_time_l[last_test_samples:] 44 | in_class_test = in_class_l[last_test_samples:] 45 | out_class_test = ml_output[last_test_samples:] 46 | 47 | in_song_train = in_song_l[:last_test_samples] 48 | in_time_train = in_time_l[:last_test_samples] 49 | in_class_train = in_class_l[:last_test_samples] 50 | out_class_train = ml_output[:last_test_samples] 51 | 52 | # normal lstm lstm 53 | ds_train = [in_song_train, in_time_train, in_class_train] 54 | ds_test = [in_song_test, in_time_test, in_class_test] 55 | 56 | ds_train_sample = [in_song_train[:test_samples], in_time_train[:test_samples], in_class_train[:test_samples]] 57 | 58 | dim_in = [in_song_train[0].shape, in_time_train[0].shape, in_class_train[0].shape] 59 | dim_out = out_class_train.shape[1] 60 | 61 | # delete variables to free ram 62 | # keras.backend.clear_session() 63 | # del encoder 64 | del in_class_l 65 | del in_class_test 66 | # del in_song 67 | del in_song_l 68 | del in_song_test 69 | del in_song_train 70 | del in_time_l 71 | del in_time_test 72 | del in_time_train 73 | del ml_input 74 | del ml_output 75 | gc.collect() 76 | 77 | # Create model 78 | ############## 79 | save_model_name = get_full_model_path(config.mapper_version) 80 | # load model 81 | mapper_model, save_model_name = load_keras_model(save_model_name) 82 | # create model 83 | if mapper_model is None: 84 | mapper_model = create_keras_model('lstm1', dim_in, dim_out) 85 | adam = Adam(learning_rate=learning_rate, weight_decay=learning_rate / n_epochs) 86 | # mapper_model.compile(loss='mean_squared_error', optimizer=adam, metrics=['accuracy']) 87 | mapper_model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy']) 88 | 89 | 90 | # Evaluate model 91 | ################ 92 | command_len = 10 93 | print("Validate model with test data...") 94 | validation = mapper_model.evaluate(x=ds_test, y=out_class_test) 95 | pred_result = mapper_model.predict(x=ds_test, verbose=0) 96 | 97 | pred_class = categorical_to_class(pred_result) 98 | real_class = categorical_to_class(out_class_test) 99 | 100 | if test_samples % command_len == 0: 101 | pred_class = pred_class.reshape(-1, command_len) 102 | real_class = real_class.reshape(-1, command_len) 103 | 104 | print(tabulate([['Pred', pred_class], ['Real', real_class]], 105 | headers=['Type', 'Result (test data)'])) 106 | 107 | print("Validate model with train data...") 108 | validation = mapper_model.evaluate(x=ds_train_sample, y=out_class_train[:test_samples]) 109 | 110 | pred_result = mapper_model.predict(x=ds_train_sample, verbose=0) 111 | pred_class = categorical_to_class(pred_result) 112 | real_class = categorical_to_class(out_class_train[:test_samples]) 113 | 114 | if test_samples % command_len == 0: 115 | pred_class = pred_class.reshape(-1, command_len) 116 | real_class = real_class.reshape(-1, command_len) 117 | 118 | print(tabulate([['Pred', pred_class], ['Real', real_class]], 119 | headers=['Type', 'Result (train data)'])) 120 | -------------------------------------------------------------------------------- /beat_prediction/find_beats.py: -------------------------------------------------------------------------------- 1 | import os 2 | import aubio 3 | # import madmom 4 | import numpy as np 5 | # import matplotlib.pyplot as plt 6 | 7 | from tools.config import config, paths 8 | from preprocessing.music_processing import log_specgram 9 | 10 | 11 | def find_beats(name_ar, train_data=True): 12 | aubio_pitch = aubio.pitch(samplerate=config.samplerate_music) 13 | if train_data: 14 | folder_path = paths.copy_path_song 15 | else: 16 | folder_path = paths.songs_pred 17 | # import song from disk 18 | ####################### 19 | ending = ".egg" 20 | song_ar = [] 21 | pitch_times_ar = [] 22 | for idx, n in enumerate(name_ar): 23 | n = folder_path + n 24 | if not n.endswith(ending): 25 | n += ending 26 | 27 | # analyze song pitches 28 | total_read = 0 29 | pitch_list = [] 30 | samples_list = [] 31 | src = aubio.source(n, channels=1, samplerate=config.samplerate_music) 32 | pit_mean_list = [] 33 | while True: 34 | samples, read = src() 35 | pit_mean = np.mean(np.abs(samples)) + np.max(np.abs(samples))/config.factor_pitch_meanmax 36 | pit = aubio_pitch(samples) 37 | samples_list.extend(samples) 38 | pitch_list.extend(pit) 39 | pit_mean_list.append(pit_mean) 40 | total_read += read 41 | if read < src.hop_size: 42 | break 43 | 44 | # get pitch times 45 | # pitch_times = get_pitch_times(pitch_list, src.hop_size) 46 | if True: 47 | # reformat pitch list 48 | pitch_list = np.asarray(pitch_list) 49 | pitch_list = np.log(pitch_list + 1) 50 | pitch_list *= config.factor_pitch_certainty / pitch_list.max() 51 | pit_mean_list = np.asarray(pit_mean_list) 52 | pit_mean_list -= pit_mean_list.min() 53 | pit_mean_list /= pit_mean_list.max() 54 | pitch_times_ar.append(pitch_list + pit_mean_list) 55 | else: 56 | pitch_times_ar.append(pitch_list) 57 | 58 | # logarithmic spectrogram of song 59 | window_size = 35.608 60 | step_size = 1 61 | _, spect = log_specgram(np.asarray(samples_list), config.samplerate_music, window_size, step_size=step_size) 62 | spect = spect.T 63 | song_ar.append(spect) 64 | # # test spectogram 65 | # import matplotlib.pyplot as plt 66 | # plt.imshow(spect) 67 | # plt.show() 68 | 69 | for idx in range(len(song_ar)): 70 | pitch_len = len(pitch_times_ar[idx]) 71 | song_ar_len = song_ar[idx].shape[1] 72 | diff = abs(pitch_len - song_ar_len) 73 | if diff > 10: 74 | print(f"!Error: Mismatched pitch input: {song_ar[idx].shape[1]} vs {len(pitch_times_ar[idx])}") 75 | if song_ar_len > pitch_len: 76 | song_ar[idx] = song_ar[idx][:, :-diff] 77 | elif song_ar_len < pitch_len: 78 | pitch_times_ar[idx] = pitch_times_ar[idx][:-diff] 79 | 80 | return song_ar, pitch_times_ar 81 | 82 | 83 | def get_silent_times(song_input, timings): 84 | pitch_list = song_input.max(axis=0) + 0.2 * song_input.mean(axis=0) 85 | threshold = np.quantile(pitch_list, config.silence_threshold) 86 | if config.check_silence_flag: 87 | threshold = np.max([threshold, config.check_silence_value]) 88 | threshold += config.silence_thresh_hard 89 | # print(f"Silent threshold: {threshold}") 90 | silent_list = np.asarray(timings)[pitch_list <= threshold] 91 | return silent_list 92 | 93 | 94 | def get_pitch_times(pitch_list, pitch_thresh=-123): 95 | # plt.figure() 96 | # plt.plot(pitch_list) 97 | # plt.show() 98 | 99 | pitch_times = [] 100 | # filter by threshold 101 | for idx, pit in enumerate(pitch_list): 102 | if pit > pitch_thresh: 103 | seconds = idx * config.hop_size / config.samplerate_music 104 | pitch_times.append(seconds) 105 | return pitch_times 106 | 107 | 108 | def samplerate_beats(real_beats, timing): 109 | beat_resampled_ar = [] 110 | for idx in range(len(real_beats)): 111 | cur_timing = np.asarray(timing[idx]) 112 | beat_resampled = np.zeros(len(timing[idx])) 113 | for beat in real_beats[idx]: 114 | beat_idx = np.argmin(abs(cur_timing - beat)) 115 | beat_resampled[beat_idx] = 1 116 | beat_resampled_ar.append(beat_resampled) 117 | 118 | return beat_resampled_ar 119 | 120 | 121 | if __name__ == '__main__': 122 | name_ar = os.listdir(paths.songs_pred) 123 | find_beats(name_ar) 124 | -------------------------------------------------------------------------------- /countlines.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import numpy as np 4 | 5 | 6 | def check_def(line): 7 | # match def functions at start of line and with optional tab 8 | m = re.match("^[\s]*\t?def (\w+)", line) 9 | if m: 10 | return m.group(1) 11 | else: 12 | return None 13 | 14 | 15 | def countlines(start, lines=0, header=True, begin_start=None): 16 | if header: 17 | print('{:>10} |{:>10} |{:>10} |{:>10} | {:<20}'.format('ADDED', 'TOTAL', 'CUR_FCT', 'FCT', 'FILE')) 18 | print('{:->11}|{:->11}|{:->11}|{:->11}|{:->20}'.format('', '', '', '', '')) 19 | 20 | global fct_count 21 | global script_count 22 | self_defined_functions = set() 23 | for thing in os.listdir(start): 24 | thing = os.path.join(start, thing) 25 | if os.path.isfile(thing): 26 | if thing.endswith('.py'): 27 | with open(thing, 'r') as f: 28 | cur_newlines = f.readlines() 29 | newlines = len(cur_newlines) 30 | lines += newlines 31 | 32 | if begin_start is not None: 33 | reldir_of_thing = '.' + thing.replace(begin_start, '') 34 | else: 35 | reldir_of_thing = '.' + thing.replace(start, '') 36 | 37 | cur_fct_count = 0 38 | for cur_line in cur_newlines: 39 | function_name = check_def(cur_line) 40 | if function_name: 41 | self_defined_functions.add(function_name) 42 | cur_fct_count += 1 43 | fct_count += cur_fct_count 44 | 45 | script_count += 1 46 | 47 | print('{:>10} |{:>10} |{:>10} |{:>10} | {:<20}'.format( 48 | newlines, lines, cur_fct_count, fct_count, reldir_of_thing)) 49 | 50 | for thing in os.listdir(start): 51 | thing = os.path.join(start, thing) 52 | if os.path.isdir(thing): 53 | lines = countlines(thing, lines, header=False, begin_start=start) 54 | 55 | return lines 56 | 57 | 58 | def count_function_calls(file_path, self_defined_functions): 59 | function_counts = {} 60 | with open(file_path, 'r') as f: 61 | lines = f.readlines() 62 | for line in lines: 63 | matches = re.findall(r'(\w+)\(', line) 64 | for match in matches: 65 | if match in self_defined_functions: 66 | if match in function_counts: 67 | function_counts[match] += 1 68 | else: 69 | function_counts[match] = 1 70 | return function_counts 71 | 72 | 73 | def display_function_usage(project_path): 74 | print("\nFunction call counts for self-defined functions:") 75 | self_defined_functions = set() 76 | for root, dirs, files in os.walk(project_path): 77 | for file in files: 78 | if file.endswith('.py'): 79 | file_path = os.path.join(root, file) 80 | with open(file_path, 'r') as f: 81 | lines = f.readlines() 82 | for line in lines: 83 | function_name = check_def(line) 84 | if function_name: 85 | self_defined_functions.add(function_name) 86 | 87 | function_counts = {} 88 | for root, dirs, files in os.walk(project_path): 89 | for file in files: 90 | if file.endswith('.py'): 91 | file_path = os.path.join(root, file) 92 | counts = count_function_calls(file_path, self_defined_functions) 93 | for function, count in counts.items(): 94 | if function in function_counts: 95 | function_counts[function] += count 96 | else: 97 | function_counts[function] = count 98 | 99 | sorted_counts = sorted(function_counts.items(), key=lambda x: x[1], reverse=True) 100 | 101 | print(f"Function calls mean: {np.mean(list(function_counts.values()))}") 102 | print("Top ten:") 103 | i = 0 104 | for function, count in sorted_counts: 105 | i += 1 106 | if i > 10: 107 | break 108 | print(f"Function {function:<30} is called {count:>3} times in the project.") 109 | 110 | 111 | if __name__ == '__main__': 112 | py_path = os.getcwd() 113 | global fct_count 114 | fct_count = 0 115 | global script_count 116 | script_count = 0 117 | countlines(py_path) 118 | print(f"Total of {script_count} python scripts.") 119 | 120 | display_function_usage(py_path) 121 | -------------------------------------------------------------------------------- /training/train_autoenc_music.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # import matplotlib.pyplot as plt 3 | import time 4 | # import tensorflow as tf 5 | 6 | from datetime import datetime 7 | from keras.models import Model 8 | from keras.optimizers import Adam 9 | from keras.layers import Input 10 | 11 | # Get the main script's directory 12 | import sys, os 13 | script_dir = os.path.dirname(os.path.realpath(__file__)) 14 | parent_dir = os.path.abspath(os.path.join(script_dir, "..")) 15 | sys.path.append(parent_dir) 16 | 17 | from bs_shift.bps_find_songs import bps_find_songs 18 | from helpers import test_gpu_tf, filter_by_bps, load_keras_model 19 | from plot_model import run_plot_autoenc 20 | from tensorflow_models import create_keras_model 21 | from preprocessing.music_processing import run_music_preprocessing 22 | from tools.config import config, paths 23 | from tools.fail_list.black_list import delete_fails 24 | from tools.utils.numpy_shorts import reduce_number_of_songs 25 | 26 | # # Check Cuda compatible GPU 27 | # if not test_gpu_tf(): 28 | # exit() 29 | 30 | # Setup configuration 31 | ##################### 32 | min_bps_limit = config.min_bps_limit 33 | max_bps_limit = config.max_bps_limit 34 | learning_rate = config.learning_rate 35 | n_epochs = config.n_epochs 36 | # epochs_per_input = config.epochs_per_input 37 | # n_epochs = int(n_epochs/epochs_per_input) 38 | batch_size = config.batch_size 39 | test_samples = config.test_samples 40 | np.random.seed(3) 41 | 42 | # Data Preprocessing 43 | #################### 44 | # get name array 45 | name_ar, _ = filter_by_bps(min_bps_limit, max_bps_limit) 46 | 47 | # Reduce amount of songs 48 | name_ar = reduce_number_of_songs(name_ar, hard_limit=config.autoenc_song_limit) 49 | 50 | # load song input 51 | i = 0 52 | while i == 0: 53 | try: 54 | i += 1 55 | song_ar, _ = run_music_preprocessing(name_ar, save_file=False, 56 | song_combined=True, channels_last=True) 57 | except: 58 | print("Need to restart due to problem with one song.") 59 | delete_fails() 60 | time.sleep(0.1) 61 | bps_find_songs(info_flag=False) 62 | time.sleep(0.1) 63 | name_ar, _ = filter_by_bps(min_bps_limit, max_bps_limit) 64 | time.sleep(0.1) 65 | i -= 1 66 | 67 | # sample into train/val/test 68 | ds_test = song_ar[:test_samples] 69 | song_ar = song_ar[test_samples:] 70 | 71 | # shuffle and split 72 | np.random.shuffle(song_ar) 73 | split = int(song_ar.shape[0] * 0.85) 74 | 75 | # setup data loaders 76 | ds_train = song_ar[:split] 77 | ds_val = song_ar[split:] 78 | 79 | # Model Building 80 | ################ 81 | # create timestamp 82 | dateTimeObj = datetime.now() 83 | timestamp = f"{dateTimeObj.month}_{dateTimeObj.day}__{dateTimeObj.hour}_{dateTimeObj.minute}" 84 | save_model_name = f"tf_model_autoenc_{config.bottleneck_len}bneck_{timestamp}.h5" 85 | save_enc_name = f"tf_model_enc_{config.bottleneck_len}bneck_{timestamp}.h5" 86 | # save_model_name = "old" 87 | 88 | # load model 89 | auto_encoder, save_model_name = load_keras_model(save_model_name) 90 | # create model 91 | if auto_encoder is None: 92 | encoder = create_keras_model('enc1', learning_rate) 93 | decoder = create_keras_model('dec1', learning_rate) 94 | auto_input = Input(shape=(24, 20, 1)) 95 | encoded = encoder(auto_input) 96 | decoded = decoder(encoded) 97 | auto_encoder = Model(auto_input, decoded) 98 | 99 | adam = Adam(learning_rate=learning_rate, weight_decay=learning_rate / n_epochs) 100 | auto_encoder.compile(loss='mean_squared_error', optimizer=adam, metrics=['accuracy']) 101 | encoder.compile(loss='mean_squared_error', optimizer=adam, metrics=['accuracy']) 102 | 103 | # Model Training 104 | ################ 105 | # min_val_loss = np.inf 106 | # for epoch in range(1, n_epochs + 1): 107 | 108 | # Training 109 | training = auto_encoder.fit(x=ds_train, y=ds_train, validation_data=(ds_val, ds_val), 110 | epochs=n_epochs, batch_size=batch_size, 111 | shuffle=True, verbose=1) 112 | 113 | # Model Evaluation 114 | ################## 115 | print("\nEvaluating test data...") 116 | eval = auto_encoder.evaluate(ds_test, ds_test) 117 | # print(f"Test loss: {eval[0]:.4f}, test accuracy: {eval[1]:.4f}") 118 | try: 119 | run_plot_autoenc(encoder, auto_encoder, ds_test, save=True) 120 | except Exception as e: 121 | print(f"Error: {type(e).__name__}") 122 | print(f"Error message: {e}") 123 | print("Error in music autoencoder plotting. Continue with saving.") 124 | 125 | # Save Model 126 | ############ 127 | print(f"Saving model: {save_model_name} at: {paths.model_path}") 128 | auto_encoder.save(paths.model_path + save_model_name) 129 | encoder.save(paths.model_path + save_enc_name) 130 | 131 | print("\nFinished Training") 132 | -------------------------------------------------------------------------------- /map_creation/gen_sliders.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from random import randint, random 3 | 4 | from tools.config import config 5 | from map_creation.sanity_check import calc_note_speed 6 | 7 | 8 | def get_integer_note_color(sideL: bool): 9 | if sideL: 10 | side_integer = 0 11 | else: 12 | side_integer = 1 13 | return side_integer 14 | 15 | 16 | def get_side_time_gaps(sideL, notes, timings): 17 | side_integer = get_integer_note_color(sideL) 18 | real_time_gap = [] 19 | real_time_gap_indices = [] 20 | b_flag = False 21 | t0 = -1 22 | for i in range(0, len(timings)): 23 | # check for notes 24 | if len(notes[i]) > 0: 25 | # skip if bomb is found 26 | if 3 in notes[i][2::4]: 27 | b_flag = True 28 | continue 29 | elif b_flag: 30 | # t0 = -1 31 | b_flag = False 32 | continue 33 | # check note color 34 | if side_integer in notes[i][2::4]: 35 | if t0 < 0: 36 | t0 = i 37 | real_time_gap.append(timings[t0]) 38 | real_time_gap_indices.append([-1, i]) 39 | continue 40 | # found it! 41 | t1 = i 42 | real_time_gap.append(timings[t1] - timings[t0]) 43 | real_time_gap_indices.append([t0, t1]) 44 | t0 = i 45 | # if real_time_gap_indices[-1] == [286, 290]: 46 | # print("") 47 | return real_time_gap, real_time_gap_indices 48 | 49 | 50 | def get_position_of_note(notes, ix, side_integer): 51 | for n in range(int(len(notes[ix]) / 4)): 52 | if side_integer == notes[ix][2 + n*4]: 53 | # found it! 54 | x = notes[ix][0 + n*4] 55 | y = notes[ix][1 + n*4] 56 | d = notes[ix][3 + n*4] 57 | return x, y, d 58 | 59 | print("Error: Could not find note to attach arc.") 60 | exit() 61 | 62 | 63 | def get_side_sliders(sideL, notes, timings, tg, tg_index): 64 | side_integer = get_integer_note_color(sideL) 65 | sliders = [] 66 | for idx, time_gap in enumerate(tg): 67 | if idx == 0: 68 | if config.slider_turbo_start: 69 | i0 = tg_index[idx][0] 70 | i1 = tg_index[idx][1] 71 | if i0 < 0: 72 | x1, y1, d1 = get_position_of_note(notes, i1, side_integer) 73 | # d0 = 8 74 | d0 = d1 75 | anchor_mode = 1 76 | if side_integer == 0: 77 | if d1 in [4, 0, 5]: # left side up 78 | anchor_mode = 2 79 | else: 80 | if d1 in [6, 1, 7]: # right side down 81 | anchor_mode = 2 82 | sliders.append([0.3*timings[i1], side_integer, x1, y1, d0, config.slider_radius_multiplier, 83 | timings[i1], x1, y1, d1, config.slider_radius_multiplier, anchor_mode]) 84 | continue 85 | 86 | if config.slider_time_gap[0] < time_gap < config.slider_time_gap[1]: 87 | # get corresponding note position 88 | i0 = tg_index[idx][0] 89 | i1 = tg_index[idx][1] 90 | x0, y0, d0 = get_position_of_note(notes, i0, side_integer) 91 | x1, y1, d1 = get_position_of_note(notes, i1, side_integer) 92 | 93 | if config.slider_movement_minimum > 0: 94 | # delete some random sliders based on movement distance 95 | nl_last = [x0, y0, 0, d0] 96 | nl_new = [x1, y1, 0, d1] 97 | speed = calc_note_speed(nl_last, nl_new, 1, 98 | cdf_lr=1.5) 99 | if speed < config.slider_movement_minimum: 100 | continue 101 | if 0 <= config.slider_probability < 1: 102 | if random() > config.slider_probability: 103 | continue 104 | 105 | sliders.append([timings[i0], side_integer, x0, y0, d0, config.slider_radius_multiplier, 106 | timings[i1], x1, y1, d1, config.slider_radius_multiplier, 0]) 107 | return sliders 108 | 109 | 110 | def calculate_sliders(notes, timings): 111 | sliders_combined = [] 112 | # get time gaps for left side 113 | tg, tg_index = get_side_time_gaps(True, notes, timings) 114 | slidersL = get_side_sliders(True, notes, timings, tg, tg_index) 115 | tg, tg_index = get_side_time_gaps(False, notes, timings) 116 | slidersR = get_side_sliders(False, notes, timings, tg, tg_index) 117 | sliders_combined.extend(slidersL) 118 | sliders_combined.extend(slidersR) 119 | if config.verbose_level > 3: 120 | print(f"Generated {len(sliders_combined)} arc sliders.") 121 | return sliders_combined 122 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | 2 | <br/> 3 | <div align="center"> 4 | <a href="https://github.com/fred-brenner/InfernoSaber---BeatSaber-Automapper/edit/main_app"> 5 | <img src="https://github.com/fred-brenner/InfernoSaber---BeatSaber-Automapper/blob/main_app/app_helper/cover.jpg" alt="Logo" width="80" height="80"> 6 | </a> 7 | <h3 align="center">InfernoSaber</h3> 8 | <p align="center"> 9 | Flexible Automapper for Beatsaber made for any difficulty 10 | <br/> 11 | <br/> 12 | <a href="https://www.youtube.com/watch?v=GpdHE6puDng"><strong>Installation walkthrough »</strong></a> 13 | <br/> 14 | <br/> 15 | <a href="https://www.youtube.com/watch?v=wJSOBuKs42Q">View Demo .</a> 16 | <a href="https://github.com/fred-brenner/InfernoSaber---BeatSaber-Automapper/issues">Report Bug .</a> 17 | <a href="https://github.com/fred-brenner/InfernoSaber---BeatSaber-Automapper/discussions">Request Feature</a> 18 | </p> 19 | </div> 20 | 21 | ## About The Project 22 | 23 | ![Screenshot played by RamenBot](https://i.imgur.com/ECXMxY5.jpeg) 24 | 25 | Automapper with fully adjustable difficulty (inpsired by star difficulty) ranging from easy maps (< 1) to Expert+++ maps (10+) 26 | 27 | Update Jan 2025: App is finally available via Pinokio: https://program.pinokio.computer/#/ 28 | Just got to "Discover" and then "Download from URL": https://github.com/fred-brenner/InfernoSaber-App 29 | 30 | This installs all dependencies in the capsulated environment of Pinokio and loads the application from (this) main repository: 31 | https://github.com/fred-brenner/InfernoSaber---BeatSaber-Automapper/tree/main_app 32 | 33 | Alternatively: 34 | 35 | Join the Discord and let the bot generate single difficulty maps for you (currently not available): 36 | https://discord.com/invite/cdV6HhpufY 37 | 38 | ... Or clone the repo yourself (Note: Use a conda environment to install audio packages on windows machines) 39 | 40 | ### Built With 41 | 42 | The automapper currently consists of 4 consecutive AI models: 43 | 44 | 1. Deep convolutional autoencoder - to encode the music/simplify all other models 45 | 2. Temporal Convolutional Network (TCN) - to generate the beat 46 | 3. Deep Neural Network (Classification) - mapping the notes/bombs 47 | 4. Deep Neural Network (Classification) - mapping the events/lights 48 | 49 | ## Getting Started 50 | 51 | Install via Pinokio. A walkthrough is given in: https://www.youtube.com/watch?v=GpdHE6puDng 52 | 53 | This project is open-source, free-to-use and will remain so. Enjoy :) 54 | 55 | ### Prerequisites 56 | 57 | Current pinokio version from: https://github.com/pinokiocomputer/pinokio/releases 58 | 59 | ### Installation 60 | 61 | (Not recommended) You can also clone the repo yourself. Note: Conda environment works best to install audio packages on windows machines 62 | 63 | ## Usage 64 | 65 | The inference usage is simplified with the included app in branch 'main_app'. The AI models will be automatically downloaded during runtime from [Hugging Face](https://huggingface.co/BierHerr/InfernoSaber), if not yet available. 66 | 67 | You can also train your own models on your favorite maps and difficulty. This can only be done locally with cloning the repo and using GPU (one better consumer GPU is enough) A guide to train the 4 models is included in the repo: 'How_to_Train_InfernoSaber.docx' 68 | 69 | Extract maps from Beatsaber/Bsaber to feed them into AI models. Map versions with custom modded data (values out of normal boundaries) are excluded, so that the data is as smooth as possible. 70 | 71 | ## Roadmap 72 | 73 | [ ] Increase number of models to improve accuracy and enable more options 74 | 75 | [x] Support new features for InfernoSaber Pinokio App 76 | 77 | [ ] Get the discord bot back online (yt blocks the bot) 78 | 79 | ## Contributing 80 | 81 | Contributions are what make the open source community such an amazing place to learn, inspire, and create. Any contributions you make are **greatly appreciated**. 82 | 83 | If you have a suggestion that would make this better, please fork the repo and create a pull request. You can also simply open an issue with the tag "enhancement". 84 | Don't forget to give the project a star and join the [Discord community](https://discord.com/invite/cdV6HhpufY)! Thanks again! 85 | 86 | 1. Fork the Project and checkout the 'main_app' branch 87 | 2. Create your Feature Branch (`git checkout -b feature/AmazingFeature`) 88 | 3. Commit your Changes (`git commit -m 'Add some AmazingFeature'`) 89 | 4. Push to the Branch (`git push origin feature/AmazingFeature`) 90 | 5. Open a Pull Request 91 | 92 | ## License 93 | 94 | Distributed under the MIT License. See [MIT License](https://opensource.org/licenses/MIT) for more information. 95 | 96 | ## Contact 97 | 98 | I've been working on this app since the release of BeatSaber and am happy to share the progress here! 99 | 100 | Author: Frederic Brenner 101 | frederic.brenner@tum.de 102 | 103 | ## Acknowledgments 104 | 105 | Thanks for the many contributions on Discord and Github so far. Here, I want to thank the code contributors 106 | - [aCzenderCa](https://github.com/aCzenderCa) - App enhancement and fixes 107 | - [tjoen](https://github.com/tjoen) - Prototype for Pinokio install script 108 | -------------------------------------------------------------------------------- /preprocessing/beat_data_helper.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | 4 | from tools.config import paths, config 5 | 6 | 7 | def load_raw_beat_data(name_ar): 8 | map_dict_events = [] 9 | map_dict_notes = [] 10 | map_dict_obstacles = [] 11 | for song_name in name_ar: 12 | # import beat data 13 | # Load notes, events, obstacles, all already divided by bpm from info file! (in real sec) 14 | map_dict_events.append(np.load(paths.dict_all_path + song_name + "_events.dat", allow_pickle=True)) 15 | map_dict_notes.append(np.load(paths.dict_all_path + song_name + "_notes.dat", allow_pickle=True)) 16 | map_dict_obstacles.append(np.load(paths.dict_all_path + song_name + "_obstacles.dat", allow_pickle=True)) 17 | 18 | return map_dict_notes, map_dict_events, map_dict_obstacles 19 | 20 | 21 | def sort_beats_by_time(song_ar: list): 22 | new_song_ar = [] 23 | new_time_ar = [] 24 | min_time = config.min_time_diff 25 | 26 | for map_ar in song_ar: 27 | # setup temp variables 28 | last_time = -1 29 | new_map_ar = [] 30 | timeline = [] 31 | 32 | for idx in range(map_ar.shape[1]): 33 | cur_map = map_ar[:, idx] 34 | # run sanity check 35 | if run_notes_sanity_check(cur_map): 36 | cur_time = cur_map[0] 37 | cur_map = cur_map[1:].astype('int16') 38 | if cur_time > last_time + min_time: 39 | # append new beat 40 | new_map_ar.append(cur_map) 41 | last_time = cur_time 42 | timeline.append(cur_time) 43 | else: # time unchanged 44 | # check if position changed 45 | if pos_changed(new_map_ar[-1], cur_map): 46 | # add notes to last beat 47 | new_map_ar[-1] = np.vstack((new_map_ar[-1], cur_map)) 48 | new_song_ar.append(new_map_ar) 49 | new_time_ar.append(timeline) 50 | 51 | return new_song_ar, new_time_ar 52 | 53 | 54 | def run_notes_sanity_check(beat_matrix, verbose=True): 55 | # check data types for notes 56 | if beat_matrix[0] < 0: 57 | if verbose: 58 | print("beat_sanity_check: Found time below zero, skipping.") 59 | return False 60 | if beat_matrix[1] not in [0, 1, 2, 3]: 61 | if verbose: 62 | print("beat_sanity_check: Found wrong _lineIndex, skipping.") 63 | return False 64 | if beat_matrix[2] not in [0, 1, 2]: 65 | if verbose: 66 | print("beat_sanity_check: Found wrong _lineLayer, skipping.") 67 | return False 68 | if beat_matrix[3] not in [0, 1, 3]: 69 | if verbose: 70 | print("beat_sanity_check: Found wrong _type, skipping.") 71 | return False 72 | if beat_matrix[4] not in [0, 1, 2, 3, 4, 5, 6, 7, 8]: 73 | if verbose: 74 | print("beat_sanity_check: Found wrong _cutDir, skipping.") 75 | return False 76 | # check completed 77 | return True 78 | 79 | 80 | def pos_changed(last_map, cur_map): 81 | if len(last_map.shape) > 1: 82 | last_map = last_map[-1] 83 | # check lineindex or linelayer changed 84 | if last_map[1-1] != cur_map[1-1] or last_map[2-1] != cur_map[2-1]: 85 | return True 86 | # else 87 | return False 88 | 89 | 90 | def cluster_notes_in_classes(notes_ar): 91 | # notes_flattened = [item for sublist in notes_ar for item in sublist] 92 | # create classify dictionary 93 | class_key = [] 94 | idx = -1 95 | # get class ID for each unique value 96 | new_song_ar = [] 97 | for song in notes_ar: 98 | new_class_ar = [] 99 | for beat in song: 100 | # if idx == 345: 101 | # print("") 102 | beat = encode_beat_ar(beat) 103 | if beat not in class_key: 104 | # unknown pattern 105 | class_key.append(beat) 106 | idx += 1 107 | # known pattern 108 | key_idx = class_key.index(beat) 109 | new_class_ar.append(key_idx) 110 | # if key_idx != idx: 111 | # print("") 112 | new_song_ar.append(new_class_ar) 113 | 114 | # save classify dictionary 115 | with open(paths.notes_classify_dict_file, "wb") as dict_file: 116 | pickle.dump(class_key, dict_file) 117 | 118 | return new_song_ar 119 | 120 | 121 | def encode_beat_ar(beat): 122 | 123 | # Remove all double notes 124 | if config.remove_double_notes: 125 | if len(beat.shape) > 1: 126 | new_beat = [] 127 | notes_done = [] 128 | for idx in range(len(beat)): 129 | cur_note = beat[idx, 2] 130 | if cur_note not in notes_done: 131 | new_beat.append(beat[idx]) 132 | notes_done.append(cur_note) 133 | beat = np.asarray(new_beat) 134 | 135 | beat = list(beat.reshape(-1)) 136 | beat_f = "" 137 | for el in beat: 138 | beat_f += f"{el}" 139 | return beat_f 140 | -------------------------------------------------------------------------------- /training/train_bs_automapper.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import numpy as np 3 | from datetime import datetime 4 | # from tensorflow import keras 5 | from keras.optimizers import Adam 6 | from tabulate import tabulate 7 | 8 | from helpers import test_gpu_tf, ai_encode_song, load_keras_model, categorical_to_class 9 | from tensorflow_models import create_keras_model 10 | from preprocessing.bs_mapper_pre import load_ml_data, lstm_shift 11 | from tools.config import config, paths 12 | from lighting_prediction.train_lighting import lstm_shift_events_half 13 | 14 | 15 | # # Check Cuda compatible GPU 16 | # if not test_gpu_tf(): 17 | # exit() 18 | 19 | # Setup configuration 20 | ##################### 21 | min_bps_limit = config.min_bps_limit 22 | max_bps_limit = config.max_bps_limit 23 | learning_rate = config.map_learning_rate 24 | n_epochs = config.map_n_epochs 25 | batch_size = config.map_batch_size 26 | test_samples = config.map_test_samples 27 | np.random.seed(3) 28 | 29 | # Data Preprocessing 30 | #################### 31 | ml_input, ml_output = load_ml_data() 32 | # ml_input, ml_output = lstm_shift(ml_input[0], ml_input[1], ml_output) 33 | # [in_song, in_time_l, in_class_l] = ml_input 34 | # in_song_l = ai_encode_song(in_song) 35 | 36 | in_song_l = ai_encode_song(ml_input[0]) 37 | ml_input, ml_output = lstm_shift_events_half(in_song_l, ml_input[1], ml_output, config.lstm_len) 38 | [in_song_l, in_time_l, in_class_l] = ml_input 39 | 40 | # Sample into train/val/test 41 | ############################ 42 | last_test_samples = len(in_song_l) - test_samples 43 | # use last samples as test data 44 | in_song_test = in_song_l[last_test_samples:] 45 | in_time_test = in_time_l[last_test_samples:] 46 | in_class_test = in_class_l[last_test_samples:] 47 | out_class_test = ml_output[last_test_samples:] 48 | 49 | in_song_train = in_song_l[:last_test_samples] 50 | in_time_train = in_time_l[:last_test_samples] 51 | in_class_train = in_class_l[:last_test_samples] 52 | out_class_train = ml_output[:last_test_samples] 53 | 54 | # normal lstm lstm 55 | ds_train = [in_song_train, in_time_train, in_class_train] 56 | ds_test = [in_song_test, in_time_test, in_class_test] 57 | 58 | ds_train_sample = [in_song_train[:test_samples], in_time_train[:test_samples], in_class_train[:test_samples]] 59 | 60 | # dim_in = [in_song_train[0].shape, in_time_train[0].shape, in_class_train[0].shape] 61 | # dim_out = out_class_train.shape[1] 62 | dim_in = [x.shape for x in ds_train] 63 | dim_out = out_class_train.shape 64 | 65 | # delete variables to free ram 66 | # keras.backend.clear_session() 67 | # del encoder 68 | del in_class_l 69 | del in_class_test 70 | # del in_song 71 | del in_song_l 72 | del in_song_test 73 | del in_song_train 74 | del in_time_l 75 | del in_time_test 76 | del in_time_train 77 | del ml_input 78 | del ml_output 79 | gc.collect() 80 | 81 | 82 | # Create model 83 | ############## 84 | # create timestamp 85 | dateTimeObj = datetime.now() 86 | timestamp = f"{dateTimeObj.month}_{dateTimeObj.day}__{dateTimeObj.hour}_{dateTimeObj.minute}" 87 | save_model_name = f"tf_model_mapper_{min_bps_limit}-{max_bps_limit}_{timestamp}.h5" 88 | # load model 89 | mapper_model, save_model_name = load_keras_model(save_model_name) 90 | # create model 91 | if mapper_model is None: 92 | mapper_model = create_keras_model('lstm_half', dim_in, dim_out) 93 | adam = Adam(learning_rate=learning_rate, weight_decay=2*learning_rate / n_epochs) 94 | # mapper_model.compile(loss='mean_squared_error', optimizer=adam, metrics=['accuracy']) 95 | mapper_model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy']) 96 | 97 | mapper_model.summary() 98 | 99 | # Model training 100 | ################ 101 | training = mapper_model.fit(x=ds_train, y=out_class_train, 102 | epochs=n_epochs, batch_size=batch_size, 103 | shuffle=True, verbose=1) 104 | 105 | # Evaluate model 106 | ################ 107 | try: 108 | command_len = 10 109 | print("Validate model with test data...") 110 | validation = mapper_model.evaluate(x=ds_test, y=out_class_test) 111 | pred_result = mapper_model.predict(x=ds_test, verbose=0) 112 | 113 | pred_class = categorical_to_class(pred_result) 114 | real_class = categorical_to_class(out_class_test) 115 | 116 | pred_class = pred_class.flatten().tolist() 117 | real_class = real_class.flatten().tolist() 118 | 119 | print(tabulate([['Pred', pred_class], ['Real', real_class]], headers=['Type', 'Result (test data)'])) 120 | 121 | print("Validate model with train data...") 122 | validation = mapper_model.evaluate(x=ds_train_sample, y=out_class_train[:test_samples]) 123 | 124 | pred_result = mapper_model.predict(x=ds_train_sample, verbose=0) 125 | pred_class = categorical_to_class(pred_result) 126 | real_class = categorical_to_class(out_class_train[:test_samples]) 127 | 128 | pred_class = pred_class.flatten().tolist() 129 | real_class = real_class.flatten().tolist() 130 | 131 | print(tabulate([['Pred', pred_class], ['Real', real_class]], headers=['Type', 'Result (train data)'])) 132 | 133 | except Exception as e: 134 | print(f"Error: {type(e).__name__}") 135 | print(f"Error message: {e}") 136 | print("Error in displaying mapper evaluation. Continue with saving.") 137 | 138 | # Save Model 139 | ############ 140 | print(f"Saving model at: {paths.model_path + save_model_name}") 141 | mapper_model.save(paths.model_path + save_model_name) 142 | 143 | print("Finished Training") 144 | -------------------------------------------------------------------------------- /bs_shift/export_map.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | from pydub import AudioSegment, effects 5 | 6 | from tools.config import config, paths 7 | from tools.utils.song_metadata import extract_metadata, metadata_to_tags, save_metadata 8 | 9 | 10 | def shutil_copy_maps(song_name, index="1234_"): 11 | if not os.path.isdir(paths.bs_song_path): 12 | print("Warning: Beatsaber folder not found, automatic export disabled.") 13 | return False 14 | 15 | src = f'{paths.new_map_path}{index}{song_name}' 16 | dst = f'{paths.bs_song_path}{index}{song_name}' 17 | shutil.copytree(src=src, dst=dst, dirs_exist_ok=True) 18 | return True 19 | 20 | 21 | def check_music_files(files, dir_path): 22 | song_list = [] 23 | metadata_for_song = {} 24 | for file_name in files: 25 | ending = file_name.split('.')[-1].lower() 26 | # if ending in ['mp3', 'mp4', 'm4a', 'wav', 'aac', 'flv', 'wma']: # (TODO: test more music formats) 27 | if ending in ['mp3', 'mp4', 'm4a']: # (TODO: test more music formats) 28 | # if ending in ['NOT_WORKING!']: 29 | # convert music to ogg format 30 | output_file = f"{file_name[:-4]}.egg" 31 | source_path = dir_path + file_name 32 | metadata_for_song[output_file] = extract_metadata(source_path) 33 | convert_music_file(source_path, dir_path + output_file, 34 | metadata=metadata_for_song[output_file]) 35 | os.remove(dir_path + file_name) 36 | song_list.append(output_file) 37 | elif ending in ['ogg']: 38 | source = dir_path + file_name 39 | destination = dir_path + file_name.replace('.ogg', '.egg') 40 | metadata_for_song[file_name.replace('.ogg', '.egg')] = extract_metadata(source) 41 | shutil.move(source, destination) 42 | song_list.append(file_name.replace('.ogg', '.egg')) 43 | elif ending in ['egg']: 44 | song_list.append(file_name) 45 | metadata_for_song[file_name] = extract_metadata(dir_path + file_name) 46 | else: 47 | print(f"Warning: Can not read {file_name} as music file.") 48 | pass 49 | 50 | # Check the file name for unsupported characters 51 | for idx, song_name in enumerate(song_list): 52 | new_name = song_name.replace(' &', ',') 53 | new_name = new_name.replace('&', ',') 54 | new_name = new_name.replace('{', '(') 55 | new_name = new_name.replace('}', ')') 56 | while new_name[:-4].endswith(' '): 57 | new_name = f"{new_name[:-5]}.egg" 58 | if new_name != song_name: 59 | shutil.move(dir_path + song_name, dir_path + new_name) 60 | metadata_for_song[new_name] = metadata_for_song.pop(song_name, {}) 61 | song_list[idx] = new_name 62 | 63 | # Normalize the volume for each song in advance 64 | if config.normalize_song_flag: 65 | print("Running volume check for input songs...") 66 | for song_name in song_list: 67 | audio = AudioSegment.from_file(dir_path + song_name, format="ogg") 68 | if config.increase_volume_flag: 69 | rms = audio.rms/1e9 70 | # print(f"Audio rms: {rms:.2f} x1e9") 71 | else: 72 | rms = 10 73 | if audio.max_dBFS > 0.0 or rms < config.audio_rms_goal: 74 | # normalize if volume is below max, else skip 75 | headroom = -1 * (0.42 + (config.audio_rms_goal - rms) * 16) 76 | normalized_song = effects.normalize(audio, headroom=headroom) 77 | error_flag = 0 78 | while normalized_song.rms / 1e9 < config.audio_rms_goal: 79 | error_flag += 1 80 | headroom -= 0.5 81 | # headroom = -1 * (0.42 + (config.audio_rms_goal - rms) * 16) 82 | normalized_song = effects.normalize(audio, headroom=headroom) 83 | if error_flag > 10: 84 | print(f"Warning: Maximum iterations for normalizing song exceeded: {song_name}. Continue") 85 | break 86 | 87 | tags = metadata_to_tags(metadata_for_song.get(song_name)) 88 | normalized_song.export(dir_path + song_name, format="ogg", tags=tags) 89 | print(f"Normalized volume of song: {song_name} with new RMS: {normalized_song.rms/1e9:.2f}") 90 | else: 91 | # ensure metadata is kept even if no normalization is needed 92 | metadata_for_song[song_name] = metadata_for_song.get(song_name, {}) 93 | 94 | for song_name in song_list: 95 | metadata = metadata_for_song.get(song_name) 96 | if metadata is None: 97 | metadata = extract_metadata(dir_path + song_name) 98 | save_metadata(song_name[:-4], metadata) 99 | return song_list 100 | 101 | 102 | def convert_music_file(file_name, output_file, metadata=None): 103 | # Load the mp3 file 104 | m_format = file_name.split('.')[-1] 105 | if m_format == 'egg': 106 | m_format = 'ogg' 107 | audio = AudioSegment.from_file(file_name, format=m_format) 108 | 109 | output_file_format = output_file.split('.')[-1] 110 | 111 | # Export the audio as ogg file 112 | tags = metadata_to_tags(metadata) 113 | if output_file_format == 'ogg' or output_file_format == 'egg': 114 | audio.export(output_file, format="ogg", tags=tags) 115 | elif output_file_format == 'mp3': 116 | audio.export(output_file, format="mp3", tags=tags) 117 | else: 118 | print(f"Error in music converter: Format unknown: {m_format}") 119 | exit() 120 | -------------------------------------------------------------------------------- /How to - Linux install for training and app.md: -------------------------------------------------------------------------------- 1 | # How to Train Your BeatSaber Automapper with InfernoSaber 2 | 3 | ## System Requirements 4 | 5 | **Recommended setup for training:** 6 | - **GPU:** NVIDIA with ≥ 8GB VRAM (NOTE: Last tested May 2025, current NVIDIA GPUs (5000 series) are not supported) 7 | - **RAM:** ≥ 24GB 8 | - **OS:** Linux (or WSL2 on Windows) 9 | - *(This spec supports ~50–150 songs, depending on variety)* 10 | 11 | **Recommended setup for inference/app execution:** 12 | - **GPU:** None 13 | - **RAM:** ≥ 8GB 14 | - **OS:** Windows, Linux 15 | 16 | --- 17 | 18 | ## Using WSL2 on Windows 19 | 20 | You will need Linux for training (and only for that). 21 | If you’re not using Linux natively or via dual boot, follow guides online to set up WSL2 and increase its memory allocation (reserving less for Windows). Tested with NVIDIA 30-series GPUs. Failed with NVIDIA 50-series GPUs due to driver issues. 22 | 23 | --- 24 | 25 | ## Installation 26 | 27 | 1. **Clone the repository**: 28 | [InfernoSaber GitHub](https://github.com/fred-brenner/InfernoSaber---BeatSaber-Automapper) 29 | Use the `main` branch for the latest stable version for training. Use the `main_app` branch for inference. 30 | 31 | 2. **Editor Recommendation**: 32 | Use **PyCharm** or **VSCode** for project editing. 33 | 34 | 3. **Install WSL2 and Python 3.10**: 35 | [Guide](https://learn.microsoft.com/en-us/windows/python/web-frameworks#install-windows-subsystem-for-linux) 36 | 37 | 4. **Update and upgrade packages**: 38 | ```bash 39 | sudo apt update && sudo apt upgrade 40 | 41 | 5. **Create your preferred Python env**: 42 | go to InfernoSaber folder with cd, ls: 43 | ```bash 44 | cd mnt/c/Users/YourUsername/Desktop/BS_Automapper/InfernoSaber---BeatSaber-Automapper 45 | ``` 46 | ```bash 47 | sudo apt install -y software-properties-common 48 | sudo add-apt-repository -y ppa:deadsnakes/ppa 49 | sudo apt install -y python3.10 python3.10-venv python3.10-dev 50 | python3.10 --version 51 | python3.10 -m venv ubuntu_venv 52 | ``` 53 | Always after startup: 54 | ```bash 55 | source ubuntu_venv/bin/activate 56 | ``` 57 | 58 | 6. **Install TensorFlow with CUDA (for NVIDIA)**: 59 | ```bash 60 | pip install tensorflow[and-cuda]==2.15 61 | ``` 62 | Tested with TensorFlow 2.15 63 | For app install only (no training), you don't need CUDA: 64 | ```bash 65 | pip install tensorflow==2.15 66 | ``` 67 | 68 | 7. **Install required dependencies**: 69 | ```bash 70 | sudo apt install libswresample-dev libsamplerate-dev libsndfile-dev txt2man doxygen 71 | sudo apt install python3-aubio aubio-tools ffmpeg libavcodec-extra 72 | sudo apt install libavcodec-dev libavformat-dev libavutil-dev libswresample-dev 73 | ``` 74 | 75 | 8. **In case of `aubio` issues (skip else)**: 76 | ```bash 77 | pip uninstall -yv aubio 78 | pip install --force-reinstall --no-cache-dir --verbose aubio 79 | ``` 80 | 81 | 9. **Install Python requirements**: 82 | ```bash 83 | pip install git+https://git.aubio.org/aubio/aubio/ 84 | ``` 85 | Aubio tends to make problems. Alternative is to install via Conda or try to use pip install aubio. 86 | 87 | Make sure all already installed versions are removed from the requirements.txt (tensorflow, keras, aubio) 88 | ```bash 89 | pip install -r requirements.txt 90 | ``` 91 | 92 | 10. **Configure your paths**: 93 | Edit `/tools/config/paths.py` and set desired folders. 94 | 95 | --- 96 | 97 | ## Configuration 98 | 99 | Edit `config.py` for training setup: 100 | 101 | ```python 102 | use_mapper_selection = "your_model_name" 103 | use_bpm_selection = True # Set to False for advanced sorting 104 | min_bps_limit = 1 # Songs with less will be excluded 105 | max_bps_limit = 500 # Optional upper bound 106 | training_songs_diff = "Expert" 107 | training_songs_diff2 = "Hard" 108 | allow_training_diff2 = True 109 | vram_limit = 8 # Adjust based on GPU 110 | autoenc_song_limit = 100 # Reduce if needed 111 | mapper_song_limit = 200 # Approx. 200 per 30GB RAM 112 | beat_song_limit = 200 # Same as above 113 | ``` 114 | 115 | Songs exceeding these limits will be randomly discarded. 116 | 117 | --- 118 | 119 | ## Folder Structure 120 | 121 | Before starting, review the folder architecture from the **Pinokio app** and replicate it in your working "Data" directory. Configure the `config.py` to match. 122 | When using the app from `main_app`, a script will always be started to create the necessary folder structure. 123 | 124 | --- 125 | 126 | ## Preparing Input Songs 127 | 128 | 1. **Back up your Beat Saber songs** 129 | Copy them to a separate folder and set this path in `bs_input_path`. 130 | 131 | 2. **Select your favorite maps**: 132 | - Run `hashtest.py` after adapting paths. 133 | - Run `copyfavorites.py` to copy selected songs. 134 | 135 | 3. **Clean and format songs**: 136 | ```bash 137 | cd InfernoSaber # Must run from the project root 138 | python3 bs_shift/cleanup_n_format.py 139 | ``` 140 | 141 | This formats `.dat` files and prepares songs for training. 142 | 143 | --- 144 | 145 | ## Training 146 | 147 | Run the training script: 148 | ```bash 149 | python3 main_training.py 150 | ``` 151 | 152 | You’ll be prompted five times—type `yyyyy` to run all stages consecutively. 153 | 154 | If interrupted or one stage fails, restart from where you left off (e.g., `nnnyy`). 155 | 156 | --- 157 | 158 | ## Testing the Environment 159 | 160 | Check if TensorFlow detects the GPU: 161 | ```bash 162 | python3 -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))" 163 | ``` 164 | 165 | Expected output: 166 | ```python 167 | [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')] 168 | ``` 169 | 170 | Check memory available to Linux: 171 | ```bash 172 | free 173 | ``` 174 | 175 | --- 176 | 177 | ## Troubleshooting 178 | 179 | For questions or issues, reach out on **Discord**. 180 | Check the **improvements channel** for successful implementations and community tips. 181 | -------------------------------------------------------------------------------- /lighting_prediction/generate_lighting.py: -------------------------------------------------------------------------------- 1 | # import numpy as np 2 | # import matplotlib.pyplot as plt 3 | # import gc 4 | # import pickle 5 | 6 | # from datetime import datetime 7 | # from sklearn.preprocessing import OneHotEncoder 8 | 9 | # from lighting_prediction.tf_lighting import create_tf_model 10 | from lighting_prediction.train_lighting import lstm_shift_events_half 11 | from map_creation.class_helpers import get_class_size, update_out_class, add_favor_factor_next_class, \ 12 | cast_y_class, decode_onehot_class 13 | 14 | # from preprocessing.music_processing import run_music_preprocessing 15 | 16 | # from tools.config import config, paths 17 | # from tools.utils import numpy_shorts 18 | 19 | from training.helpers import * 20 | 21 | 22 | def decode_class_string(y_class_num): 23 | y_class = np.zeros((len(y_class_num), 2)) 24 | for idx in range(len(y_class_num)): 25 | y_class[idx] = [y_class_num[idx][0].split(';')][0] 26 | return y_class 27 | 28 | 29 | def generate(l_in_song, time_ar, save_model_name, lstm_len, encoder_file): 30 | class_size = get_class_size(encoder_file) 31 | # gather input 32 | ############## 33 | # some timings may have been removed in sanity check 34 | l_in_song = l_in_song[:len(time_ar)] 35 | 36 | time_diff = np.concatenate(([1], np.diff(time_ar)), axis=0) 37 | 38 | x_input, _ = lstm_shift_events_half(l_in_song, time_diff, None, lstm_len) 39 | [in_song_l, in_time_l, _] = x_input 40 | # x_input = x_input[:2] 41 | 42 | # setup ML model 43 | ################ 44 | model, _ = load_keras_model(save_model_name) 45 | if model is None: 46 | print(f"Error. Could not load model {save_model_name}") 47 | 48 | """Model light""" 49 | # y_class = model.predict(x_input) 50 | # # cast to 0 and 1 51 | # y_arg_max = np.argmax(y_class, axis=1) 52 | # y_class_map = np.zeros(y_class.shape, dtype=int) 53 | # for idx in range(len(y_arg_max)): 54 | # y_class_map[idx][y_arg_max[idx]] = 1 55 | 56 | """Model full""" 57 | # # apply event model 58 | # ################### 59 | # y_class = None 60 | # y_class_map = [] 61 | # y_class_last = None 62 | # class_size = get_class_size(paths.events_classify_encoder_file) 63 | # for idx in range(len(in_song_l)): 64 | # if y_class is None: 65 | # in_class_l = np.zeros((len(in_song_l), config.event_lstm_len, class_size)) 66 | # 67 | # in_class_l = update_out_class(in_class_l, y_class, idx) 68 | # 69 | # # normal lstm lstm 70 | # ds_train = [in_song_l[idx:idx + 1], in_time_l[idx:idx + 1], in_class_l[idx:idx + 1]] 71 | # y_class = model.predict(x=ds_train) 72 | # 73 | # # add factor to NEXT class 74 | # y_class = add_favor_factor_next_class(y_class, y_class_last) 75 | # 76 | # # find class winner 77 | # y_class = cast_y_class(y_class) 78 | # 79 | # y_class_last = y_class.copy() 80 | # y_class_map.append(y_class) 81 | 82 | """Model half""" 83 | # apply note/event model 84 | ################### 85 | y_class = None 86 | rd_counter = 0 87 | rd_distribution = None 88 | y_class_map = np.zeros((in_time_l.shape[0], in_time_l.shape[1], class_size), dtype=int) 89 | for idx in range(len(in_song_l)): 90 | if y_class is None: 91 | in_class_l = np.zeros((len(in_song_l), lstm_len, class_size)) 92 | else: 93 | in_class_l[idx] = y_class_map[idx - 1] 94 | 95 | # normal lstm lstm 96 | ds_train = [in_song_l[idx:idx + 1], in_time_l[idx:idx + 1], in_class_l[idx:idx + 1]] 97 | y_class = model.predict(x=ds_train, verbose=0) 98 | 99 | y_class, rd_distribution, rd_counter = apply_random_mapper(y_class, rd_distribution, rd_counter) 100 | 101 | # (TODO: add favor_bombs flag) 102 | # find class winner 103 | y_arg_max = np.argmax(y_class, axis=2)[0] 104 | for imax in range(len(y_arg_max)): 105 | y_class_map[idx, imax][y_arg_max[imax]] = 1 106 | 107 | # decode event class output 108 | y_class_map = y_class_map.reshape(-1, y_class_map.shape[2]) 109 | y_class_num = decode_onehot_class(y_class_map, encoder_file) 110 | if encoder_file == paths.events_classify_encoder_file: 111 | y_class_num = decode_class_string(y_class_num) 112 | # print("Finished lighting generator") 113 | # events_out = np.concatenate((time_ar[config.event_lstm_len+1:].reshape(-1, 1), y_class_num), axis=1) 114 | else: 115 | # print("Finished mapping generator") 116 | pass 117 | return y_class_num 118 | 119 | 120 | def apply_random_mapper(y_class, rd_distribution, rd_counter): 121 | # Warning: Prediction is not stable enough 122 | # May lead to random resampling of ml output 123 | 124 | # initiate random map center 125 | c_window = 60 126 | c_val = config.random_note_map_factor 127 | if c_val == 0: 128 | return y_class, rd_distribution, rd_counter 129 | 130 | # scale batch to [0, 1] 131 | y_class_sc = y_class / np.max(y_class) 132 | 133 | if rd_counter <= 0: 134 | # initialization 135 | rd_distribution = np.random.rand(y_class.shape[0], y_class.shape[1], y_class.shape[2]) * c_val 136 | center = int(np.random.rand(1)[0] * y_class.shape[2]) 137 | # check center is in bounds 138 | if center < c_window: 139 | center = c_window 140 | elif center > y_class.shape[2] - c_window - 1: 141 | center = y_class.shape[2] - c_window - 1 142 | center_start = center - c_window 143 | center_end = center + c_window 144 | # shift emphasis towards center 145 | rd_distribution[:, :, center_start:center_end] += c_val 146 | rd_distribution += 1 147 | rd_counter = config.random_note_map_change 148 | 149 | rd_counter -= 1 150 | 151 | y_class = y_class_sc * rd_distribution 152 | 153 | return y_class, rd_distribution, rd_counter 154 | 155 | 156 | if __name__ == '__main__': 157 | # generate() 158 | pass 159 | -------------------------------------------------------------------------------- /beat_prediction/ai_beat_gen.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import numpy as np 3 | from datetime import datetime 4 | from keras.optimizers import Adam 5 | from PIL import Image 6 | import matplotlib.pyplot as plt 7 | 8 | # Get the main script's directory 9 | import sys, os 10 | script_dir = os.path.dirname(os.path.realpath(__file__)) 11 | parent_dir = os.path.abspath(os.path.join(script_dir, "..")) 12 | sys.path.append(parent_dir) 13 | 14 | from beat_prediction.find_beats import find_beats, get_pitch_times, samplerate_beats 15 | # from beat_prediction.beat_to_lstm import beat_to_lstm 16 | from beat_prediction.beat_prop import get_beat_prop, delete_offbeats 17 | 18 | from preprocessing.bs_mapper_pre import load_beat_data 19 | 20 | from training.helpers import test_gpu_tf, filter_by_bps, calc_class_weight 21 | from training.tensorflow_models import create_music_model 22 | 23 | from tools.config import config, paths 24 | from tools.utils import numpy_shorts 25 | from tools.utils.numpy_shorts import reduce_number_of_songs 26 | 27 | 28 | def main(): 29 | # Setup configuration 30 | ##################### 31 | # # Check Cuda compatible GPU 32 | # if not test_gpu_tf(): 33 | # exit() 34 | 35 | # model name setup 36 | ################## 37 | # create timestamp 38 | dateTimeObj = datetime.now() 39 | timestamp = f"{dateTimeObj.month}_{dateTimeObj.day}__{dateTimeObj.hour}_{dateTimeObj.minute}" 40 | save_model_name = f"tf_beat_gen_{config.min_bps_limit}_{config.max_bps_limit}_{timestamp}.h5" 41 | 42 | # gather input 43 | ############## 44 | print("Gather input data:", end=' ') 45 | 46 | # ram_limit = int(11 * config.ram_limit) # 100 songs ~9gb 47 | name_ar, _ = filter_by_bps(config.min_bps_limit, config.max_bps_limit) 48 | 49 | # Reduce amount of songs 50 | name_ar = reduce_number_of_songs(name_ar, hard_limit=config.beat_song_limit) 51 | 52 | # if len(name_ar) > ram_limit: 53 | # print(f"Info: Loading reduced song number into generator to not overload the RAM " 54 | # f"(previous {len(name_ar)}") 55 | # name_ar = name_ar[:ram_limit] 56 | 57 | # print(f"Importing {len(name_ar)} songs") 58 | song_input, pitch_input = find_beats(name_ar, train_data=True) 59 | 60 | # calculate discrete timings 61 | pitch_times = [] 62 | n_x = song_input[0].shape[0] 63 | for idx in range(len(pitch_input)): 64 | pitch_times.append(get_pitch_times(pitch_input[idx])) 65 | # resize song input to fit pitch algorithm 66 | im = Image.fromarray(song_input[idx]) 67 | im = im.resize((len(pitch_input[idx]), n_x)) 68 | song_input[idx] = np.asarray(im) 69 | # # test song input 70 | # plt.imshow(song_input[idx]) 71 | # plt.show() 72 | 73 | # get beat proposals 74 | [x_volume, x_onset] = get_beat_prop(song_input) 75 | 76 | # load real beats 77 | _, real_beats = load_beat_data(name_ar) 78 | 79 | beat_resampled = samplerate_beats(real_beats, pitch_times) 80 | 81 | # free ram 82 | del _, real_beats 83 | del im 84 | del pitch_input, pitch_times 85 | gc.collect() 86 | 87 | # delete a fraction of the offbeats to balance the dataset (n_beats << n_empty) 88 | beat_resampled, song_input, x_volume, x_onset = delete_offbeats(beat_resampled, song_input, 89 | x_volume, x_onset) 90 | 91 | print("Reshape input for AI model...") 92 | 93 | def lstm_reshape_half(song_list, y=False): 94 | tcn_len = config.tcn_len 95 | song_ar = None 96 | for song in song_list: 97 | delete = song.shape[-1] % tcn_len 98 | if delete != 0: 99 | if len(song.shape) == 2: 100 | song = song[:, :-delete] 101 | else: 102 | song = song[:-delete] 103 | if len(song.shape) == 2: 104 | song = song.reshape(-1, tcn_len, song.shape[0]) 105 | elif len(song.shape) == 1 and not y: 106 | song = song.reshape(-1, tcn_len, 1) 107 | elif len(song.shape) == 1 and y: 108 | song = song[tcn_len - 1::tcn_len] 109 | song_ar = numpy_shorts.np_append(song_ar, song, axis=0) 110 | return song_ar 111 | 112 | x_volume = lstm_reshape_half(x_volume) 113 | x_onset = lstm_reshape_half(x_onset) 114 | x_song = lstm_reshape_half(song_input) 115 | y = lstm_reshape_half(beat_resampled, y=True) 116 | 117 | # x_volume = tcn_reshape(x_volume) 118 | # x_onset = tcn_reshape(x_onset) 119 | # x_song, y = beat_to_lstm(song_input, beat_resampled) 120 | 121 | x_song = numpy_shorts.minmax_3d(x_song) 122 | cw = calc_class_weight(y) 123 | 124 | x_input = [x_song, x_volume, x_onset] 125 | test_len = config.tcn_test_samples 126 | x_part = [x_song[:test_len], x_volume[:test_len], x_onset[:test_len]] 127 | y_part = y[:test_len] 128 | 129 | del x_song, x_volume, x_onset 130 | gc.collect() 131 | 132 | # x_last_beats = last_beats_to_lstm(y) 133 | 134 | # setup ML model 135 | ################ 136 | model = create_music_model('tcn', song_input[0].shape[0], config.tcn_len) 137 | adam = Adam(learning_rate=config.beat_learning_rate, 138 | weight_decay=config.beat_learning_rate * 2 / config.beat_n_epochs) 139 | model.compile(loss='binary_crossentropy', optimizer=adam, 140 | metrics=['accuracy']) 141 | 142 | model.summary() 143 | 144 | model.fit(x=x_input, y=y, epochs=config.beat_n_epochs, shuffle=True, 145 | batch_size=config.beat_batch_size, verbose=1, class_weight=cw) 146 | 147 | # save model 148 | ############ 149 | print(f"Saving model at: {paths.model_path + save_model_name}") 150 | model.save(paths.model_path + save_model_name) 151 | 152 | # plot test result 153 | ################## 154 | if True: 155 | try: 156 | y_pred = model.predict(x_part, verbose=0) 157 | # bound prediction to 0 or 1 158 | thresh = 0.5 159 | y_pred[y_pred > thresh] = 1 160 | y_pred[y_pred <= thresh] = 0 161 | 162 | fig = plt.figure() 163 | # plt.plot(y, 'b-', label='original') 164 | y_count = np.arange(0, len(y_part), 1) 165 | y_count = y_part * y_count 166 | 167 | plt.vlines(y_count, ymin=-0.1, ymax=1.1, colors='k', label='original', linewidth=2) 168 | plt.plot(y_pred, 'b-', label='prediction', linewidth=1) 169 | plt.legend() 170 | plt.show() 171 | except Exception as e: 172 | print(f"Error: {type(e).__name__}") 173 | print(f"Error message: {e}") 174 | print("Error in plotting beat generator evaluation. Continue.") 175 | 176 | print("Finished beat generator training") 177 | 178 | 179 | if __name__ == '__main__': 180 | main() 181 | -------------------------------------------------------------------------------- /evaluation/evaluate_beat_algorithms.py: -------------------------------------------------------------------------------- 1 | # import matplotlib.pyplot as plt 2 | from itertools import product 3 | import numpy as np 4 | import sys, os 5 | 6 | script_dir = os.path.dirname(os.path.realpath(__file__)) 7 | parent_dir = os.path.abspath(os.path.join(script_dir, "..")) 8 | sys.path.append(parent_dir) 9 | 10 | from tools.config import paths, config 11 | 12 | # overwrite test path and config 13 | ################################ 14 | config.training_songs_diff = 'ExpertPlus' 15 | config.allow_training_diff2 = True 16 | config.training_songs_diff2 = 'Expert' 17 | 18 | paths.bs_input_path = r'C:\Users\frede\Desktop\BS_Automapper\Data\training\evaluate_beats\bs_map_input' + '/' 19 | paths.copy_path_map = r'C:\Users\frede\Desktop\BS_Automapper\Data\training\evaluate_beats\maps' + '/' 20 | paths.dict_all_path = r'C:\Users\frede\Desktop\BS_Automapper\Data\training\evaluate_beats\maps_dict_all' + '/' 21 | paths.copy_path_song = r'C:\Users\frede\Desktop\BS_Automapper\Data\training\evaluate_beats\songs_egg' + '/' 22 | paths.diff_ar_file = r'C:\Users\frede\Desktop\BS_Automapper\Data\training\evaluate_beats\songs_diff' + "/diff_ar.npy" 23 | paths.name_ar_file = r'C:\Users\frede\Desktop\BS_Automapper\Data\training\evaluate_beats\songs_diff' + "/name_ar.npy" 24 | paths.songs_pred = paths.copy_path_song 25 | 26 | from beat_prediction.validate_find_beats import plot_beat_vs_real 27 | # from bs_shift.shift import shift_bs_songs, delete_old_files 28 | # from bs_shift.bps_find_songs import bps_find_songs 29 | # from bs_shift.map_to_dict_all import map_to_dict_all 30 | # from bs_shift.cleanup_n_format import clean_songs 31 | from training.helpers import filter_by_bps 32 | from preprocessing.bs_mapper_pre import load_beat_data 33 | from map_creation.gen_beats import main as gen_beats_main 34 | 35 | 36 | # allowed time difference in seconds 37 | tolerance = 0.02 38 | 39 | 40 | def calculate_beat_accuracy(beat_pred, beat_real, tolerance=0.05): 41 | true_positives = 0 42 | false_positives = 0 43 | false_negatives = 0 44 | 45 | for beat_time_pred in beat_pred: 46 | # Find the nearest beat in the ground truth 47 | closest_beat_real = min(beat_real, key=lambda x: abs(x - beat_time_pred)) 48 | 49 | # Check if the difference is within the tolerance 50 | if abs(beat_time_pred - closest_beat_real) <= tolerance: 51 | true_positives += 1 52 | else: 53 | false_positives += 1 54 | 55 | false_negatives = len(beat_real) - true_positives 56 | 57 | precision = true_positives / (true_positives + false_positives) if true_positives + false_positives > 0 else 0 58 | recall = true_positives / (true_positives + false_negatives) if true_positives + false_negatives > 0 else 0 59 | 60 | f_measure = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0 61 | 62 | return precision, recall, f_measure 63 | 64 | 65 | # if True: 66 | if False: 67 | # Shift stuff 68 | ############# 69 | clean_songs() 70 | 71 | delete_old_files() 72 | 73 | try: 74 | shift_bs_songs() 75 | except: 76 | print("Error while analyzing song") 77 | 78 | # Start casting to dictionary (notes, events, etc) 79 | map_to_dict_all() 80 | 81 | # Calculate notes per sec for each song 82 | bps_find_songs() 83 | 84 | # Import map data 85 | ################# 86 | name_ar, _ = filter_by_bps(0.1, 50) 87 | _, real_beats = load_beat_data(name_ar, return_notes=True) 88 | name_ar = [name_ar[0]] 89 | real_beats = real_beats[0] 90 | 91 | # Set tuning parameters 92 | ####################### 93 | bool_options = [True] 94 | # float_1_options = np.arange(60, 120, 10).tolist() 95 | # float_1_options = np.arange(0, 0.9, 0.1).tolist() 96 | float_1_options = [0.22] 97 | # float_2_options = np.arange(0, 0.51, 0.03).tolist() 98 | float_2_options = [0.1] 99 | # float_3_options = np.arange(0.15, 0.81, 0.05).tolist() 100 | float_3_options = [0.3] 101 | float_4_options = [2] 102 | float_5_options = [3] 103 | float_6_options = [None] 104 | # int_1_options = list(range(1, 50)) 105 | int_1_options = [None] 106 | 107 | # Set Tuning function 108 | ##################### 109 | # Perform grid search 110 | best_accuracy = 0.0 111 | best_parameters = None 112 | 113 | total_iterations = len(bool_options) 114 | total_iterations *= len(float_1_options) 115 | total_iterations *= len(float_2_options) 116 | total_iterations *= len(float_3_options) 117 | total_iterations *= len(float_4_options) 118 | total_iterations *= len(float_5_options) 119 | total_iterations *= len(float_6_options) 120 | total_iterations *= len(int_1_options) 121 | 122 | iteration = 0 123 | 124 | # Run analysis 125 | ############## 126 | for (bool_1, 127 | float_1, 128 | float_2, 129 | float_3, 130 | float_4, 131 | float_5, 132 | float_6, 133 | int_1) in product(bool_options, 134 | float_1_options, 135 | float_2_options, 136 | float_3_options, 137 | float_4_options, 138 | float_5_options, 139 | float_6_options, 140 | int_1_options): 141 | iteration += 1 142 | print(f"Iteration {iteration} of {total_iterations}") 143 | # overwrite parameters 144 | config.add_silence_flag = bool_1 145 | config.add_beat_intensity_orig = 50 146 | config.silence_threshold_orig = float_1 147 | config.thresh_beat = float_2 148 | config.map_filler_iters = 0 149 | config.thresh_pitch = float_3 150 | config.factor_pitch_certainty = float_4 151 | config.factor_pitch_meanmax = float_5 152 | 153 | # Call your beat algorithm with the current parameter values and calculate accuracy 154 | beats_algo = gen_beats_main(name_ar, debug_beats=True) 155 | # Calculate accuracy 156 | precision, recall, f_measure = calculate_beat_accuracy(beats_algo, real_beats, tolerance) 157 | # accuracy = (precision + recall) / 2 158 | accuracy = f_measure 159 | 160 | # Update the best parameters if the current combination performs better 161 | parameters = (bool_1, float_1, float_2, float_3, float_4, float_5, float_6, int_1) 162 | print(f"Current Accuracy: {accuracy}, Precision: {precision}, Recall: {recall}") 163 | if accuracy >= best_accuracy: 164 | best_accuracy = accuracy 165 | best_parameters = parameters 166 | print("New best Parameters:", best_parameters) 167 | else: 168 | print("Bad Parameters:", parameters) 169 | 170 | print("Best Parameters:", best_parameters) 171 | print("Best Accuracy:", best_accuracy) 172 | 173 | if total_iterations == 1: 174 | # Compare results 175 | ################# 176 | precision, recall, f_measure = calculate_beat_accuracy(beats_algo, real_beats, tolerance=tolerance) 177 | print("Precision: ", precision) 178 | print("Recall: ", recall) 179 | print("F-measure: ", f_measure) 180 | 181 | plot_beat_vs_real(beats_algo, real_beats) 182 | -------------------------------------------------------------------------------- /preprocessing/music_processing.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageFilter 2 | # from progressbar import ProgressBar 3 | from scipy import signal 4 | import aubio 5 | import numpy as np 6 | 7 | from bs_shift.bps_find_songs import bps_find_songs 8 | from tools.config import paths, config 9 | from tools.fail_list.black_list import append_fail, delete_fails 10 | from tools.utils import numpy_shorts 11 | from tools.utils.load_and_save import save_npy 12 | 13 | # from line_profiler_pycharm import profile 14 | 15 | 16 | def load_song(data_path: str, time_ar: list, return_raw=False) -> np.array: 17 | total_read = 0 18 | samples_list = [] 19 | src = aubio.source(data_path, channels=1, samplerate=config.samplerate_music) 20 | 21 | while True: 22 | samples, read = src() 23 | samples_list.extend(samples) 24 | total_read += read 25 | if read < src.hop_size: 26 | break 27 | 28 | samples_ar = np.asarray(samples_list) 29 | if return_raw: 30 | return samples_ar 31 | 32 | # x=time slot, y=window size 33 | if time_ar is None: 34 | # resample song into window size 35 | window_size = int(config.window * config.samplerate_music) 36 | windows_counts = int(len(samples_ar) / window_size) 37 | samples_ar_split = samples_ar[:int(windows_counts * window_size)] 38 | samples_ar_split = samples_ar_split.reshape((windows_counts, window_size)) 39 | remove_idx = [] 40 | else: 41 | window_size = int(config.window * config.samplerate_music) 42 | time_ar = np.asarray(time_ar) * config.samplerate_music 43 | time_ar = np.around(time_ar).astype('int') 44 | samples_ar_split = [] 45 | remove_idx = [] 46 | max_time = len(samples_ar) 47 | for idx, sample in enumerate(time_ar): 48 | # sanity check 49 | start_idx = sample - int(window_size/2) 50 | if start_idx < 0: 51 | remove_idx.append(idx) 52 | continue 53 | end_idx = sample + int(window_size/2) 54 | if end_idx >= max_time: 55 | remove_idx.append(idx) 56 | continue 57 | # add time window 58 | samples_ar_split.append(samples_ar[start_idx:end_idx]) 59 | samples_ar_split = np.asarray(samples_ar_split) 60 | 61 | return samples_ar_split, remove_idx 62 | 63 | 64 | def log_specgram(audio, sample_rate, window_size, 65 | step_size=1, eps=1e-10): 66 | nperseg = int(round(window_size * sample_rate / 1e3)) 67 | noverlap = int(round(step_size * sample_rate / 1e3)) 68 | freqs, _, spec = signal.spectrogram(audio, 69 | fs=sample_rate, 70 | window='hann', 71 | nperseg=nperseg, 72 | noverlap=noverlap, 73 | detrend=False) 74 | return freqs, np.log(spec.T.astype(np.float32) + eps) 75 | 76 | 77 | def process_song(song_ar: np.array) -> np.array: 78 | # use absolut values 79 | song_ar = np.abs(song_ar) 80 | # amplify signal 81 | song_ar *= 100 82 | 83 | # convert into spectrogram 84 | window_size = 100 85 | sample_rate = int(config.samplerate_music / 1) 86 | 87 | spectrogram_ar = [] 88 | n_x = None 89 | for n in range(song_ar.shape[0]): 90 | _, spectrogram = log_specgram(song_ar[n], sample_rate, window_size) 91 | # shift spectrogram to 0+ 92 | spectrogram -= spectrogram.min() 93 | 94 | # plt.imshow(spectrogram.T, aspect='auto', origin='lower') 95 | # plt.show() 96 | 97 | # resize and filter spectrogram 98 | im = Image.fromarray(spectrogram) 99 | # im = im.filter(ImageFilter.MaxFilter(config.max_filter_size)) 100 | if n_x is None: 101 | n_x = spectrogram.shape[0] 102 | im = im.resize((config.specgram_res, n_x)) 103 | 104 | # transpose and save spectrogram 105 | im = np.asarray(im).T 106 | spectrogram_ar.append(im) 107 | 108 | # plt.imshow(im, aspect='auto', origin='lower') 109 | # plt.show() 110 | return np.asarray(spectrogram_ar) 111 | 112 | 113 | def run_music_preprocessing(names_ar: list, time_ar=None, save_file=True, song_combined=True, 114 | channels_last=True, predict_path=False): 115 | # load song notes 116 | ending = ".egg" 117 | song_ar = [] 118 | rm_index_ar = [] 119 | errors_appeared = 0 120 | # rm_index = None 121 | 122 | # bar = ProgressBar(max_value=len(names_ar)) 123 | 124 | # print(f"Importing {len(names_ar)} songs") 125 | for idx, n in enumerate(names_ar): 126 | # bar.update(idx+1) 127 | if time_ar is None: 128 | time = None 129 | else: 130 | time = time_ar[idx] 131 | if not n.endswith(ending): 132 | n += ending 133 | try: 134 | if predict_path: 135 | song, remove_idx = load_song(paths.songs_pred + n, time_ar=time) 136 | else: 137 | song, remove_idx = load_song(paths.copy_path_song + n, time_ar=time) 138 | except Exception as e: 139 | print(f"Problem with song: {n}") 140 | print(f"Exception details: {str(e)}") 141 | # print(paths.copy_path_song) 142 | # exit() 143 | append_fail(n[:-4]) 144 | errors_appeared += 1 145 | continue 146 | 147 | rm_index_ar.append(remove_idx) 148 | ml_input_song = process_song(song) 149 | 150 | # if song_combined: # does not work 151 | # song_ar.extend(ml_input_song) 152 | # else: 153 | song_ar.append(ml_input_song) 154 | if errors_appeared > 0: 155 | delete_fails() 156 | bps_find_songs() 157 | print("Deleted failed maps, please re-run!") 158 | exit() 159 | # scale song to 0-1 160 | # if len(np.asarray(song_ar).shape) > 1: 161 | # song_ar = np.asarray(song_ar) 162 | # # song_ar = song_ar.clip(min=0) 163 | # # song_ar /= song_ar.max() 164 | # song_ar = numpy_shorts.minmax_3d(song_ar) 165 | # if channels_last: 166 | # song_ar = song_ar.reshape((song_ar.shape[0], song_ar.shape[1], song_ar.shape[2], 1)) 167 | # else: 168 | # song_ar = song_ar.reshape((song_ar.shape[0], 1, song_ar.shape[1], song_ar.shape[2])) 169 | # else: 170 | for idx, song in enumerate(song_ar): 171 | song = numpy_shorts.minmax_3d(song) 172 | if channels_last: 173 | song_ar[idx] = song.reshape((song.shape[0], song.shape[1], song.shape[2], 1)) 174 | else: 175 | song_ar[idx] = song.reshape((song.shape[0], 1, song.shape[1], song.shape[2])) 176 | 177 | if song_combined: 178 | song_ar = np.concatenate(song_ar, axis=0) 179 | 180 | if save_file: 181 | save_npy(song_ar, paths.ml_input_song_file) 182 | else: 183 | return song_ar, rm_index_ar 184 | -------------------------------------------------------------------------------- /lighting_prediction/train_lighting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # import matplotlib.pyplot as plt 3 | # import gc 4 | import random 5 | import pickle 6 | 7 | from datetime import datetime 8 | from keras.optimizers import Adam 9 | from sklearn.preprocessing import OneHotEncoder 10 | from tabulate import tabulate 11 | 12 | # Get the main script's directory 13 | import sys, os 14 | script_dir = os.path.dirname(os.path.realpath(__file__)) 15 | parent_dir = os.path.abspath(os.path.join(script_dir, "..")) 16 | sys.path.append(parent_dir) 17 | 18 | from lighting_prediction.tf_lighting import create_tf_model 19 | 20 | from preprocessing.beat_data_helper import load_raw_beat_data 21 | from preprocessing.music_processing import run_music_preprocessing 22 | 23 | from tools.config import config, paths 24 | # from tools.utils import numpy_shorts 25 | 26 | from training.helpers import test_gpu_tf, filter_by_bps, ai_encode_song, categorical_to_class 27 | 28 | 29 | def lstm_shift_events_half(song_in, time_in, ml_out, lstm_len): 30 | n_samples = len(time_in) 31 | delete = n_samples % lstm_len 32 | 33 | # ml_out 34 | if ml_out is None: 35 | l_ml_out = None 36 | l_ml_in = [] 37 | else: 38 | if delete != 0: 39 | ml_out = ml_out[:-delete] 40 | l_ml_in = ml_out.reshape(-1, lstm_len, ml_out.shape[1])[:-1] 41 | l_ml_out = ml_out.reshape(-1, lstm_len, ml_out.shape[1])[1:] 42 | if delete != 0: 43 | song_in = song_in[:-delete] 44 | time_in = time_in[:-delete] 45 | # shape(samples, lstm, features) 46 | l_song_in = song_in.reshape(-1, lstm_len, song_in.shape[1], 1)[:-1] 47 | l_time_in = time_in.reshape(-1, lstm_len, 1)[:-1] 48 | 49 | return [l_song_in, l_time_in, l_ml_in], l_ml_out 50 | 51 | 52 | def lstm_shift_events(song_in, time_in, ml_out): 53 | n_samples = len(time_in) 54 | lstm_len = config.event_lstm_len 55 | start = lstm_len + 1 56 | 57 | # ml_out 58 | if ml_out is None: 59 | l_ml_out = None 60 | else: 61 | l_ml_out = ml_out[start:] 62 | l_out_in = [] 63 | # time in 64 | l_time_in = [] 65 | 66 | for idx in range(start, n_samples): 67 | if ml_out is not None: 68 | l_out_in.append(ml_out[idx - start:idx - 1]) 69 | l_time_in.append(time_in[idx - start:idx - 1]) 70 | 71 | l_time_in = np.asarray(l_time_in).reshape((-1, lstm_len, 1)) 72 | 73 | l_out_in = np.asarray(l_out_in) 74 | # l_out_in = l_out_in.reshape(l_out_in.shape[0], 1, lstm_len, -1) 75 | # song_in 76 | song_in = song_in[start:] 77 | 78 | return [song_in, l_time_in, l_out_in], l_ml_out 79 | 80 | 81 | def onehot_encode_events(in_event): 82 | ml_input = [f'{ev[0]};{ev[1]}' for ev in in_event.astype(int)] 83 | ml_input = np.asarray(ml_input).reshape(-1, 1) 84 | encoder = OneHotEncoder(dtype=int) 85 | encoder.fit(ml_input) 86 | ml_output = encoder.transform(ml_input).toarray() 87 | 88 | # save onehot encoder 89 | with open(paths.events_classify_encoder_file, "wb") as enc_file: 90 | pickle.dump(encoder, enc_file) 91 | # return ml data 92 | return ml_output 93 | 94 | 95 | def get_time_from_events(events, diff=False): 96 | time_ar = [] 97 | rm_idx = [] 98 | for idx, ev in enumerate(events): 99 | if not diff: 100 | if len(ev) == 0: 101 | rm_idx.append(idx) 102 | # time_ar.append([]) 103 | else: 104 | time_ar.append(ev[0]) 105 | else: 106 | temp = np.diff(ev[0]) 107 | temp = np.concatenate(([1], temp), axis=0) 108 | time_ar.append(temp) 109 | 110 | rm_idx.reverse() 111 | return time_ar, rm_idx 112 | 113 | 114 | def start_training(): 115 | # Setup configuration 116 | ##################### 117 | # # Check Cuda compatible GPU 118 | # if not test_gpu_tf(): 119 | # exit() 120 | 121 | # model name setup 122 | ################## 123 | # create timestamp 124 | dateTimeObj = datetime.now() 125 | timestamp = f"{dateTimeObj.month}_{dateTimeObj.day}__{dateTimeObj.hour}_{dateTimeObj.minute}" 126 | save_model_name = f"tf_event_gen_{config.min_bps_limit}_{config.max_bps_limit}_{timestamp}.h5" 127 | 128 | # gather input 129 | ############## 130 | print("Gather input data:", end=' ') 131 | 132 | name_ar, _ = filter_by_bps(config.min_bps_limit, config.max_bps_limit) 133 | ram_limit = int(2.6 * config.vram_limit) 134 | if len(name_ar) > ram_limit: 135 | random.shuffle(name_ar) 136 | print(f"Info: Loading reduced song number into generator to not overload the VRAM (from {len(name_ar)})") 137 | name_ar = name_ar[:ram_limit] 138 | print(f"Importing {len(name_ar)} songs") 139 | 140 | # load map data 141 | _, events, _ = load_raw_beat_data(name_ar) # time, type, value 142 | time_ar, rm_idx = get_time_from_events(events, diff=False) 143 | [name_ar.pop(rm) for rm in rm_idx] 144 | _, events, _ = load_raw_beat_data(name_ar) 145 | 146 | # load song data 147 | song_ar, rm_idx = run_music_preprocessing(name_ar, time_ar, save_file=False, 148 | song_combined=False) 149 | 150 | print("Reshape input for AI model...") 151 | # remove wrong time indices 152 | for idx in range(len(rm_idx)): 153 | if len(rm_idx[idx]) > 0: 154 | events[idx] = np.delete(events[idx], rm_idx[idx], axis=-1) 155 | 156 | time_ar, _ = get_time_from_events(events, diff=True) 157 | time_ar = np.concatenate(time_ar, axis=0) 158 | events = np.concatenate(events, axis=1) 159 | in_event = events[1:].T 160 | y_out = onehot_encode_events(in_event) 161 | 162 | # encode song data 163 | song_ar = np.concatenate(song_ar, axis=0) 164 | in_song = ai_encode_song(song_ar) 165 | 166 | # x_input, y_out = lstm_shift_events(in_song, time_ar, y_out) 167 | x_input, y_out = lstm_shift_events_half(in_song, time_ar, y_out, config.event_lstm_len) 168 | 169 | # only use song and time data, not class as input 170 | # x_input = x_input[:2] 171 | 172 | # [in_song_l, in_time_l, in_class_l] = x_input 173 | # input_song_enc, input_time_lstm, input_class_lstm 174 | x_input_shape = [x.shape for x in x_input] 175 | 176 | # setup ML model 177 | ################ 178 | model = create_tf_model('lstm_half', x_input_shape, y_out.shape) 179 | adam = Adam(learning_rate=config.event_learning_rate, 180 | weight_decay=config.event_learning_rate * 2 / config.event_n_epochs) 181 | model.compile(loss='binary_crossentropy', optimizer=adam, 182 | metrics=['accuracy']) 183 | 184 | model.summary() 185 | 186 | model.fit(x=x_input, y=y_out, epochs=config.event_n_epochs, shuffle=True, 187 | batch_size=config.event_batch_size, verbose=1) 188 | 189 | # save model 190 | ############ 191 | print(f"Saving model at: {paths.model_path + save_model_name}") 192 | model.save(paths.model_path + save_model_name) 193 | 194 | # Evaluate model 195 | ################ 196 | test_samples = 10 197 | command_len = 10 198 | try: 199 | x_test = [x[:test_samples] for x in x_input] 200 | y_test = y_out[:test_samples] 201 | print("Validate model...") 202 | # validation = model.evaluate(x=x_test, y=y_test) 203 | pred_result = model.predict(x=x_test, verbose=0) 204 | 205 | pred_class = categorical_to_class(pred_result) 206 | real_class = categorical_to_class(y_test) 207 | 208 | pred_class = pred_class.flatten().tolist() 209 | real_class = real_class.flatten().tolist() 210 | 211 | print(tabulate([['Pred', pred_class], ['Real', real_class]], 212 | headers=['Type', 'Result (test data)'])) 213 | 214 | except Exception as e: 215 | print(f"Error: {type(e).__name__}") 216 | print(f"Error message: {e}") 217 | print("Error in displaying lighting evaluation. Continue.") 218 | 219 | print("Finished lighting generator training") 220 | 221 | 222 | if __name__ == '__main__': 223 | start_training() 224 | -------------------------------------------------------------------------------- /training/tensorflow_models.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Dense, Input, LSTM, Flatten, Dropout, \ 2 | MaxPooling2D, Conv2D, BatchNormalization, SpatialDropout2D, concatenate, \ 3 | Reshape, Conv2DTranspose, UpSampling2D 4 | from tcn import TCN # pip install keras-tcn 5 | from keras.models import Model 6 | import numpy as np 7 | 8 | from tools.config import config 9 | 10 | 11 | def create_keras_model(model_type, dim_in=[], dim_out=None): 12 | print("Setup keras model") 13 | if model_type == 'lstm1': 14 | # in_song (lin), in_time (rec), in_class (rec) 15 | input_a = Input(shape=(dim_in[0]), name='input_song_enc') 16 | input_b = Input(shape=(dim_in[1]), name='input_time_lstm') 17 | input_c = Input(shape=(dim_in[2]), name='input_class_lstm') 18 | 19 | lstm_b = LSTM(32, return_sequences=True)(input_b) 20 | lstm_c = LSTM(256, return_sequences=True)(input_c) 21 | 22 | lstm_in = concatenate([lstm_b, lstm_c]) 23 | lstm_out = LSTM(64, return_sequences=False)(lstm_in) 24 | 25 | x = concatenate([input_a, lstm_out]) 26 | x = Dense(256, activation='relu')(x) 27 | x = Dropout(0.05)(x) 28 | x = Dense(512, activation='sigmoid')(x) 29 | 30 | out = Dense(dim_out, activation='softmax', name='output')(x) 31 | 32 | model = Model(inputs=[input_a, input_b, input_c], outputs=out) 33 | return model 34 | 35 | elif model_type == 'lstm_half': 36 | # in_song (lin), in_time (rec), in_class (rec) 37 | input_a = Input(shape=(dim_in[0][1:]), name='input_song_enc') 38 | input_b = Input(shape=(dim_in[1][1:]), name='input_time_lstm') 39 | input_c = Input(shape=(dim_in[2][1:]), name='input_class_lstm') 40 | 41 | conv = Conv2D(32, kernel_size=3, activation='relu')(input_a) 42 | conv = Flatten('channels_last')(conv) 43 | lstm_b = LSTM(32, return_sequences=True)(input_b) 44 | lstm_c = LSTM(64, return_sequences=True)(input_c) 45 | 46 | lstm_in = concatenate([lstm_b, lstm_c]) 47 | lstm_out = LSTM(64, return_sequences=False)(lstm_in) 48 | 49 | x = concatenate([conv, lstm_out]) 50 | x = Dense(1024, activation='relu')(x) 51 | x = Dropout(0.05)(x) 52 | x = Dense(dim_out[1] * dim_out[2], activation='sigmoid')(x) 53 | 54 | out = Reshape(target_shape=(dim_out[1:]))(x) 55 | # out = Dense(dim_out[1:], activation='softmax', name='output')(x) 56 | 57 | model = Model(inputs=[input_a, input_b, input_c], outputs=out) 58 | return model 59 | 60 | # autoencoder 61 | elif model_type == 'enc1': 62 | input_img = Input(shape=(24, 20, 1)) 63 | # Conv2d(1, 32, 3, padding=1) with Relu 64 | x = Conv2D(32, (3, 3), activation='relu', padding='same')(input_img) 65 | # Dropout2d(0.2) 66 | # x = SpatialDropout2D(0.05)(x) 67 | # MaxPool2d(2, 2) 68 | x = MaxPooling2D((2, 2), padding='same')(x) 69 | 70 | # Conv2d(32, 16, 3, padding=1) with Relu 71 | x = Conv2D(16, (3, 3), activation='relu', padding='same')(x) 72 | # MaxPool2d(2, 2) 73 | x = MaxPooling2D((2, 2), padding='same')(x) 74 | 75 | # Flatten(start_dim=1) 76 | x = Flatten('channels_last')(x) 77 | # Dropout(0.1) 78 | x = Dropout(0.05)(x) 79 | # Relu(in=480, out=128) 80 | x = Dense(128, activation='relu')(x) 81 | # Relu(in=128, out=config.bottleneck_len) 82 | x = Dense(config.bottleneck_len, activation='relu')(x) 83 | 84 | model = Model(input_img, x) 85 | return model 86 | 87 | elif model_type == 'dec1': 88 | input_img = Input(shape=config.bottleneck_len) 89 | # Relu(in=config.bottleneck_len, out=128) 90 | x = Dense(128, activation='relu')(input_img) 91 | # Relu(in=128, out=480) 92 | x = Dense(640, activation='relu')(x) 93 | # Unflatten(1, (16, 6, 5)) 94 | x = Reshape(target_shape=(10, 8, 8))(x) 95 | # input shape (batch_size, rows, cols, channels) 96 | 97 | # x = Conv2D(32, kernel_size=2, activation='relu')(x) 98 | # x = UpSampling2D((2, 2))(x) 99 | # x = Conv2D(1, kernel_size=2, activation='sigmoid')(x) 100 | # x = UpSampling2D((2, 2))(x) 101 | 102 | # # ConvTranspose2d(16, 32, 2, stride=2) with Relu 103 | x = Conv2DTranspose(32, kernel_size=5, strides=2, activation='relu')(x) 104 | # ConvTranspose2d(32, 1, 2, stride=2) with Sigmoid 105 | x = Conv2DTranspose(1, kernel_size=2, strides=1, activation='sigmoid')(x) 106 | # output shape (batch_size, new_rows, new_cols, filters) 107 | 108 | model = Model(input_img, x) 109 | return model 110 | 111 | 112 | def create_music_model(model_type, dim_in, tcn_len): 113 | # https://github.com/giusenso/seld-tcn/blob/master/keras_model.py 114 | # https://github.com/philipperemy/keras-tcn 115 | print("Setup music model") 116 | # input song 264.X (batch_size, timesteps, input_dim) 117 | # output beats 1.X (beat IO x sample) 118 | 119 | if model_type == 'tcn': 120 | # input song (batch_size, timesteps, input_dim) 121 | input_a = Input(shape=(tcn_len, dim_in), name='input_song_img') 122 | # input beat (batch_size, time_steps, 1) 123 | input_b = Input(shape=(tcn_len, 1), name='input_beat_prop') 124 | input_c = Input(shape=(tcn_len, 1), name='input_onset_detection') 125 | 126 | a = TCN(nb_filters=128, 127 | kernel_size=2, 128 | nb_stacks=1, 129 | dilations=(1, 2, 4, 8, 16, 32), 130 | dropout_rate=0.0, 131 | use_skip_connections=True, 132 | use_batch_norm=False, 133 | use_weight_norm=False, 134 | use_layer_norm=False, 135 | return_state=False, 136 | return_sequences=False, 137 | activation='relu' 138 | )(input_a) 139 | 140 | b = LSTM(4, return_sequences=False)(input_b) 141 | c = LSTM(4, return_sequences=False)(input_c) 142 | 143 | x = concatenate([a, b, c]) 144 | x = Dense(256, activation='relu')(x) 145 | x = Dense(128, activation='relu')(x) 146 | 147 | out = Dense(1, activation='sigmoid')(x) 148 | 149 | model = Model(inputs=[input_a, input_b, input_c], outputs=out) 150 | return model 151 | 152 | # if model_type == 'lstm': 153 | # # input song (batch_size, time_steps, seq_len) 154 | # input_a = Input(shape=(tcn_len, dim_in), name='input_song_img') 155 | # # input beat (batch_size, time_steps, 1) 156 | # input_b = Input(shape=(tcn_len, 1), name='input_beat_prop') 157 | # 158 | # a = LSTM(256, return_sequences=True)(input_a) 159 | # a = LSTM(128, return_sequences=True)(a) 160 | # 161 | # b = LSTM(8, return_sequences=True)(input_b) 162 | # 163 | # ab = concatenate([a, b]) 164 | # x = LSTM(64, return_sequences=False)(ab) 165 | # x = Dense(32, activation='relu')(x) 166 | # x = Dense(32, activation='relu')(x) 167 | # 168 | # out = Dense(1, activation='sigmoid')(x) 169 | # 170 | # model = Model(inputs=[input_a, input_b], outputs=out) 171 | # return model 172 | 173 | 174 | def create_post_model(model_type, lstm_len: int, dim_out=2): 175 | print("Setup keras model") 176 | if model_type == 'lstm1': 177 | # in_song (lin), in_time (rec), in_class (rec) 178 | input_a = Input(shape=(lstm_len, 5), name='input_type_cut_timeDiff_lastNotes_lstm') 179 | # input_b = Input(shape=(dim_in[1]), name='input_note_cut_lstm') 180 | # input_c = Input(shape=(dim_in[2]), name='input_time_diff_lstm') 181 | 182 | lstm_a = LSTM(256, return_sequences=True)(input_a) 183 | # lstm_c = LSTM(128, return_sequences=True)(input_c) 184 | 185 | # lstm_in = concatenate([lstm_b, lstm_c]) 186 | lstm_out = LSTM(128, return_sequences=False)(lstm_a) 187 | 188 | # x = concatenate([input_a, lstm_out]) 189 | x = Dense(512, activation='relu')(lstm_out) 190 | x = Dropout(0.05)(x) 191 | x = Dense(256, activation='relu')(x) 192 | 193 | out = Dense(dim_out, activation='relu', name='output')(x) 194 | 195 | model = Model(inputs=[input_a], outputs=out) 196 | return model 197 | -------------------------------------------------------------------------------- /bs_shift/shift.py: -------------------------------------------------------------------------------- 1 | # import glob 2 | import json 3 | import os 4 | import shutil 5 | import sys 6 | # from progressbar import ProgressBar 7 | 8 | # Get the main script's directory 9 | import sys, os 10 | script_dir = os.path.dirname(os.path.realpath(__file__)) 11 | parent_dir = os.path.abspath(os.path.join(script_dir, "..")) 12 | sys.path.append(parent_dir) 13 | 14 | 15 | # sys.path.insert(1, os.path.join(sys.path[0], '..')) 16 | from tools.utils.str_compare import str_compare 17 | from tools.fail_list.black_list import append_fail, delete_fails 18 | from tools.utils.index_find_str import return_find_str 19 | 20 | # set folder paths 21 | from tools.config import paths, config 22 | # import exclusion names 23 | from tools.config import exclusion 24 | 25 | from bs_shift.bps_find_songs import bps_find_songs 26 | from bs_shift.map_to_dict_all import map_to_dict_all 27 | 28 | # paths 29 | copy_path_map = paths.copy_path_map 30 | copy_path_song = paths.copy_path_song 31 | 32 | if not os.path.isdir(paths.bs_input_path): 33 | print("Could not find Beat Saber path! Exit") 34 | exit() 35 | if not os.path.isdir(copy_path_map) or not os.path.isdir(copy_path_song): 36 | print("Could not find copy path! Exit") 37 | exit() 38 | 39 | 40 | def read_dat_file(file_path: str, filename="") -> list[str]: 41 | if filename != "": 42 | file_path = os.path.join(file_path, filename) 43 | with open(file_path) as f: 44 | dat_content = f.readlines() 45 | return dat_content 46 | 47 | 48 | def read_json_content_file(file_path: str, filename="") -> list[str]: 49 | if filename != "": 50 | file_path = os.path.join(file_path, filename) 51 | with open(file_path) as f: 52 | dat_content = json.load(f) 53 | return dat_content 54 | 55 | 56 | def delete_old_files(): 57 | # delete old files 58 | print("Delete old files") 59 | for de in os.listdir(copy_path_song): 60 | # input("Del? " + de) 61 | os.remove(copy_path_song + de) 62 | for de in os.listdir(copy_path_map): 63 | # input("Del? " + de) 64 | os.remove(copy_path_map + de) 65 | 66 | 67 | def shift_bs_songs(): 68 | # difficulty setup 69 | diff = config.training_songs_diff 70 | allow_diff2 = config.allow_training_diff2 # (TODO: check implementation) 71 | diff2 = config.training_songs_diff2 72 | 73 | # variables setup 74 | num_cur = 0 75 | num_all = len(os.listdir(paths.bs_input_path)) 76 | print(f"Check {num_all} songs - may take a while") 77 | count = 0 78 | song_name_list = [] 79 | song_name = None 80 | 81 | # bar = ProgressBar(max_value=num_all+1) 82 | exp_plus_name = f"{diff}.dat" 83 | exp_name = f"{diff2}.dat" 84 | 85 | # walk through bs directory 86 | for root, dirs, files in os.walk(paths.bs_input_path): 87 | both = False 88 | excl_true = str_compare(str=os.path.basename(root), str_list=exclusion.exclusion, return_str=False, 89 | silent=False) 90 | if excl_true: 91 | continue 92 | 93 | for file in files: 94 | # get only ExpertPlus (or Expert for allow_diff2 == True) 95 | # if file.endswith(diff + ".dat") or (allow_diff2 and not both and file.endswith(diff2 + ".dat")): 96 | if (allow_diff2 and not both) and file == exp_plus_name or (allow_diff2 and not both and file == exp_name): 97 | both = True 98 | # print(os.path.join(root, file)) 99 | # print(files) 100 | 101 | # get song name 102 | info_file = False 103 | try: 104 | for n_file in files: 105 | if n_file.lower() == "info.dat": 106 | num_cur += 1 107 | # bar.update(num_cur) 108 | # import dat file 109 | dat_content = read_dat_file(os.path.join(root, n_file)) 110 | 111 | if config.exclude_requirements: 112 | search_string = '"_requirements":' 113 | for s in dat_content: 114 | if search_string in s: 115 | if '[]' not in s: 116 | print("Excluding maps with custom mod requirements.") 117 | append_fail(os.path.basename(root)) 118 | 119 | search_string = '"_songName"' 120 | # get name line 121 | for s in dat_content: 122 | if search_string in s: 123 | song_name, _ = return_find_str(0, s, search_string, True) 124 | info_file = True 125 | break 126 | 127 | while song_name.lower() in song_name_list: 128 | # song found in different versions 129 | song_name += "_2" 130 | song_name_list.append(song_name.lower()) 131 | # Finished name 132 | except: 133 | print("Could not understand .info formatting: " + os.path.basename(root)) 134 | # Append to blacklist 135 | append_fail(os.path.basename(root)) 136 | break 137 | # print(song_name) 138 | 139 | if not info_file: 140 | print("Skipping... No info.dat file found in " + root) 141 | break 142 | # copy files 143 | count += 1 144 | test_copy = 0 145 | for copy_file in files: 146 | if copy_file.endswith(".egg"): 147 | shutil.copyfile(root + "/" + copy_file, copy_path_song + song_name + copy_file[-4:]) 148 | test_copy += 2 149 | elif copy_file.endswith(exp_name) and allow_diff2: 150 | shutil.copyfile(root + "/" + copy_file, copy_path_map + song_name + copy_file[-4:]) 151 | test_copy += 1 152 | elif copy_file.lower().endswith("info.dat"): 153 | copy_file = "info.dat" 154 | shutil.copyfile(root + "/" + copy_file, copy_path_map + song_name + "_" + copy_file) 155 | test_copy += 2 156 | 157 | # Overwrite Expert with ExpertPlus if available 158 | for copy_file in files: 159 | if copy_file.endswith(exp_plus_name): 160 | shutil.copyfile(root + "/" + copy_file, copy_path_map + song_name + copy_file[-4:]) 161 | test_copy += 1 162 | break 163 | 164 | # print(song_name) 165 | 166 | # test complete copy process 167 | if test_copy < 5 or test_copy > 6: 168 | if test_copy > 6: 169 | print("Too many files copied - probably two music files found.") 170 | else: 171 | print("Not enough files copied") 172 | print("Error: Copy process failed at: " + os.path.basename(root)) 173 | append_fail(song_name) 174 | if len(song_name_list) != count: 175 | print("Error: count != song_name_list") 176 | exit() 177 | 178 | if len(os.listdir(copy_path_map)) / 2 != count: 179 | print(f"Count {count} maps vs directory: {len(os.listdir(copy_path_map)) / 2}") 180 | exit() 181 | if len(os.listdir(copy_path_song)) != count: 182 | print(f"Count {count} songs vs directory: {len(os.listdir(copy_path_song))}") 183 | exit() 184 | 185 | print("\nFinished Shift from BS directory to project") 186 | 187 | # Delete uncompleted samples 188 | delete_fails() 189 | 190 | 191 | if __name__ == '__main__': 192 | delete_old_files() 193 | 194 | shift_bs_songs() 195 | 196 | # Start casting to dictionary (notes, events, etc) 197 | map_to_dict_all() 198 | 199 | # Calculate notes per sec for each song 200 | bps_find_songs() 201 | 202 | print("Finished shifting") 203 | -------------------------------------------------------------------------------- /map_creation/gen_obstacles.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from random import randint 3 | 4 | from tools.config import config 5 | 6 | 7 | def add_obstacle(obstacles: list, position: int, first_time, last_time): 8 | # check for multi occurrences of obstacles 9 | # obs_counter = 0 10 | # obs_break_counter = 4 11 | # for obs_last in obstacles[::-1]: 12 | # obs_last_time = obs_last[0] + obs_last[3] 13 | # if first_time <= obs_last_time: 14 | # obs_counter += 1 15 | # if obs_counter >= config.obstacle_max_count: 16 | # break 17 | # else: 18 | # obs_break_counter -= 1 19 | # if obs_break_counter <= 0: 20 | # break 21 | # if obs_counter < config.obstacle_max_count: 22 | 23 | # Switch sport and obstacle mode 24 | if config.sporty_obstacles: 25 | config.obstacle_allowed_types = config.sport_obstacle_allowed_types 26 | config.obstacle_positions = config.sport_obstacle_positions 27 | else: 28 | config.obstacle_allowed_types = config.norm_obstacle_allowed_types 29 | config.obstacle_positions = config.norm_obstacle_positions 30 | 31 | # Add new obstacle 32 | rand_type = randint(0, len(config.obstacle_allowed_types) - 1) 33 | o_type = config.obstacle_allowed_types[rand_type] 34 | o_height = randint(1, config.max_obstacle_height) 35 | first_time += config.obstacle_time_gap[0] 36 | last_time -= config.obstacle_time_gap[1] 37 | 38 | duration = last_time - first_time 39 | # _obstacles":[{"_time":64.39733123779297,"_lineIndex":0, 40 | # "_type":0,"_duration":6.5,"_width":1} 41 | cur_obstacle = [first_time, position, o_type, duration, config.obstacle_width, o_height] 42 | obstacles[position].append(cur_obstacle) 43 | return obstacles 44 | 45 | 46 | def check_obstacle_times(first_time, last_time): 47 | time_diff = last_time - first_time 48 | if time_diff <= config.obstacle_min_duration + sum(config.obstacle_time_gap): 49 | return False 50 | else: 51 | return True 52 | 53 | 54 | def combine_obstacles(obstacles_all, times_empty): 55 | def found_obstacle(obst_temp, first_time, last_time, position, width=config.obstacle_width): 56 | if width > 2: 57 | o_type = 2 # only allow ceiling type for crouch walls 58 | o_height = 3 59 | else: 60 | rand_type = randint(0, len(config.obstacle_allowed_types) - 1) 61 | o_type = config.obstacle_allowed_types[rand_type] 62 | o_height = randint(1, config.max_obstacle_height) 63 | 64 | for t_empty in times_empty: 65 | if first_time >= t_empty: 66 | break 67 | else: 68 | if t_empty < last_time: 69 | dur_temp = round(t_empty - first_time - 0.1, 1) 70 | cur_obstacle = [first_time, position, o_type, dur_temp, width, o_height] 71 | obst_temp.append(cur_obstacle) 72 | first_time = t_empty 73 | dur_temp = round(last_time - first_time, 1) 74 | # if dur_temp < 0: 75 | # print("Error. Encountered negative duration for obstacles! Exclude.") 76 | if dur_temp > 0: 77 | cur_obstacle = [first_time, position, o_type, dur_temp, width, o_height] 78 | obst_temp.append(cur_obstacle) 79 | return obst_temp 80 | 81 | def check_saved_times(first, last, lt_save_first, lt_save_last, diff_min): 82 | if len(lt_save_last) == 0 or not config.sporty_obstacles: 83 | return first, last 84 | first_last = [first, last] 85 | for i_save, lt_save in enumerate([lt_save_last, lt_save_first]): 86 | diff_array = np.asarray(lt_save) - first_last[i_save] 87 | # only check negative values as positive are not possible for sporty mode 88 | if i_save == 1: 89 | diff_array *= -1 90 | diff_array = diff_array[diff_array < 0] 91 | if len(diff_array) > 0: 92 | diff_real = round(abs(max(diff_array)), 1) 93 | if diff_real < diff_min: 94 | if i_save == 0: 95 | first += round(diff_min - diff_real, 1) 96 | if i_save == 1: 97 | last -= round(diff_min - diff_real, 1) 98 | return first, last 99 | 100 | step_size = 0.1 101 | obstacles = [] 102 | time_list = [[], [], [], []] 103 | for idx in range(len(obstacles_all)): 104 | for obst in obstacles_all[idx]: 105 | start_time = round(obst[0], 1) 106 | duration = round(obst[3], 1) 107 | new_times = np.round(np.arange(start_time, start_time + duration + step_size, step_size), 1) 108 | time_list[idx].extend(list(new_times)) 109 | common_val1 = list(set(time_list[0]).intersection(time_list[1])) 110 | common_val2 = list(set(time_list[2]).intersection(time_list[3])) 111 | common_val1.sort() 112 | common_val2.sort() 113 | 114 | if config.sporty_obstacles: 115 | common_val3 = list(set(common_val1).intersection(common_val2)) 116 | common_val3.sort() 117 | else: 118 | common_val3 = [] 119 | 120 | last_time_saves = [] 121 | first_time_saves = [] 122 | diff_gap_min = round(max(config.obstacle_time_gap), 1) 123 | for idx, common_val in enumerate([common_val1, common_val2]): 124 | t_first = -1 125 | t_last = 0 126 | for t in common_val: 127 | if t in common_val3: 128 | pass # do later for both sides active 129 | elif t_first == -1: 130 | t_first = t 131 | t_last = t 132 | elif t > t_last + 2 * step_size: 133 | rnd_pos = randint(0, len(config.obstacle_positions[idx]) - 1) 134 | t_first, t_last = check_saved_times(t_first, t_last, first_time_saves, 135 | last_time_saves, diff_gap_min) 136 | obstacles = found_obstacle(obstacles, t_first, t_last, 137 | config.obstacle_positions[idx][rnd_pos]) 138 | first_time_saves.append(t_first) 139 | last_time_saves.append(t_last) 140 | t_first = -1 141 | else: 142 | t_last = t 143 | if t_first > 0: 144 | if t_last - t_first >= config.obstacle_min_duration: 145 | rnd_pos = randint(0, len(config.obstacle_positions[idx]) - 1) 146 | t_first, t_last = check_saved_times(t_first, t_last, first_time_saves, 147 | last_time_saves, diff_gap_min) 148 | obstacles = found_obstacle(obstacles, t_first, t_last, 149 | config.obstacle_positions[idx][rnd_pos]) 150 | first_time_saves.append(t_first) 151 | last_time_saves.append(t_last) 152 | 153 | # Obstacles for both sides on 154 | t_first = -1 155 | t_last = 0 156 | diff_gap_min = round(max(config.obstacle_time_gap), 1) 157 | for t in common_val3: 158 | if t_first == -1: 159 | t_first = t 160 | t_last = t 161 | elif t > t_last + 2 * step_size: 162 | obstacle_pos = [] 163 | for idx in range(2): 164 | rnd_pos = randint(0, len(config.obstacle_positions[idx]) - 1) 165 | obstacle_pos.append(config.obstacle_positions[idx][rnd_pos]) 166 | if t_first < 10: 167 | # do not allow crouching at the beginning of song 168 | if obstacle_pos[0] == 1 and obstacle_pos[1] == 2: 169 | obstacle_pos[1] = 3 170 | if obstacle_pos[0] == 1 and obstacle_pos[1] == 2: 171 | # crouch obstacle 172 | # t_first, t_last = check_saved_times(t_first, t_last, first_time_saves, 173 | # last_time_saves, diff_gap_min) 174 | t_first += 0.1 # allow more time for crouching 175 | t_last -= 0.1 176 | obstacles = found_obstacle(obstacles, t_first, t_last, 0, 177 | config.obstacle_crouch_width) 178 | else: 179 | t_first, t_last = check_saved_times(t_first, t_last, first_time_saves, 180 | last_time_saves, diff_gap_min) 181 | obstacles = found_obstacle(obstacles, t_first, t_last, obstacle_pos[0]) 182 | obstacles = found_obstacle(obstacles, t_first, t_last, obstacle_pos[1]) 183 | t_first = -1 184 | else: 185 | t_last = t 186 | 187 | return obstacles 188 | 189 | 190 | def calculate_obstacles(notes, timings): # TODO: change config to obstacle_mode with 0 None and up to3 for full 191 | obstacles_all = [[], [], [], []] 192 | rows_last = [2, 2, 2, 2] 193 | times_empty = [0] 194 | for idx in range(len(notes)): 195 | if len(notes[idx]) == 0: 196 | times_empty.append(timings[idx]) 197 | else: 198 | cur_rows = notes[idx][::4] 199 | for n_row in cur_rows: 200 | # if n_row == 0 or n_row == 3: 201 | if check_obstacle_times(rows_last[n_row], timings[idx]): 202 | obstacles_all = add_obstacle(obstacles_all, n_row, rows_last[n_row], 203 | timings[idx]) 204 | rows_last[n_row] = timings[idx] 205 | 206 | obstacles = combine_obstacles(obstacles_all, times_empty) 207 | 208 | # sort by timings 209 | obstacles = np.asarray(obstacles) 210 | obstacles = obstacles[np.argsort(obstacles[:, 0])] 211 | 212 | return obstacles 213 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | ################################################### 2 | # This file is needed to find the working directory 3 | ################################################### 4 | import os 5 | import shutil 6 | import time 7 | import tensorflow as tf 8 | # import sys 9 | 10 | from tools.config import paths, config 11 | import map_creation.gen_beats as beat_generator 12 | from bs_shift.export_map import check_music_files, shutil_copy_maps 13 | 14 | from tools.config.mapper_selection import update_model_file_paths 15 | from tools.utils.huggingface import model_download 16 | 17 | 18 | def main(use_model=None, diff=None, export_results_to_bs=True, 19 | gimme_more=None, quick_start=None, beat_intensity=None, 20 | random_factor=None, js_offset=None, 21 | allow_no_dir_flag=None, silence_factor=None, 22 | add_obstacles=None, sporty_obstacles=None, 23 | add_sliders=None, slider_start_time=None, 24 | slider_end_time=None, slider_probability=None, 25 | slider_movement_min=None, legacy_mode=None, single_mode=None, 26 | logger_callback=None): 27 | if use_model is not None: 28 | config.use_mapper_selection = use_model 29 | # change difficulty 30 | if diff is not None: 31 | config.max_speed = diff 32 | config.max_speed_orig = diff 33 | if gimme_more is not None: 34 | config.gimme_more_notes_prob = gimme_more 35 | if quick_start is not None: 36 | config.quick_start = quick_start 37 | if beat_intensity is not None: 38 | config.add_beat_intensity = beat_intensity 39 | config.add_beat_intensity_orig = beat_intensity 40 | if random_factor is not None: 41 | config.random_note_map_factor = random_factor 42 | if js_offset is not None: 43 | config.jump_speed_offset += js_offset 44 | config.jump_speed_offset_orig += js_offset 45 | if allow_no_dir_flag is not None: 46 | config.allow_dot_notes = allow_no_dir_flag 47 | if silence_factor is not None: 48 | config.silence_threshold *= silence_factor 49 | config.silence_threshold_orig *= silence_factor 50 | if add_obstacles is not None: 51 | config.add_obstacle_flag = add_obstacles 52 | if sporty_obstacles is not None: 53 | config.sporty_obstacles = sporty_obstacles 54 | if add_sliders is not None: 55 | config.add_slider_flag = add_sliders 56 | if slider_start_time is not None: 57 | config.slider_time_gap[0] = slider_start_time 58 | if slider_end_time is not None: 59 | config.slider_time_gap[1] = slider_end_time 60 | if slider_probability is not None: 61 | config.slider_probability = slider_probability 62 | if slider_movement_min is not None: 63 | config.slider_movement_minimum = slider_movement_min 64 | if single_mode is not None: 65 | if single_mode: 66 | config.emphasize_beats_flag = False 67 | config.single_notes_only_flag = True 68 | else: 69 | config.emphasize_beats_flag = True 70 | config.single_notes_only_flag = False 71 | if legacy_mode is not None: 72 | if legacy_mode: 73 | config.bs_mapping_version = 'v2' 74 | else: 75 | config.bs_mapping_version = 'v3' 76 | update_model_file_paths(check_model_exists=False) 77 | 78 | # Download AI Model from huggingface 79 | log_message = "Loading AI Model [if not yet done, ~1GB will be downloaded (see second tab)]" 80 | if logger_callback: 81 | logger_callback(log_message) 82 | model_download() 83 | log_message = "Model Found" 84 | if logger_callback: 85 | logger_callback(log_message) 86 | 87 | # # limit gpu ram usage 88 | # conf = tf.compat.v1.ConfigProto() 89 | # conf.gpu_options.allow_growth = True 90 | # sess = tf.compat.v1.Session(config=conf) 91 | # tf.compat.v1.keras.backend.set_session(sess) 92 | 93 | # MAP GENERATOR 94 | ############### 95 | song_list = os.listdir(paths.songs_pred) 96 | song_list = check_music_files(song_list, paths.songs_pred) 97 | print(f"Found {len(song_list)} songs. Iterating...") 98 | if len(song_list) == 0: 99 | print("No songs found!") 100 | log_message = "No songs found! Please go to first tab." 101 | if logger_callback: 102 | logger_callback(log_message) 103 | 104 | for i, song_name in enumerate(song_list): 105 | start_time = time.time() 106 | song_name = song_name[:-4] 107 | print(f"Analyzing song: {song_name} ({i + 1} of {len(song_list)})") 108 | fail_flag = beat_generator.main([song_name]) 109 | if fail_flag: 110 | print("Continue with next song") 111 | continue 112 | end_time = time.time() 113 | print(f"Time needed: {end_time - start_time:.1f}s") 114 | 115 | # create zip archive for online viewer 116 | shutil.make_archive(f'{paths.new_map_path}{config.max_speed_orig:.1f}_{song_name}', 117 | 'zip', f'{paths.new_map_path}1234_{config.max_speed_orig:.1f}_{song_name}') 118 | # export map to beat saber 119 | if export_results_to_bs: 120 | if shutil_copy_maps(f"{config.max_speed_orig:.1f}_{song_name}"): 121 | log_message = "Copied map(s) to BeatSaber directory." 122 | if logger_callback: 123 | logger_callback(log_message) 124 | 125 | print("Finished map generator") 126 | log_message = "Finished map generator" 127 | if logger_callback: 128 | logger_callback(log_message) 129 | 130 | 131 | # ############################################################ 132 | 133 | # ############################################################ 134 | # if fails, rerun train_bs_automapper with correct min/max_bps 135 | # until training is started (cancel after data import) 136 | ############################################################## 137 | 138 | # TRAINING 139 | ########## 140 | # run bs_shift / shift.py 141 | # run training / train_autoenc_music.py 142 | # run training / train_bs_automapper.py 143 | # run beat_prediction / ai_beat_gen.py 144 | # run lighting_prediction / train_lighting.py 145 | 146 | if __name__ == "__main__": 147 | use_model = os.environ.get('inferno_model') 148 | print(f"Running model: {use_model}") 149 | 150 | diff = os.environ.get('max_speed') 151 | if diff is not None: 152 | diff = float(diff) 153 | print(f"Set BPS difficulty to {diff}") 154 | diff = diff * 4 # calculate bps to max_speed 155 | else: 156 | print("Use default difficulty values") 157 | 158 | gm = os.environ.get('gimme_more') 159 | if gm is not None: 160 | gm = float(gm) 161 | print(f"Set density to {gm}") 162 | 163 | qs = os.environ.get('quick_start') 164 | if qs is not None: 165 | qs = float(qs) 166 | print(f"Set quick_start to {qs}") 167 | 168 | bi = os.environ.get('beat_intensity') 169 | if bi is not None: 170 | bi = float(bi) 171 | print(f"Set beat intensity to {bi}") 172 | 173 | rf = os.environ.get('random_factor') 174 | if rf is not None: 175 | rf = float(rf) 176 | print(f"Set random factor to {rf}") 177 | 178 | jso = os.environ.get('jump_speed_offset') 179 | if jso is not None: 180 | jso = float(jso) 181 | print(f"Set jump speed offset to {jso}") 182 | 183 | ndf = os.environ.get('allow_no_direction_flag') 184 | if ndf is not None: 185 | if ndf == 'True': 186 | ndf = True 187 | else: 188 | ndf = False 189 | print(f"Set allow_no_direction_flag to {ndf}") 190 | 191 | sf = os.environ.get('silence_factor') 192 | if sf is not None: 193 | sf = float(sf) 194 | print(f"Set silence factor to {sf}") 195 | 196 | # obstacles 197 | aof = os.environ.get('add_obstacle_flag') 198 | if aof is not None: 199 | if aof == 'True': 200 | aof = True 201 | else: 202 | aof = False 203 | print(f"Set add_obstacle_flag to {aof}") 204 | 205 | sof = os.environ.get('sporty_obstacle_flag') 206 | if sof is not None: 207 | if sof == 'True': 208 | sof = True 209 | else: 210 | sof = False 211 | print(f"Set sporty_obstacle_flag to {sof}") 212 | 213 | # sliders 214 | asf = os.environ.get('add_slider_flag') 215 | if asf is not None: 216 | if asf == 'True': 217 | asf = True 218 | else: 219 | asf = False 220 | print(f"Set add_slider_flag to {asf}") 221 | 222 | sst = os.environ.get('slider_start_time') 223 | if sst is not None: 224 | sst = float(sst) 225 | print(f"Set slider_start_time to {sst}") 226 | 227 | se = os.environ.get('slider_end_time') 228 | if se is not None: 229 | se = float(se) 230 | print(f"Set slider_end_time to {se}") 231 | 232 | sp = os.environ.get('slider_probability') 233 | if sp is not None: 234 | sp = float(sp) 235 | print(f"Set slider_probability to {sp}") 236 | 237 | smm = os.environ.get('slider_movement_min') 238 | if smm is not None: 239 | smm = float(smm) 240 | print(f"Set slider_movement_min to {smm}") 241 | 242 | lm = os.environ.get('legacy_mode') 243 | if isinstance(lm, str): 244 | if lm == 'True': 245 | lm = True 246 | else: 247 | lm = False 248 | 249 | export_results_to_bs = False 250 | 251 | main(use_model, diff, export_results_to_bs, gm, qs, bi, rf, jso, 252 | ndf, sf, aof, sof, asf, sst, se, sp, smm, lm) 253 | --------------------------------------------------------------------------------