├── code ├── loader │ ├── __init__.py │ ├── vocabulary.py │ ├── selector.py │ ├── ddcdataset.py │ ├── audiocharttwobeatsongdataset.py │ ├── audiotimingtwobeatsongdataset.py │ ├── audiotimingfinebeatdataset.py │ └── audiorandomshift.py ├── loss │ ├── __init__.py │ └── selector.py ├── model │ ├── __init__.py │ ├── weights │ │ ├── spectrogram_extractor.bin │ │ ├── spectrogram_normalizer.bin │ │ └── placement_cnn_ckpt_56000.bin │ ├── paths.py │ ├── constants.py │ ├── selector.py │ ├── gorstconditionininput.py │ └── gorstconditionininputtiming.py ├── optimizer │ ├── __init__.py │ └── selector.py ├── _scripts │ └── runtrain.sh ├── __init__.py ├── utils │ └── creator_factory.py └── counter │ ├── osu4kclstmcounter.py │ ├── vocabulary.py │ └── osu4kcounter.py ├── osu-to-ddc ├── __init__.py ├── osutoddc │ ├── __init__.py │ └── converter │ │ ├── util.py │ │ ├── __init__.py │ │ ├── parser_exceptions.py │ │ ├── converter.py │ │ └── osuparser.py ├── README.md └── .gitignore ├── 0_ddc ├── dataset │ ├── __init__.py │ ├── smdataset │ │ ├── __init__.py │ │ ├── abstime.py │ │ └── parse.py │ ├── util.py │ ├── preview_sm.py │ ├── dataset_json.py │ ├── preview_wav.py │ ├── analyze_json.py │ └── extract_json.py ├── scripts │ ├── smd_0_push.sh │ ├── sml_0_push.sh │ ├── var.sh │ ├── smd_4_analyze.sh │ ├── smd_1_extract.sh │ ├── all.sh │ ├── sml_onset_0_extract.sh │ ├── sml_sym_1_chart.sh │ ├── sml_onset_1_chart.sh │ ├── sml_sym_1_chart_audio.sh │ ├── sml_sym_2_mark.sh │ ├── smd_all.sh │ ├── smd_2_filter.sh │ ├── smd_3_dataset.sh │ ├── sml_onset_3_eval.sh │ ├── sml_onset_4_export.sh │ ├── sml_sym_2_train.sh │ └── sml_onset_2_train.sh ├── ddc_to_gorst │ ├── h5pyize_dataset.py │ ├── cache_similar_beat_index.py │ ├── generate_dataset_peripherals.py │ ├── make_similarity_matrix.py │ ├── run.sh │ ├── main.py │ ├── ddcjson_to_osujson.py │ └── make_split.py └── README.md ├── conf ├── optimizer │ ├── sgd_ddc.yaml │ ├── adam.yaml │ ├── adam_clstm.yaml │ └── adam_mel_finetune.yaml ├── loss │ ├── bceloss.yaml │ └── crossentropy.yaml ├── model │ ├── ddc_cnn.yaml │ ├── ddc_clstm.yaml │ ├── finediffdecoderprelnbipast.yaml │ ├── mel.yaml │ └── finediffdecoderprelnbipasttiming.yaml ├── experiment │ ├── ddc.yaml │ └── test.yaml ├── ts_dataset │ ├── beatfine.yaml │ ├── ddc.yaml │ ├── beatfine_itg.yaml │ ├── beatfine_itg_timingonly.yaml │ └── beatfine_fraxtil_timingonly.yaml ├── cv_dataset │ ├── ddc.yaml │ ├── ddc_expert_plus.yaml │ ├── ddc_easy.yaml │ ├── ddc_hard.yaml │ ├── ddc_expert.yaml │ ├── ddc_insane.yaml │ ├── ddc_normal.yaml │ ├── beatfine_8_100.yaml │ ├── melrandomshift.yaml │ ├── beatfine_timingonly.yaml │ ├── beatfine_itg.yaml │ ├── beatfine_itg_timingonly.yaml │ ├── beatfine_fraxtil.yaml │ └── beatfine_fraxtil_timingonly.yaml └── tr_dataset │ ├── ddc.yaml │ ├── beatfine_8_100.yaml │ ├── melrandomshift.yaml │ ├── beatfine_timingonly.yaml │ ├── beatfine_itg.yaml │ ├── beatfine_fraxtil.yaml │ ├── beatfine_itg_timingonly.yaml │ └── beatfine_fraxtil_timingonly.yaml ├── generated.zip ├── 1_preindex_similarity_matrix ├── strahv.png ├── cleanup.py ├── cache_similarity_matrix.py ├── cache_similar_beat_index.py └── make_similarity_matrix.py ├── scripts ├── ddc_cnn.sh ├── ddc_clstm.sh ├── ddc_eval_cnn.sh ├── melitg.sh ├── ddc_eval_clstm.sh ├── mel.sh ├── melrandomshift.sh ├── mel_timing.sh ├── melitg_timing.sh ├── melfraxtil_timing.sh ├── thres_tune_cnn.sh ├── thres_tune_clstm.sh ├── melitg_timing_finetune.sh └── melfraxtil_timing_finetune.sh ├── text_replacer.py ├── LICENSE ├── 2_generate_dataset ├── generate_dataset_peripherals.py └── h5pyize_dataset.py ├── generate_ref.py ├── .gitignore ├── metrics_cond_centered_AR.py ├── tune_thresholds.py ├── train_ddc.py ├── train_cond_centered.py ├── metrics_timing_AR.py └── ddc_eval.py /code/loader/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /code/loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /code/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /osu-to-ddc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /0_ddc/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /code/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /osu-to-ddc/osutoddc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /0_ddc/dataset/smdataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /osu-to-ddc/osutoddc/converter/util.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /osu-to-ddc/osutoddc/converter/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /code/_scripts/runtrain.sh: -------------------------------------------------------------------------------- 1 | python ../train.py \ 2 | +loader=test -------------------------------------------------------------------------------- /0_ddc/scripts/smd_0_push.sh: -------------------------------------------------------------------------------- 1 | source var.sh 2 | 3 | pushd ../dataset 4 | -------------------------------------------------------------------------------- /0_ddc/scripts/sml_0_push.sh: -------------------------------------------------------------------------------- 1 | source var.sh 2 | 3 | pushd ../learn 4 | -------------------------------------------------------------------------------- /0_ddc/scripts/var.sh: -------------------------------------------------------------------------------- 1 | SM_DIR=~/0_ddc/ 2 | SMDATA_DIR=STEPMANIAFOLDER 3 | -------------------------------------------------------------------------------- /conf/optimizer/sgd_ddc.yaml: -------------------------------------------------------------------------------- 1 | name: sgd 2 | parameters: 3 | lr: 1e-1 4 | -------------------------------------------------------------------------------- /conf/loss/bceloss.yaml: -------------------------------------------------------------------------------- 1 | name: BCELoss 2 | parameters: 3 | reduction: 'mean' -------------------------------------------------------------------------------- /0_ddc/ddc_to_gorst/h5pyize_dataset.py: -------------------------------------------------------------------------------- 1 | ../../2_generate_dataset/h5pyize_dataset.py -------------------------------------------------------------------------------- /generated.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stet-stet/goct_ismir2023/HEAD/generated.zip -------------------------------------------------------------------------------- /conf/loss/crossentropy.yaml: -------------------------------------------------------------------------------- 1 | name: CrossEntropyLoss 2 | parameters: 3 | label_smoothing: 0.02 -------------------------------------------------------------------------------- /conf/model/ddc_cnn.yaml: -------------------------------------------------------------------------------- 1 | name: DDCCNN 2 | parameters: 3 | load_pretrained_weights: False 4 | -------------------------------------------------------------------------------- /conf/model/ddc_clstm.yaml: -------------------------------------------------------------------------------- 1 | name: DDCCLSTM 2 | parameters: 3 | load_pretrained_weights: False 4 | -------------------------------------------------------------------------------- /conf/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | name: adam 2 | parameters: 3 | lr: 2e-4 4 | beta1: 0.9 5 | beta2: 0.999 6 | -------------------------------------------------------------------------------- /0_ddc/ddc_to_gorst/cache_similar_beat_index.py: -------------------------------------------------------------------------------- 1 | ../../1_preindex_similarity_matrix/cache_similar_beat_index.py -------------------------------------------------------------------------------- /0_ddc/ddc_to_gorst/generate_dataset_peripherals.py: -------------------------------------------------------------------------------- 1 | ../../2_generate_dataset/generate_dataset_peripherals.py -------------------------------------------------------------------------------- /0_ddc/ddc_to_gorst/make_similarity_matrix.py: -------------------------------------------------------------------------------- 1 | ../../1_preindex_similarity_matrix/make_similarity_matrix.py -------------------------------------------------------------------------------- /conf/optimizer/adam_clstm.yaml: -------------------------------------------------------------------------------- 1 | name: adam 2 | parameters: 3 | lr: 2e-4 4 | beta1: 0.9 5 | beta2: 0.999 6 | -------------------------------------------------------------------------------- /conf/optimizer/adam_mel_finetune.yaml: -------------------------------------------------------------------------------- 1 | name: adam 2 | parameters: 3 | lr: 2e-5 4 | beta1: 0.9 5 | beta2: 0.999 6 | -------------------------------------------------------------------------------- /0_ddc/scripts/smd_4_analyze.sh: -------------------------------------------------------------------------------- 1 | source smd_0_push.sh 2 | 3 | python analyze_json.py \ 4 | ${SMDATA_DIR}/json_filt/${1}.txt ${2} 5 | -------------------------------------------------------------------------------- /1_preindex_similarity_matrix/strahv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stet-stet/goct_ismir2023/HEAD/1_preindex_similarity_matrix/strahv.png -------------------------------------------------------------------------------- /conf/experiment/ddc.yaml: -------------------------------------------------------------------------------- 1 | epochs: 12 2 | batch_size: 3 | tr: 64 4 | cv: 64 5 | ts: 64 6 | num_workers: 8 7 | checkpoint_path: ckpt.pth 8 | -------------------------------------------------------------------------------- /conf/experiment/test.yaml: -------------------------------------------------------------------------------- 1 | epochs: 12 2 | batch_size: 3 | tr: 32 4 | cv: 32 5 | ts: 1 6 | num_workers: 5 7 | checkpoint_path: ckpt.pth 8 | -------------------------------------------------------------------------------- /osu-to-ddc/README.md: -------------------------------------------------------------------------------- 1 | # osu-to-ddc 2 | 3 | turn osu files to ddc-compatible .json files, with clearly denoted beat numbers and timings. 4 | -------------------------------------------------------------------------------- /code/model/weights/spectrogram_extractor.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stet-stet/goct_ismir2023/HEAD/code/model/weights/spectrogram_extractor.bin -------------------------------------------------------------------------------- /code/model/weights/spectrogram_normalizer.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stet-stet/goct_ismir2023/HEAD/code/model/weights/spectrogram_normalizer.bin -------------------------------------------------------------------------------- /code/model/weights/placement_cnn_ckpt_56000.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stet-stet/goct_ismir2023/HEAD/code/model/weights/placement_cnn_ckpt_56000.bin -------------------------------------------------------------------------------- /0_ddc/scripts/smd_1_extract.sh: -------------------------------------------------------------------------------- 1 | source smd_0_push.sh 2 | 3 | python extract_json.py \ 4 | ${SMDATA_DIR}/raw/${1} \ 5 | ${SMDATA_DIR}/json_raw/${1} \ 6 | ${2} 7 | -------------------------------------------------------------------------------- /scripts/ddc_cnn.sh: -------------------------------------------------------------------------------- 1 | python train_ddc.py +tr_dataset=ddc +cv_dataset=ddc +model=ddc_cnn +optimizer=adam_clstm +loss=BCELoss +experiment=ddc +ckpt_path=ckpts/ddc_cnn/ 2 | -------------------------------------------------------------------------------- /0_ddc/scripts/all.sh: -------------------------------------------------------------------------------- 1 | for COLL in fraxtil itg 2 | do 3 | echo "Executing ${1} for ${COLL}" 4 | ${1} ${COLL} 5 | echo "--------------------------------------------" 6 | done 7 | -------------------------------------------------------------------------------- /scripts/ddc_clstm.sh: -------------------------------------------------------------------------------- 1 | python train_ddc.py +tr_dataset=ddc +cv_dataset=ddc +model=ddc_clstm +optimizer=adam_clstm +loss=BCELoss +experiment=ddc +ckpt_path=ckpts/ddc_clstm/ 2 | 3 | -------------------------------------------------------------------------------- /code/__init__.py: -------------------------------------------------------------------------------- 1 | from .loader.selector import DatasetSelector 2 | from .loss.selector import LossSelector 3 | from .model.selector import ModelSelector 4 | from .optimizer.selector import OptimizerSelector -------------------------------------------------------------------------------- /scripts/ddc_eval_cnn.sh: -------------------------------------------------------------------------------- 1 | python ddc_eval.py \ 2 | +ts_dataset=ddc \ 3 | +model=ddc_cnn \ 4 | +ckpt_path=/mnt/c/Users/manym/Desktop/gorst/gorst/ckpts/ddc_cnn/2e4_and_64/ckpt_epoch_1.pth \ 5 | +experiment=ddc 6 | -------------------------------------------------------------------------------- /conf/ts_dataset/beatfine.yaml: -------------------------------------------------------------------------------- 1 | name: AudioChartTwoBeatSongDataset 2 | parameters: 3 | split_json_path: OSUFOLDER/test.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | split_hdf5_path: OSUFOLDER/test.h5 6 | -------------------------------------------------------------------------------- /scripts/melitg.sh: -------------------------------------------------------------------------------- 1 | 2 | python train_cond_centered.py \ 3 | +tr_dataset=beatfine_itg \ 4 | +cv_dataset=beatfine_itg \ 5 | +experiment=test \ 6 | +model=mel \ 7 | +loss=crossentropy \ 8 | +optimizer=adam 9 | -------------------------------------------------------------------------------- /scripts/ddc_eval_clstm.sh: -------------------------------------------------------------------------------- 1 | python ddc_eval.py \ 2 | +ts_dataset=ddc \ 3 | +model=ddc_clstm \ 4 | +ckpt_path="/mnt/c/Users/manym/Desktop/gorst/gorst/ckpts/ddc_clstm/0.0002_and_64/ckpt_epoch_0.pth" \ 5 | +experiment=ddc 6 | -------------------------------------------------------------------------------- /code/utils/creator_factory.py: -------------------------------------------------------------------------------- 1 | def makeClassMaker(the_class, *args, **kwargs): 2 | def makeClass(*other_args, **other_kwargs): 3 | return the_class(*other_args, *args, **kwargs, **other_kwargs) 4 | return makeClass -------------------------------------------------------------------------------- /conf/cv_dataset/ddc.yaml: -------------------------------------------------------------------------------- 1 | name: DDCDataset 2 | parameters: 3 | split_json_path: OSUFOLDER/valid.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | music_hdf5_path: OSUFOLDER/all_ddc.h5 6 | onset_hdf5_path: OSUFOLDER/all_onset.h5 7 | -------------------------------------------------------------------------------- /conf/tr_dataset/ddc.yaml: -------------------------------------------------------------------------------- 1 | name: DDCDataset 2 | parameters: 3 | split_json_path: OSUFOLDER/train.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | music_hdf5_path: OSUFOLDER/all_ddc.h5 6 | onset_hdf5_path: OSUFOLDER/all_onset.h5 7 | -------------------------------------------------------------------------------- /0_ddc/scripts/sml_onset_0_extract.sh: -------------------------------------------------------------------------------- 1 | source sml_0_push.sh 2 | 3 | python extract_feats.py \ 4 | ${SMDATA_DIR}/json_filt/${1}.txt \ 5 | --out_dir=${SMDATA_DIR}/feats/${1}/mel80hop441 \ 6 | --nhop=441 \ 7 | --nffts=1024,2048,4096 \ 8 | --log_scale 9 | -------------------------------------------------------------------------------- /scripts/mel.sh: -------------------------------------------------------------------------------- 1 | 2 | python train_cond_centered.py \ 3 | +tr_dataset=beatfine_8_100 \ 4 | +cv_dataset=beatfine_8_100 \ 5 | +experiment=test \ 6 | +model=mel \ 7 | +loss=crossentropy \ 8 | +optimizer=adam \ 9 | +ckpt_path=ckpts/mel/ 10 | -------------------------------------------------------------------------------- /conf/ts_dataset/ddc.yaml: -------------------------------------------------------------------------------- 1 | name: DDCDataset 2 | parameters: 3 | split_json_path: OSUFOLDER/test.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | music_hdf5_path: OSUFOLDER/all_ddc.h5 6 | onset_hdf5_path: OSUFOLDER/all_onset.h5 7 | overlaps: False 8 | -------------------------------------------------------------------------------- /scripts/melrandomshift.sh: -------------------------------------------------------------------------------- 1 | python train_cond_centered.py \ 2 | +tr_dataset=melrandomshift \ 3 | +cv_dataset=melrandomshift \ 4 | +experiment=test \ 5 | +model=mel \ 6 | +loss=crossentropy \ 7 | +optimizer=adam \ 8 | +ckpt_path=ckpts/mel_shift/ 9 | -------------------------------------------------------------------------------- /conf/ts_dataset/beatfine_itg.yaml: -------------------------------------------------------------------------------- 1 | name: AudioChartTwoBeatSongDataset 2 | parameters: 3 | split_json_path: STEPMANIAFOLDER/json_filt/itg/test.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | split_hdf5_path: STEPMANIAFOLDER/json_filt/itg/test.h5 6 | 7 | -------------------------------------------------------------------------------- /conf/cv_dataset/ddc_expert_plus.yaml: -------------------------------------------------------------------------------- 1 | name: DDCDataset 2 | parameters: 3 | split_json_path: OSUFOLDER/valid.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | music_hdf5_path: OSUFOLDER/all_ddc.h5 6 | onset_hdf5_path: OSUFOLDER/all_onset.h5 7 | min_diff: 6.5 8 | -------------------------------------------------------------------------------- /conf/ts_dataset/beatfine_itg_timingonly.yaml: -------------------------------------------------------------------------------- 1 | name: AudioTimingTwoBeatSongDataset 2 | parameters: 3 | split_json_path: STEPMANIAFOLDER/json_filt/itg/test.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | split_hdf5_path: STEPMANIAFOLDER/json_filt/itg/test.h5 6 | 7 | -------------------------------------------------------------------------------- /conf/cv_dataset/ddc_easy.yaml: -------------------------------------------------------------------------------- 1 | name: DDCDataset 2 | parameters: 3 | split_json_path: OSUFOLDER/valid.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | music_hdf5_path: OSUFOLDER/all_ddc.h5 6 | onset_hdf5_path: OSUFOLDER/all_onset.h5 7 | min_diff: 0.0 8 | max_diff: 1.99 9 | -------------------------------------------------------------------------------- /conf/cv_dataset/ddc_hard.yaml: -------------------------------------------------------------------------------- 1 | name: DDCDataset 2 | parameters: 3 | split_json_path: OSUFOLDER/valid.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | music_hdf5_path: OSUFOLDER/all_ddc.h5 6 | onset_hdf5_path: OSUFOLDER/all_onset.h5 7 | min_diff: 2.7 8 | max_diff: 3.99 9 | -------------------------------------------------------------------------------- /conf/cv_dataset/ddc_expert.yaml: -------------------------------------------------------------------------------- 1 | name: DDCDataset 2 | parameters: 3 | split_json_path: OSUFOLDER/valid.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | music_hdf5_path: OSUFOLDER/all_ddc.h5 6 | onset_hdf5_path: OSUFOLDER/all_onset.h5 7 | min_diff: 5.3 8 | max_diff: 6.49 9 | -------------------------------------------------------------------------------- /conf/cv_dataset/ddc_insane.yaml: -------------------------------------------------------------------------------- 1 | name: DDCDataset 2 | parameters: 3 | split_json_path: OSUFOLDER/valid.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | music_hdf5_path: OSUFOLDER/all_ddc.h5 6 | onset_hdf5_path: OSUFOLDER/all_onset.h5 7 | min_diff: 4.0 8 | max_diff: 5.29 9 | -------------------------------------------------------------------------------- /conf/cv_dataset/ddc_normal.yaml: -------------------------------------------------------------------------------- 1 | name: DDCDataset 2 | parameters: 3 | split_json_path: OSUFOLDER/valid.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | music_hdf5_path: OSUFOLDER/all_ddc.h5 6 | onset_hdf5_path: OSUFOLDER/all_onset.h5 7 | min_diff: 2.0 8 | max_diff: 2.69 9 | -------------------------------------------------------------------------------- /conf/ts_dataset/beatfine_fraxtil_timingonly.yaml: -------------------------------------------------------------------------------- 1 | name: AudioTimingTwoBeatSongDataset 2 | parameters: 3 | split_json_path: STEPMANIAFOLDER/json_filt/fraxtil/test.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | split_hdf5_path: STEPMANIAFOLDER/json_filt/fraxtil/test.h5 6 | 7 | -------------------------------------------------------------------------------- /scripts/mel_timing.sh: -------------------------------------------------------------------------------- 1 | 2 | python train_cond_centered.py \ 3 | +tr_dataset=beatfine_timingonly \ 4 | +cv_dataset=beatfine_timingonly \ 5 | +experiment=test \ 6 | +model=mel \ 7 | +loss=crossentropy \ 8 | +optimizer=adam \ 9 | +ckpt_path=ckpts/mel_timingonly/ 10 | -------------------------------------------------------------------------------- /osu-to-ddc/osutoddc/converter/parser_exceptions.py: -------------------------------------------------------------------------------- 1 | class NotManiaFileException(Exception): 2 | pass 3 | 4 | class Not4KException(Exception): 5 | pass 6 | 7 | class MissingAttributeException(Exception): 8 | pass 9 | 10 | class NotV14Exception(Exception): 11 | pass -------------------------------------------------------------------------------- /scripts/melitg_timing.sh: -------------------------------------------------------------------------------- 1 | 2 | python train_cond_centered.py \ 3 | +tr_dataset=beatfine_itg_timingonly \ 4 | +cv_dataset=beatfine_itg_timingonly \ 5 | +experiment=test \ 6 | +model=mel \ 7 | +loss=crossentropy \ 8 | +optimizer=adam \ 9 | +ckpt_path=ckpts/itg_mel_timingonly 10 | -------------------------------------------------------------------------------- /0_ddc/scripts/sml_sym_1_chart.sh: -------------------------------------------------------------------------------- 1 | source sml_0_push.sh 2 | 3 | python create_charts.py \ 4 | ${SMDATA_DIR}/json_filt/${1}_train.txt \ 5 | ${SMDATA_DIR}/json_filt/${1}_valid.txt \ 6 | ${SMDATA_DIR}/json_filt/${1}_test.txt \ 7 | --out_dir=${SMDATA_DIR}/chart_sym/${1}/symbolic \ 8 | --chart_type=sym 9 | -------------------------------------------------------------------------------- /code/model/paths.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | LIB_DIR = pathlib.Path(__file__).resolve().parent 4 | _REPO_DIR = LIB_DIR.parent if LIB_DIR.parent.name == "ddc_onset" else None 5 | 6 | WEIGHTS_DIR = pathlib.Path(LIB_DIR, "weights") 7 | TEST_DATA_DIR = None if _REPO_DIR is None else pathlib.Path(_REPO_DIR, "test") 8 | -------------------------------------------------------------------------------- /scripts/melfraxtil_timing.sh: -------------------------------------------------------------------------------- 1 | 2 | python train_cond_centered.py \ 3 | +tr_dataset=beatfine_fraxtil_timingonly \ 4 | +cv_dataset=beatfine_fraxtil_timingonly \ 5 | +experiment=test \ 6 | +model=mel \ 7 | +loss=crossentropy \ 8 | +optimizer=adam \ 9 | +ckpt_path=ckpts/fraxtil_mel_timingonly/ 10 | -------------------------------------------------------------------------------- /conf/cv_dataset/beatfine_8_100.yaml: -------------------------------------------------------------------------------- 1 | name: AudioChartFineBeatDataset 2 | parameters: 3 | split_json_path: OSUFOLDER/valid.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | split_hdf5_path: OSUFOLDER/valid.h5 6 | token_length: 200 7 | encoder: twotwo 8 | center: True 9 | logspec: True 10 | truncate: 92 -------------------------------------------------------------------------------- /conf/tr_dataset/beatfine_8_100.yaml: -------------------------------------------------------------------------------- 1 | name: AudioChartFineBeatDataset 2 | parameters: 3 | split_json_path: OSUFOLDER/train.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | split_hdf5_path: OSUFOLDER/train.h5 6 | token_length: 200 7 | encoder: twotwo 8 | center: True 9 | logspec: True 10 | truncate: 92 -------------------------------------------------------------------------------- /conf/cv_dataset/melrandomshift.yaml: -------------------------------------------------------------------------------- 1 | name: AudioRandomShiftDataset 2 | parameters: 3 | split_json_path: OSUFOLDER/valid.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | split_hdf5_path: OSUFOLDER/valid.h5 6 | token_length: 200 7 | encoder: twotwo 8 | center: True 9 | logspec: True 10 | truncate: 92 11 | -------------------------------------------------------------------------------- /conf/model/finediffdecoderprelnbipast.yaml: -------------------------------------------------------------------------------- 1 | name: GorstFineDiffInDecoderInput 2 | parameters: 3 | encoder_hidden_size: 256 4 | decoder_hidden_size: 256 5 | encoder_layers: 3 6 | decoder_layers: 3 7 | decoder_max_length: 200 8 | condition_dim: 48 9 | norm_first: True 10 | bidirectional_past_context: True -------------------------------------------------------------------------------- /conf/tr_dataset/melrandomshift.yaml: -------------------------------------------------------------------------------- 1 | name: AudioRandomShiftDataset 2 | parameters: 3 | split_json_path: OSUFOLDER/train.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | split_hdf5_path: OSUFOLDER/train.h5 6 | token_length: 200 7 | encoder: twotwo 8 | center: True 9 | logspec: True 10 | truncate: 92 11 | -------------------------------------------------------------------------------- /scripts/thres_tune_cnn.sh: -------------------------------------------------------------------------------- 1 | if [ $# -eq 0 ]; then 2 | echo "no args" 3 | exit 1 4 | fi 5 | diff=$1 6 | python tune_thresholds.py \ 7 | +cv_dataset=ddc_${diff} \ 8 | +model=ddc_cnn \ 9 | +ckpt_path=/mnt/c/Users/manym/Desktop/gorst/gorst/ckpts/ddc_cnn/2e4_and_64_try2/ckpt_epoch_1.pth \ 10 | +experiment=ddc 11 | -------------------------------------------------------------------------------- /scripts/thres_tune_clstm.sh: -------------------------------------------------------------------------------- 1 | if [ $# -eq 0 ]; then 2 | echo "no args" 3 | exit 1 4 | fi 5 | diff=$1 6 | python tune_thresholds.py \ 7 | +cv_dataset=ddc_${diff} \ 8 | +model=ddc_clstm \ 9 | +ckpt_path="/mnt/c/Users/manym/Desktop/gorst/gorst/ckpts/ddc_clstm/0.0002_and_64/ckpt_epoch_0.pth" \ 10 | +experiment=ddc 11 | -------------------------------------------------------------------------------- /0_ddc/ddc_to_gorst/run.sh: -------------------------------------------------------------------------------- 1 | DIR=STEPMANIAFOLDER/json_filt 2 | 3 | python main.py clean ${DIR} 4 | python main.py do ${DIR} 5 | python make_split.py do ${DIR}/itg 6 | python h5pyize_dataset.py ${DIR}/itg/test.json 7 | python h5pyize_dataset.py ${DIR}/itg/valid.json 8 | python h5pyize_dataset.py ${DIR}/itg/train.json 9 | 10 | -------------------------------------------------------------------------------- /conf/cv_dataset/beatfine_timingonly.yaml: -------------------------------------------------------------------------------- 1 | name: AudioTimingFineBeatDataset 2 | parameters: 3 | split_json_path: OSUFOLDER/valid.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | split_hdf5_path: OSUFOLDER/valid.h5 6 | token_length: 100 7 | encoder: twotwo 8 | center: True 9 | logspec: True 10 | truncate: 46 11 | -------------------------------------------------------------------------------- /conf/model/mel.yaml: -------------------------------------------------------------------------------- 1 | name: GorstFineDiffInDecoderInput 2 | parameters: 3 | encoder_hidden_size: 256 4 | decoder_hidden_size: 256 5 | encoder_layers: 3 6 | decoder_layers: 3 7 | decoder_max_length: 200 8 | condition_dim: 48 9 | norm_first: True 10 | initialize_method: "trunc_normal" 11 | initialize_std: 0.1 12 | -------------------------------------------------------------------------------- /conf/tr_dataset/beatfine_timingonly.yaml: -------------------------------------------------------------------------------- 1 | name: AudioTimingFineBeatDataset 2 | parameters: 3 | split_json_path: OSUFOLDER/train.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | split_hdf5_path: OSUFOLDER/train.h5 6 | token_length: 100 7 | encoder: twotwo 8 | center: True 9 | logspec: True 10 | truncate: 46 11 | -------------------------------------------------------------------------------- /conf/model/finediffdecoderprelnbipasttiming.yaml: -------------------------------------------------------------------------------- 1 | name: GorstFineDiffInDecoderInputTiming 2 | parameters: 3 | encoder_hidden_size: 256 4 | decoder_hidden_size: 256 5 | encoder_layers: 3 6 | decoder_layers: 3 7 | decoder_max_length: 200 8 | condition_dim: 48 9 | norm_first: True 10 | bidirectional_past_context: True 11 | -------------------------------------------------------------------------------- /conf/cv_dataset/beatfine_itg.yaml: -------------------------------------------------------------------------------- 1 | name: AudioChartFineBeatDataset 2 | parameters: 3 | split_json_path: STEPMANIAFOLDER/json_filt/itg/valid.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | split_hdf5_path: STEPMANIAFOLDER/json_filt/itg/valid.h5 6 | token_length: 200 7 | encoder: twotwo 8 | center: True 9 | logspec: True 10 | truncate: 92 11 | -------------------------------------------------------------------------------- /conf/tr_dataset/beatfine_itg.yaml: -------------------------------------------------------------------------------- 1 | name: AudioChartFineBeatDataset 2 | parameters: 3 | split_json_path: STEPMANIAFOLDER/json_filt/itg/train.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | split_hdf5_path: STEPMANIAFOLDER/json_filt/itg/train.h5 6 | token_length: 200 7 | encoder: twotwo 8 | center: True 9 | logspec: True 10 | truncate: 92 11 | -------------------------------------------------------------------------------- /conf/cv_dataset/beatfine_itg_timingonly.yaml: -------------------------------------------------------------------------------- 1 | name: AudioTimingFineBeatDataset 2 | parameters: 3 | split_json_path: STEPMANIAFOLDER/json_filt/itg/valid.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | split_hdf5_path: STEPMANIAFOLDER/json_filt/itg/valid.h5 6 | token_length: 100 7 | encoder: twotwo 8 | center: True 9 | logspec: True 10 | truncate: 46 11 | -------------------------------------------------------------------------------- /conf/cv_dataset/beatfine_fraxtil.yaml: -------------------------------------------------------------------------------- 1 | name: AudioChartFineBeatDataset 2 | parameters: 3 | split_json_path: STEPMANIAFOLDER/json_filt/fraxtil/valid.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | split_hdf5_path: STEPMANIAFOLDER/json_filt/fraxtil/valid.h5 6 | token_length: 200 7 | encoder: twotwo 8 | center: True 9 | logspec: True 10 | truncate: 92 11 | -------------------------------------------------------------------------------- /conf/tr_dataset/beatfine_fraxtil.yaml: -------------------------------------------------------------------------------- 1 | name: AudioChartFineBeatDataset 2 | parameters: 3 | split_json_path: STEPMANIAFOLDER/json_filt/fraxtil/train.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | split_hdf5_path: STEPMANIAFOLDER/json_filt/fraxtil/train.h5 6 | token_length: 200 7 | encoder: twotwo 8 | center: True 9 | logspec: True 10 | truncate: 92 11 | -------------------------------------------------------------------------------- /conf/tr_dataset/beatfine_itg_timingonly.yaml: -------------------------------------------------------------------------------- 1 | name: AudioTimingFineBeatDataset 2 | parameters: 3 | split_json_path: STEPMANIAFOLDER/json_filt/itg/train.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | split_hdf5_path: STEPMANIAFOLDER/json_filt/itg/train.h5 6 | token_length: 100 7 | encoder: twotwo 8 | center: True 9 | logspec: True 10 | truncate: 46 11 | -------------------------------------------------------------------------------- /conf/cv_dataset/beatfine_fraxtil_timingonly.yaml: -------------------------------------------------------------------------------- 1 | name: AudioTimingFineBeatDataset 2 | parameters: 3 | split_json_path: STEPMANIAFOLDER/json_filt/fraxtil/valid.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | split_hdf5_path: STEPMANIAFOLDER/json_filt/fraxtil/valid.h5 6 | token_length: 100 7 | encoder: twotwo 8 | center: True 9 | logspec: True 10 | truncate: 46 11 | -------------------------------------------------------------------------------- /conf/tr_dataset/beatfine_fraxtil_timingonly.yaml: -------------------------------------------------------------------------------- 1 | name: AudioTimingFineBeatDataset 2 | parameters: 3 | split_json_path: STEPMANIAFOLDER/json_filt/fraxtil/train.json 4 | #split_hdf5_path: /home/stetstet/valid.h5 5 | split_hdf5_path: STEPMANIAFOLDER/json_filt/fraxtil/train.h5 6 | token_length: 100 7 | encoder: twotwo 8 | center: True 9 | logspec: True 10 | truncate: 46 11 | -------------------------------------------------------------------------------- /scripts/melitg_timing_finetune.sh: -------------------------------------------------------------------------------- 1 | 2 | python train_cond_centered.py \ 3 | +tr_dataset=beatfine_itg_timingonly \ 4 | +cv_dataset=beatfine_itg_timingonly \ 5 | +experiment=test \ 6 | +model=mel \ 7 | +loss=crossentropy \ 8 | +optimizer=adam_mel_finetune \ 9 | +ckpt_path=ckpts/itg_mel_timingonly_finetune \ 10 | +load_from=ckpts/mel_timingonly/0.0002/ckpt_epoch_9.pth 11 | -------------------------------------------------------------------------------- /0_ddc/scripts/sml_onset_1_chart.sh: -------------------------------------------------------------------------------- 1 | source sml_0_push.sh 2 | 3 | python create_charts.py \ 4 | ${SMDATA_DIR}/json_filt/${1}_train.txt \ 5 | ${SMDATA_DIR}/json_filt/${1}_valid.txt \ 6 | ${SMDATA_DIR}/json_filt/${1}_test.txt \ 7 | --out_dir=${SMDATA_DIR}/chart_onset/${1}/mel80hop441 \ 8 | --chart_type=onset \ 9 | --frame_rate=44100,441 \ 10 | --feats_dir=${SMDATA_DIR}/feats/${1}/mel80hop441 11 | -------------------------------------------------------------------------------- /scripts/melfraxtil_timing_finetune.sh: -------------------------------------------------------------------------------- 1 | 2 | python train_cond_centered.py \ 3 | +tr_dataset=beatfine_fraxtil_timingonly \ 4 | +cv_dataset=beatfine_fraxtil_timingonly \ 5 | +experiment=test \ 6 | +model=mel \ 7 | +loss=crossentropy \ 8 | +optimizer=adam_mel_finetune \ 9 | +ckpt_path=ckpts/fraxtil_mel_timingonly_finetune \ 10 | +load_from=ckpts/mel_timingonly/0.0002/ckpt_epoch_9.pth 11 | -------------------------------------------------------------------------------- /0_ddc/scripts/sml_sym_1_chart_audio.sh: -------------------------------------------------------------------------------- 1 | source sml_0_push.sh 2 | 3 | python create_charts.py \ 4 | ${SMDATA_DIR}/json_filt${3}/${1}_train.txt \ 5 | ${SMDATA_DIR}/json_filt${3}/${1}_valid.txt \ 6 | ${SMDATA_DIR}/json_filt${3}/${1}_test.txt \ 7 | --out_dir=${SMDATA_DIR}/chart_sym/${1}/${2}${3} \ 8 | --chart_type=sym \ 9 | --frame_rate=44100,441 \ 10 | --feats_dir=${SMDATA_DIR}/feats/${1}/${2} 11 | -------------------------------------------------------------------------------- /0_ddc/scripts/sml_sym_2_mark.sh: -------------------------------------------------------------------------------- 1 | source sml_0_push.sh 2 | 3 | TRAIN_DIR=/tmp/ngram 4 | rm -rf ${TRAIN_DIR} 5 | mkdir -p ${TRAIN_DIR} 6 | 7 | python ngram.py \ 8 | ${SM_DIR}/data/json_filt/${1}_train.txt \ 9 | ${TRAIN_DIR}/model_${2}.pkl \ 10 | --k=${2} \ 11 | --task=train 12 | 13 | python ngram.py \ 14 | ${SM_DIR}/data/json_filt/${1}_test.txt \ 15 | ${TRAIN_DIR}/model_${2}.pkl \ 16 | --k=${2} \ 17 | --task=eval 18 | -------------------------------------------------------------------------------- /0_ddc/scripts/smd_all.sh: -------------------------------------------------------------------------------- 1 | source var.sh 2 | 3 | rm -rf ${SMDATA_DIR}/json_* 4 | mkdir ${SMDATA_DIR}/json_raw 5 | mkdir ${SMDATA_DIR}/json_filt 6 | 7 | for COLL in fraxtil itg 8 | do 9 | ./smd_1_extract.sh ${COLL} --itg 10 | done 11 | 12 | for COLL in speirs sudzi 13 | do 14 | ./smd_1_extract.sh ${COLL} --itg 15 | done 16 | 17 | for COLL in fraxtil itg speirs sudzi 18 | do 19 | ./smd_2_filter.sh ${COLL} 20 | ./smd_3_dataset.sh ${COLL} filt 21 | done 22 | -------------------------------------------------------------------------------- /0_ddc/scripts/smd_2_filter.sh: -------------------------------------------------------------------------------- 1 | source smd_0_push.sh 2 | 3 | python filter_json.py \ 4 | ${SMDATA_DIR}/json_raw/${1} \ 5 | ${SMDATA_DIR}/json_filt${2}/${1} \ 6 | --chart_types=dance-single \ 7 | --chart_difficulties=Beginner,Easy,Medium,Hard,Challenge \ 8 | --min_chart_feet=1 \ 9 | --max_chart_feet=-1 \ 10 | --substitutions=M,0,4,2 \ 11 | --arrow_types=1,2,3 \ 12 | --max_jump_size=-1 \ 13 | --remove_zeros \ 14 | --permutations=0123 #0123, 3120, 0213, 3210 in original DDC publication. For our purposes, this augmentation will confuse the model. 15 | -------------------------------------------------------------------------------- /code/loss/selector.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | import torch.nn as nn 3 | from ..utils.creator_factory import makeClassMaker 4 | 5 | def LossSelector(args): 6 | """ 7 | :params: args: actually either args.tr_loader, args.cv_loader args.ts_loaser 8 | """ 9 | print("loss\t:",args) 10 | b = OmegaConf.to_container(args) 11 | if args.name == "CrossEntropyLoss": 12 | return makeClassMaker(nn.CrossEntropyLoss, **b['parameters']) 13 | if args.name == "BCELoss": 14 | return makeClassMaker(nn.BCELoss, **b['parameters']) -------------------------------------------------------------------------------- /code/model/constants.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | SAMPLE_RATE = 44100 4 | FRAME_RATE = 100 5 | 6 | 7 | # Coarse difficulties from DDR. For general onset detection use, CHALLENGE is a good default. 8 | class Difficulty(Enum): 9 | BEGINNER = 0 10 | EASY = 1 11 | MEDIUM = 2 12 | HARD = 3 13 | CHALLENGE = 4 14 | 15 | 16 | # Thresholds tuned on DDR validation data. Here for testing / posterity and likely irrelevant for other applications. 17 | DIFFICULTY_TO_THRESHOLD = { 18 | Difficulty.BEGINNER: 0.15325437, 19 | Difficulty.EASY: 0.23268291, 20 | Difficulty.MEDIUM: 0.29456162, 21 | Difficulty.HARD: 0.29084727, 22 | Difficulty.CHALLENGE: 0.28875697, 23 | } 24 | -------------------------------------------------------------------------------- /text_replacer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | def for_one_file(inthisfile,replacethis,withthis): 5 | with open(inthisfile) as f: 6 | contents = f.read() 7 | contents = contents.replace(replacethis,withthis) 8 | with open(inthisfile,'w') as f: 9 | f.write(contents) 10 | 11 | def iterate_folders(replacethis,withthis): 12 | for root,dirs,files in os.walk('.'): 13 | for file in files: 14 | if file.endswith(".json") or file.endswith(".sh") or file.endswith(".yaml") or file.endswith(".py"): 15 | for_one_file(os.path.join(root,file),replacethis,withthis) 16 | 17 | if __name__=="__main__": 18 | iterate_folders(sys.argv[1],sys.argv[2]) 19 | 20 | -------------------------------------------------------------------------------- /0_ddc/dataset/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def ez_name(x): 4 | x = ''.join(x.strip().split()) 5 | x_clean = [] 6 | for char in x: 7 | if char.isalnum(): 8 | x_clean.append(char) 9 | else: 10 | x_clean.append('_') 11 | return ''.join(x_clean) 12 | 13 | def get_subdirs(root, choose=False): 14 | subdir_names = sorted(filter(lambda x: os.path.isdir(os.path.join(root, x)), os.listdir(root))) 15 | if choose: 16 | for i, subdir_name in enumerate(subdir_names): 17 | print '{}: {}'.format(i, subdir_name) 18 | subdir_idxs = [int(x) for x in raw_input('Which subdir(s)? ').split(',')] 19 | subdir_names = [subdir_names[i] for i in subdir_idxs] 20 | return subdir_names -------------------------------------------------------------------------------- /0_ddc/scripts/smd_3_dataset.sh: -------------------------------------------------------------------------------- 1 | source smd_0_push.sh 2 | 3 | python dataset_json.py \ 4 | ${SMDATA_DIR}/json_filt/${1} \ 5 | --splits=8,1,1 \ 6 | --splitnames=train,valid,test \ 7 | --shuffle \ 8 | --shuffle_seed=1337 9 | 10 | rm ${SMDATA_DIR}/json_filt/*${1}*.txt 11 | for f in ${SMDATA_DIR}/json_filt/${1}/*train*.txt; do (cat "${f}"; echo) >> ${SMDATA_DIR}/json_filt/${1}_train.txt; done 12 | for f in ${SMDATA_DIR}/json_filt/${1}/*valid*.txt; do (cat "${f}"; echo) >> ${SMDATA_DIR}/json_filt/${1}_valid.txt; done 13 | for f in ${SMDATA_DIR}/json_filt/${1}/*test*.txt; do (cat "${f}"; echo) >> ${SMDATA_DIR}/json_filt/${1}_test.txt; done 14 | for f in ${SMDATA_DIR}/json_filt/${1}/*.txt; do (cat "${f}"; echo) >> ${SMDATA_DIR}/json_filt/${1}.txt; done 15 | -------------------------------------------------------------------------------- /1_preindex_similarity_matrix/cleanup.py: -------------------------------------------------------------------------------- 1 | """ 2 | cleans up after a cache_similar_beat_index.py run. 3 | 4 | given a folder of folders, deletes folders that do not have equal numbers of ".osu.json" files and ".osu.json.beat.json" files. 5 | """ 6 | import os 7 | import sys 8 | import shutil 9 | 10 | def one_folder(dir): 11 | b = os.listdir(dir) 12 | one = len([e for e in b if e.endswith(".osu.json")]) 13 | two = len([e for e in b if e.endswith(".osu.json.beat.json")]) 14 | if one != two: 15 | print(f"delete {dir}!") 16 | shutil.rmtree(dir) 17 | 18 | 19 | def do(dir_of_dirs): 20 | dirs = [os.path.join(dir_of_dirs,e) for e in os.listdir(dir_of_dirs) if os.path.join(dir_of_dirs, e)] 21 | for dir in dirs: 22 | one_folder(dir) 23 | 24 | if __name__=="__main__": 25 | do(sys.argv[1]) 26 | -------------------------------------------------------------------------------- /0_ddc/scripts/sml_onset_3_eval.sh: -------------------------------------------------------------------------------- 1 | SM_DIR=${WORK}/sm 2 | EXP_DIR=${SM_DIR}/trained/onset/17_01_23_00_cnn_diff_eval 3 | 4 | pushd ../smlearn 5 | 6 | python onset_train.py \ 7 | --test_txt_fp=${SM_DIR}/data/chart_onset/fraxtil_test.txt \ 8 | --model_ckpt_fp=${EXP_DIR}/onset_net_early_stop-88800 \ 9 | --audio_context_radius=7 \ 10 | --audio_nbands=80 \ 11 | --audio_nchannels=3 \ 12 | --feat_diff_feet_to_id_fp=${SM_DIR}/data/labels/fraxtil/diff_feet_to_id.txt \ 13 | --cnn_filter_shapes=7,3,10,3,3,20 \ 14 | --cnn_pool=1,3,1,3 \ 15 | --rnn_cell_type=lstm \ 16 | --rnn_size=200 \ 17 | --rnn_nlayers=0 \ 18 | --rnn_nunroll=1 \ 19 | --rnn_keep_prob=0.5 \ 20 | --dnn_sizes=256,128 \ 21 | --dnn_keep_prob=0.5 \ 22 | --batch_size=256 \ 23 | --weight_strategy=rect \ 24 | --exclude_onset_neighbors=2 \ 25 | --exclude_pre_onsets \ 26 | --exclude_post_onsets \ 27 | --experiment_dir=${EXP_DIR} \ 28 | --eval_hann_width=5 \ 29 | --eval_align_tolerance=2 30 | -------------------------------------------------------------------------------- /code/optimizer/selector.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | import torch 3 | from ..utils.creator_factory import makeClassMaker 4 | 5 | def OptimizerSelector(args): 6 | """ 7 | :params: args: actually args.optimizer 8 | """ 9 | b = OmegaConf.to_container(args) 10 | if args.name == "adam" or args.name=="Adam": 11 | return makeClassMaker(torch.optim.Adam, 12 | lr=b['parameters']['lr'], 13 | betas=(b['parameters']['beta1'],b['parameters']['beta2']) ) 14 | if args.name == "AdamW" or args.name=="adamw": 15 | return makeClassMaker(torch.optim.AdamW, 16 | lr=b['parameters']['lr'], 17 | betas=(b['parameters']['beta1'],b['parameters']['beta2']), 18 | weight_decay=b['parameters']['weight_decay']) 19 | if args.name == "sgd": 20 | return makeClassMaker(torch.optim.SGD, lr=b['parameters']['lr']) 21 | -------------------------------------------------------------------------------- /0_ddc/scripts/sml_onset_4_export.sh: -------------------------------------------------------------------------------- 1 | source sml_0_push.sh 2 | 3 | SM_DIR=/work/03860/cdonahue/maverick/sm 4 | 5 | EXP_DIR=${SM_DIR}/trained/onset/17_02_05_00_fraxnew_cnn_export 6 | 7 | python onset_train.py \ 8 | --train_txt_fp=${SM_DIR}/data/chart_onset/fraxtil/mel80nfft3/fraxtil_train.txt \ 9 | --valid_txt_fp=${SM_DIR}/data/chart_onset/fraxtil/mel80nfft3/fraxtil_valid.txt \ 10 | --z_score \ 11 | --test_txt_fp=${SM_DIR}/data/chart_onset/fraxtil/mel80nfft3/fraxtil_test.txt \ 12 | --model_ckpt=${EXP_DIR}/onset_net_early_stop_auprc-312000 \ 13 | --export_feat_name=cnn_1 \ 14 | --audio_context_radius=7 \ 15 | --audio_nbands=80 \ 16 | --audio_nchannels=3 \ 17 | --audio_select_channels=0,1,2 \ 18 | --cnn_filter_shapes=7,3,10,3,3,20 \ 19 | --cnn_pool=1,3,1,3 \ 20 | --rnn_cell_type=lstm \ 21 | --rnn_size=200 \ 22 | --rnn_nlayers=0 \ 23 | --rnn_nunroll=1 \ 24 | --rnn_keep_prob=0.5 \ 25 | --dnn_sizes=256,128 \ 26 | --dnn_keep_prob=0.5 \ 27 | --batch_size=256 \ 28 | --experiment_dir=${EXP_DIR} 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Jayeon Yi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /code/model/selector.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | from .gorstconditionininput import GorstFineDiffInDecoderInput 3 | from .gorstconditionininputtiming import GorstFineDiffInDecoderInputTiming 4 | from .cnn import PlacementCNN 5 | from .clstm import PlacementCLSTM 6 | 7 | from ..utils.creator_factory import makeClassMaker 8 | 9 | def ModelSelector(args): 10 | """ 11 | :params: args: actually either args.tr_loader, args.cv_loader args.ts_loaser 12 | """ 13 | print("model\t:",args) 14 | b = OmegaConf.to_container(args) 15 | if args.name == "GorstFineDiffInDecoderInputTiming": 16 | return makeClassMaker(GorstFineDiffInDecoderInputTiming, **b['parameters']) 17 | 18 | if args.name == "GorstFineDiffInDecoderInput": # keep 19 | return makeClassMaker(GorstFineDiffInDecoderInput, **b['parameters']) 20 | 21 | elif args.name == "DDCCNN": #keep 22 | return makeClassMaker(PlacementCNN, **b['parameters']) 23 | 24 | elif args.name == "DDCCLSTM": #keep 25 | return makeClassMaker(PlacementCLSTM, **b['parameters']) 26 | 27 | raise NotImplementedError("Invalid Model") 28 | 29 | 30 | -------------------------------------------------------------------------------- /0_ddc/scripts/sml_sym_2_train.sh: -------------------------------------------------------------------------------- 1 | source sml_0_push.sh 2 | 3 | EXP_DIR=/tmp/train_sym 4 | rm -rf ${EXP_DIR} 5 | mkdir -p ${EXP_DIR} 6 | 7 | python sym_train.py \ 8 | --train_txt_fp=${SM_DIR}/data/chart_sym/${1}/symbolic/${1}_train.txt \ 9 | --valid_txt_fp=${SM_DIR}/data/chart_sym/${1}/symbolic/${1}_valid.txt \ 10 | --sym_in_type=bagofarrows \ 11 | --sym_out_type=onehot \ 12 | --sym_narrows=4 \ 13 | --sym_narrowclasses=4 \ 14 | --sym_embedding_size=0 \ 15 | --feat_time_diff \ 16 | --feat_time_diff_next \ 17 | --batch_size=64 \ 18 | --nunroll=64 \ 19 | --cnn_filter_shapes= \ 20 | --cnn_pool= \ 21 | --rnn_cell_type=lstm \ 22 | --rnn_size=128 \ 23 | --rnn_nlayers=2 \ 24 | --rnn_keep_prob=0.5 \ 25 | --dnn_sizes= \ 26 | --dnn_keep_prob=0.5 \ 27 | --grad_clip=5.0 \ 28 | --opt=sgd \ 29 | --lr=1.0 \ 30 | --lr_decay_rate=1.0 \ 31 | --lr_decay_delay=10 \ 32 | --nbatches_per_ckpt=200 \ 33 | --nbatches_per_eval=200 \ 34 | --nepochs=1000 \ 35 | --experiment_dir=${EXP_DIR} 36 | -------------------------------------------------------------------------------- /code/counter/osu4kclstmcounter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from .osu4kcounter import OneSeventyEightCounter 5 | from .vocabulary import OneThirtyVocabulary, OneSeventyEightVocabulary 6 | 7 | 8 | class Osu4kCLSTMCounter(OneSeventyEightCounter): 9 | def __init__(self, threshold): 10 | super().__init__() 11 | self.threshold = threshold 12 | 13 | def update(self, pred, ref): 14 | """ 15 | ref, pred: must be of identical shape. 16 | 17 | ref: an array of 1s and 0s. 18 | pred: an array of possibility values between 0 and 1. 19 | """ 20 | pred = (pred > self.threshold) 21 | ref = (ref == 1) 22 | 23 | self.true_positive += float(torch.sum(ref & pred)) 24 | self.false_positive += float(torch.sum((~ref) & pred)) 25 | self.false_negative += float(torch.sum(ref & (~pred))) 26 | 27 | 28 | def test(): 29 | counter = Osu4kCLSTMCounter(0.5) 30 | counter.update(torch.tensor([[0, 0, 0, 1], [1, 1, 1, 1]]), torch.tensor([[0, 0, 0, 1], [1, 0, 1, 1]])) 31 | print(counter.precision()) 32 | print(counter.recall()) 33 | print(counter.f1()) 34 | 35 | if __name__=="__main__": 36 | test() -------------------------------------------------------------------------------- /1_preindex_similarity_matrix/cache_similarity_matrix.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import copy 5 | import scipy 6 | import numpy as np 7 | import torch 8 | import soundfile as sf 9 | import h5py 10 | from tqdm import tqdm 11 | from make_similarity_matrix import make_beatwise_similarity_matrix_and_beat_to_sample_map 12 | 13 | def cache_similarity_matrix(split_file, h5_output): 14 | """ 15 | split_file: json file with filename(osujson): [bar_length beat_length] pairs 16 | """ 17 | with open(split_file) as file: 18 | d = json.load(file) 19 | dest_path = os.path.dirname(split_file) + '/' 20 | print(dest_path) 21 | with h5py.File(h5_output,'w') as h5f: 22 | key_list = list(d.keys()) 23 | pbar = tqdm(key_list) 24 | for k in pbar: 25 | h5_key = k.replace(dest_path, '') 26 | sim_matrix, _ = make_beatwise_similarity_matrix_and_beat_to_sample_map(k) 27 | h5f.create_dataset(h5_key, data=sim_matrix) 28 | pbar.set_description(str(sim_matrix.shape)) 29 | 30 | if __name__=="__main__": 31 | cache_similarity_matrix('OSUFOLDER/train.json', 'train_m.h5') 32 | cache_similarity_matrix('OSUFOLDER/valid.json', 'valid_m.h5') 33 | cache_similarity_matrix('OSUFOLDER/test.json', 'test_m.h5') 34 | -------------------------------------------------------------------------------- /code/counter/vocabulary.py: -------------------------------------------------------------------------------- 1 | class SixtyTwoVocabulary: 2 | def __init__(self): 3 | self.beat_token = 48 4 | self.time_token = {i:i for i in range(48)} 5 | self.tap_token = {1:49, 2:50, 3:51, 4:52} 6 | self.hold_token = {1:53, 2:54, 3:55, 4:56} 7 | self.release_token = {1:57, 2:58, 3:59, 4:60} 8 | self.bar_token=61 9 | 10 | def make_possible_combinations(): 11 | ret = [] 12 | for a in range(3): 13 | for b in range(3): 14 | for c in range(3): 15 | for d in range(3): 16 | if a+b+c+d==0: 17 | continue 18 | combination = f"{a}{b}{c}{d}" 19 | ret.append(combination) 20 | return ret 21 | 22 | class OneThirtyVocabulary: 23 | def __init__(self): 24 | self.beat_token = 48 25 | self.time_token = {i:i for i in range(48)} 26 | self.chart_token = {c:49+n for n,c in enumerate(make_possible_combinations())} # 81 total 27 | self.bar_token = 129 28 | 29 | 30 | class OneSeventyEightVocabulary: 31 | def __init__(self): 32 | self.beat_token = 96 33 | self.time_token = {i:i for i in range(96)} 34 | self.chart_token = {c:97+n for n,c in enumerate(make_possible_combinations())} # 81 total 35 | self.bar_token = 177 -------------------------------------------------------------------------------- /code/loader/vocabulary.py: -------------------------------------------------------------------------------- 1 | class SixtyTwoVocabulary: 2 | def __init__(self): 3 | self.beat_token = 48 4 | self.time_token = {i:i for i in range(48)} 5 | self.tap_token = {1:49, 2:50, 3:51, 4:52} 6 | self.hold_token = {1:53, 2:54, 3:55, 4:56} 7 | self.release_token = {1:57, 2:58, 3:59, 4:60} 8 | self.bar_token=61 9 | 10 | def make_possible_combinations(): 11 | ret = [] 12 | for a in range(3): 13 | for b in range(3): 14 | for c in range(3): 15 | for d in range(3): 16 | if a+b+c+d==0: 17 | continue 18 | combination = f"{a}{b}{c}{d}" 19 | ret.append(combination) 20 | return ret 21 | 22 | class OneThirtyVocabulary: 23 | def __init__(self): 24 | self.beat_token = 48 25 | self.time_token = {i:i for i in range(48)} 26 | self.chart_token = {c:49+n for n,c in enumerate(make_possible_combinations())} # 81 total 27 | self.bar_token = 129 28 | 29 | 30 | class OneSeventyEightVocabulary: 31 | def __init__(self): 32 | self.beat_token = 96 33 | self.time_token = {i:i for i in range(96)} 34 | self.chart_token = {c:97+n for n,c in enumerate(make_possible_combinations())} # 81 total 35 | self.bar_token = 177 -------------------------------------------------------------------------------- /0_ddc/scripts/sml_onset_2_train.sh: -------------------------------------------------------------------------------- 1 | source sml_0_push.sh 2 | 3 | EXP_DIR=/tmp/train 4 | rm -rf ${EXP_DIR} 5 | mkdir -p ${EXP_DIR} 6 | 7 | python onset_train.py \ 8 | --train_txt_fp=${SM_DIR}/data/chart_onset/${1}/mel80hop441/${1}_train.txt \ 9 | --valid_txt_fp=${SM_DIR}/data/chart_onset/${1}/mel80hop441/${1}_valid.txt \ 10 | --z_score \ 11 | --audio_context_radius=7 \ 12 | --audio_nbands=80 \ 13 | --audio_nchannels=3 \ 14 | --audio_select_channels=0,1,2 \ 15 | --feat_diff_coarse_to_id_fp=${SM_DIR}/labels/${1}/diff_coarse_to_id.txt \ 16 | --cnn_filter_shapes=7,3,10,3,3,20 \ 17 | --cnn_pool=1,3,1,3 \ 18 | --rnn_cell_type=lstm \ 19 | --rnn_size=200 \ 20 | --rnn_nlayers=0 \ 21 | --rnn_nunroll=1 \ 22 | --rnn_keep_prob=0.5 \ 23 | --dnn_nonlin=sigmoid \ 24 | --dnn_sizes=256,128 \ 25 | --dnn_keep_prob=0.5 \ 26 | --batch_size=256 \ 27 | --weight_strategy=rect \ 28 | --nobalanced_class \ 29 | --exclude_onset_neighbors=2 \ 30 | --exclude_pre_onsets \ 31 | --exclude_post_onsets \ 32 | --grad_clip=5.0 \ 33 | --opt=sgd \ 34 | --lr=0.1 \ 35 | --lr_decay_rate=1.0 \ 36 | --lr_decay_delay=0 \ 37 | --nbatches_per_ckpt=4000 \ 38 | --nbatches_per_eval=4000 \ 39 | --nepochs=128 \ 40 | --experiment_dir=${EXP_DIR} \ 41 | --eval_window_type=hamming \ 42 | --eval_window_width=5 \ 43 | --eval_align_tolerance=2 44 | -------------------------------------------------------------------------------- /code/loader/selector.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | from .audiochartfinebeatdataset import AudioChartFineBeatDataset 3 | from .audiotimingfinebeatdataset import AudioTimingFineBeatDataset 4 | from .audiorandomshift import AudioRandomShiftDataset 5 | from .audiocharttwobeatsongdataset import AudioChartTwoBeatSongDataset 6 | from .audiotimingtwobeatsongdataset import AudioTimingTwoBeatSongDataset 7 | from .ddcdataset import DDCDataset 8 | 9 | from ..utils.creator_factory import makeClassMaker 10 | 11 | def DatasetSelector(args): 12 | """ 13 | :params: args: actually either args.tr_loader, args.cv_loader args.ts_loaser 14 | """ 15 | print("dataset\t:",args) 16 | b = OmegaConf.to_container(args) 17 | if args.name == "AudioChartFineBeatDataset": # **mel 18 | return makeClassMaker(AudioChartFineBeatDataset, **b['parameters']) 19 | elif args.name == "AudioTimingFineBeatDataset": 20 | return makeClassMaker(AudioTimingFineBeatDataset, **b['parameters']) # timingonly 21 | elif args.name == "AudioRandomShiftDataset": # mel, random shift 22 | return makeClassMaker(AudioRandomShiftDataset, **b['parameters']) 23 | elif args.name == "AudioChartTwoBeatSongDataset": # **mel, eval 24 | return makeClassMaker(AudioChartTwoBeatSongDataset, **b['parameters']) 25 | elif args.name == "AudioTimingTwoBeatSongDataset": # fraxtil, itg timingonly, eval 26 | return makeClassMaker(AudioTimingTwoBeatSongDataset, **b['parameters']) 27 | elif args.name == "DDCDataset": 28 | return makeClassMaker(DDCDataset, **b['parameters']) 29 | -------------------------------------------------------------------------------- /0_ddc/ddc_to_gorst/main.py: -------------------------------------------------------------------------------- 1 | from ddcjson_to_osujson import ddcjson_to_osujson 2 | from cache_similar_beat_index import do 3 | import os 4 | import sys 5 | from tqdm import tqdm 6 | 7 | def cleanup(biggest_folder): 8 | for root, dirs, files in os.walk(biggest_folder): 9 | for file in files: 10 | if file.endswith(".osu.json") or file.endswith(".osu.json.beat.json"): 11 | os.remove(os.path.join(root,file)) 12 | 13 | def run_on_all(biggest_folder): 14 | for root, dirs, files in os.walk(biggest_folder): 15 | print(root) 16 | for file in tqdm(files): 17 | if file.endswith(".json") and not file.endswith(".osu.json"): 18 | file_base = os.path.join(root,file[:-5]) 19 | ddcjson_to_osujson(file_base) 20 | 21 | def run_matrix_on_all(list_of_folder_of_folders): 22 | for folder_of_folders in list_of_folder_of_folders: 23 | do(folder_of_folders) 24 | 25 | if __name__ == "__main__": 26 | if len(sys.argv) < 3: 27 | print("==HELP==") 28 | print("Supply me with two or more arguments.") 29 | print("python main.py (clean / do) (path to json_filt)") 30 | exit(0) 31 | 32 | if sys.argv[1] == "clean": 33 | cleanup(sys.argv[2]) 34 | elif sys.argv[1] == "do": 35 | the_path = sys.argv[2] 36 | print("converting ddc.json to osu.jsons...") 37 | run_on_all(the_path) 38 | print("fraxtil") 39 | do(os.path.join(the_path,'fraxtil')) 40 | print("itg") 41 | #do(os.path.join(the_path,'itg')) 42 | elif sys.argv[1] == "matrix": 43 | the_path = sys.argv[2] 44 | print("fraxtil") 45 | do(os.path.join(the_path,'fraxtil')) 46 | print("itg") 47 | do(os.path.join(the_path,'itg')) -------------------------------------------------------------------------------- /0_ddc/ddc_to_gorst/ddcjson_to_osujson.py: -------------------------------------------------------------------------------- 1 | """ 2 | Converts DDC-type json files to our osu-type json files. 3 | 4 | Due to some inconveniences that arose in beat-aligned training, 5 | jsons from DDC and our framework has minute differences that need to be resolved before training can be carried out. 6 | 7 | Code written by stet-stet (Jayeon Yi) 8 | """ 9 | 10 | import json 11 | import copy 12 | 13 | def get_48th_index(num): 14 | return round(num*48)%48 15 | 16 | def ddcjson_to_osujson(ddcjson_file_base): 17 | """ 18 | Differences: 19 | - .ddc.json has multiple charts in one file; .osu.json has each chart in a separate file 20 | - offset sign is backwards 21 | - four beats must equal 192. each beat[0][2] entry should be in multiples of 48. 22 | - .osu.json has an extra "bar_to_beat" field that translates bars to beats 23 | - this information will not be needed as of now, so this is skipped. 24 | - If we need this in the future, be sure to start from the original .sm files. 25 | """ 26 | ddcjson_file = f"{ddcjson_file_base}.json" 27 | with open(ddcjson_file) as file: 28 | a = json.load(file) 29 | 30 | # inspect each beat and determine which of 48 this has to be. 31 | 32 | all_except_charts = copy.deepcopy(a) 33 | del all_except_charts['charts'] 34 | all_except_charts["offset"] = - (all_except_charts["offset"] * 1000) 35 | 36 | for n, chart in enumerate(a['charts']): 37 | to_output = copy.deepcopy(all_except_charts) 38 | to_output.update({"charts": [chart]}) 39 | 40 | for note in to_output['charts'][0]['notes']: 41 | note[0][2] = get_48th_index(note[1]) 42 | 43 | output_filename = f"{ddcjson_file_base}_{n}.osu.json" 44 | with open(output_filename,'w') as file: 45 | file.write(json.dumps(to_output,indent=4)) 46 | 47 | 48 | if __name__ == "__main__": 49 | ddcjson_to_osujson("examples/example/example_ddc") -------------------------------------------------------------------------------- /0_ddc/ddc_to_gorst/make_split.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | from generate_dataset_peripherals import do 5 | 6 | def make_summary(dir_of_dirs): 7 | print("summarizing...") 8 | do(dir_of_dirs, make_splits=False) 9 | 10 | def translate_splits(dir_of_dirs): 11 | print("splitting...") 12 | # gather all txt files with the name "*_train.txt", "*_valid.txt", "*_test.txt" 13 | l = [e for e in os.listdir(dir_of_dirs) if e.endswith(".txt")] 14 | train, valid, test = [],[],[] 15 | trainjson, validjson, testjson = {}, {}, {} 16 | for textfile in l: 17 | with open(os.path.join(dir_of_dirs, textfile)) as file: 18 | b = [e.strip() for e in file.readlines()] 19 | if "_train.txt" in textfile: train.extend(b) 20 | elif "_valid.txt" in textfile: valid.extend(b) 21 | elif "_test.txt" in textfile: test.extend(b) 22 | 23 | with open(os.path.join(dir_of_dirs,"summary.json")) as file: 24 | summary = json.load(file) 25 | 26 | for key in summary: # this is slow and I know that, sorry 27 | to_look_for = "_".join(key.split('_')[:-1]) + ".json" 28 | for split, jsonfile in [(train, trainjson),(valid, validjson),(test,testjson)]: 29 | if to_look_for in split: 30 | jsonfile.update({key: summary[key]}) 31 | break 32 | 33 | for name, jsonfile in [("train.json",trainjson),("valid.json",validjson),("test.json",testjson)]: 34 | with open(os.path.join(dir_of_dirs, name),'w') as file: 35 | file.write(json.dumps(jsonfile,indent=4)) 36 | 37 | 38 | if __name__=="__main__": 39 | if len(sys.argv) < 3: 40 | print("==HELP==") 41 | print("Supply me with two or more arguments.") 42 | print(f"python {sys.argv[1]} (clean / do) (path to json_filt/(itg or fraxtil))") 43 | exit(0) 44 | 45 | elif sys.argv[1] == "summarize": 46 | the_path = sys.argv[2] 47 | make_summary(the_path) 48 | 49 | elif sys.argv[1] == "do": 50 | the_path = sys.argv[2] 51 | make_summary(the_path) 52 | translate_splits(the_path) 53 | -------------------------------------------------------------------------------- /0_ddc/dataset/preview_sm.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | _TEMPL = """\ 5 | #TITLE:{title}; 6 | #ARTIST:{artist}; 7 | #MUSIC:{music_fp}; 8 | #OFFSET:0.0; 9 | #BPMS:0.0={bpm}; 10 | #STOPS:; 11 | {charts}\ 12 | """ 13 | 14 | _CHART_TEMPL = """\ 15 | #NOTES: 16 | {ctype}: 17 | {cversion}: 18 | {ccoarse}: 19 | {cfine}: 20 | 0.500,0.500,0.500,0.500,0.500: 21 | {measures};\ 22 | """ 23 | 24 | def meta_to_sm(meta): 25 | subdiv = 64 26 | dt = 512.0 / 44100.0 27 | # seconds per minute * timesteps per second * measures per subdivision * beats per measure 28 | bpm = 60 * (1.0 / dt) * (1.0 / float(subdiv)) * 4.0 29 | 30 | charts = [] 31 | for chart in meta['charts']: 32 | ctype = chart['type'] 33 | cversion = 'Stepnet' 34 | ccoarse = chart['difficulty_coarse'] 35 | cfine = chart['difficulty_fine'] 36 | cnotes = chart['notes'] 37 | 38 | measures = [] 39 | timestep_to_code = {int(round(t / dt)) : code for _, t, code in cnotes} 40 | max_s = cnotes[-1][1] + 15.0 41 | max_timestep = int(round(max_s / dt)) 42 | if max_timestep % subdiv != 0: 43 | max_timestep += subdiv - (max_timestep % subdiv) 44 | 45 | null_code = '0' * len(cnotes[0][2]) 46 | timesteps = [timestep_to_code.get(i, null_code) for i in xrange(max_timestep)] 47 | measures = [timesteps[i:i+subdiv] for i in xrange(0, max_timestep, subdiv)] 48 | measures_txt = '\n,\n'.join(['\n'.join(measure) for measure in measures]) 49 | 50 | chart_txt = _CHART_TEMPL.format( 51 | ctype=ctype, 52 | cversion=cversion, 53 | ccoarse=ccoarse, 54 | cfine=cfine, 55 | cgroove='', 56 | measures=measures_txt) 57 | 58 | charts.append(chart_txt) 59 | 60 | return _TEMPL.format( 61 | title=meta['title'], 62 | artist=meta['artist'], 63 | music_fp=meta['music_fp'], 64 | bpm=bpm, 65 | charts='\n'.join(charts)) 66 | 67 | if __name__ == '__main__': 68 | json_fp, sm_fp = sys.argv[1:3] 69 | with open(json_fp, 'r') as f: 70 | meta = json.loads(f.read()) 71 | 72 | sm_txt = meta_to_sm(meta) 73 | 74 | with open(sm_fp, 'w') as f: 75 | f.write(sm_txt) -------------------------------------------------------------------------------- /1_preindex_similarity_matrix/cache_similar_beat_index.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import json 5 | from tqdm import tqdm 6 | import numpy as np 7 | from make_similarity_matrix import make_beatwise_similarity_matrix_and_beat_to_sample_map 8 | 9 | def folder_is_sane(folder): 10 | audiofiles = [e for e in os.listdir(folder) if not (e.endswith(".osu") or e.endswith(".json"))] 11 | return len(audiofiles)==1 12 | 13 | def append_similar_beat_index(folder): 14 | # assert folder_is_sane(folder) <- not necessarily true for DDC. 15 | #we make index for all charts there is. 16 | contents = os.listdir(folder) 17 | ddc_jsons = [os.path.join(folder,e) for e in contents if e.endswith('.osu.json')] 18 | output_filenames = [f"{e}.beat.json" for e in ddc_jsons] 19 | 20 | for ddc_json,output_file in zip(ddc_jsons, output_filenames): 21 | sim_matrix, beat_to_sample_map = make_beatwise_similarity_matrix_and_beat_to_sample_map(ddc_json) 22 | length = sim_matrix.shape[0] # == sim_matrix.shape[1] 23 | ret = {"beat_to_sample_map":beat_to_sample_map, "use_bars":None} 24 | the_index = {} 25 | for row_id in range(length): 26 | top5_ind = [int(e) for e in np.argpartition(sim_matrix[row_id], -5)[-5:]] 27 | if row_id not in top5_ind: 28 | top5_ind.append(row_id) 29 | top5_ind = [e for e in top5_ind if e<=row_id] 30 | the_index.update({int(row_id): top5_ind}) 31 | ret['use_bars']=the_index 32 | with open(output_file,'w') as file: 33 | file.write(json.dumps(ret,indent=4)) 34 | 35 | 36 | def do(folder_of_folders): 37 | assert os.path.isdir(folder_of_folders) 38 | a = os.listdir(folder_of_folders) 39 | folders = [os.path.join(folder_of_folders,e) for e in a if os.path.isdir(os.path.join(folder_of_folders,e))] 40 | pbar = tqdm(folders, position=0, leave=True) 41 | for folder in pbar: 42 | pbar.set_description(str(folder)) 43 | try: 44 | append_similar_beat_index(folder) 45 | except Exception as e: 46 | print('') 47 | print(e) 48 | print(f"{folder} is not a typical osu folder, or I ran out of RAM.") 49 | 50 | if __name__=="__main__": 51 | do(sys.argv[1]) -------------------------------------------------------------------------------- /0_ddc/README.md: -------------------------------------------------------------------------------- 1 | # Data preprocessing code from DDC 2 | 3 | ## Building Datsaet (part 1) 4 | 5 | 1. Install Python 2.7. 6 | ``` 7 | # for conda or mamba 8 | conda create -n ddc python=2.7 9 | mamba create -n ddc python=2.7 10 | ``` 11 | 2. Make a directory wherever you'd like, and change `scripts/var.sh`. The variable `SMDATA_DIR` must point to the directory you just made. 12 | 3. Under this directory, make directories `raw`, `json_raw` and `json_filt` 13 | 4. Under `data/raw` make directories `fraxtil` and `itg` 14 | 5. Under `data/raw/fraxtil`, download and unzip: 15 | - [Tsunamix III](https://fra.xtil.net/simfiles/data/tsunamix/III/Tsunamix%20III%20%5BSM5%5D.zip) 16 | - [Fraxtil's Arrow Arrangements](https://fra.xtil.net/simfiles/data/arrowarrangements/Fraxtil's%20Arrow%20Arrangements%20%5BSM5%5D.zip) 17 | - [Fraxtil's Beast Beats](https://fra.xtil.net/simfiles/data/beastbeats/Fraxtil's%20Beast%20Beats%20%5BSM5%5D.zip) 18 | 6. Under `data/raw/itg`, download and unzip: 19 | - [In the Groove](https://search.stepmaniaonline.net/link/In%20The%20Groove%201.zip) 20 | - [In the Groove 2](https://search.stepmaniaonline.net/link/In%20The%20Groove%202.zip) 21 | 7. Navigate to `scripts/` 22 | 8. Run the following. 23 | ``` 24 | ./all.sh ./smd_1_extract.sh # parsing .sm files to .json 25 | ./all.sh ./smd_2_filter.sh # filter (removing mines, etc.) 26 | ./all.sh ./smd_3_dataset.sh # split dataset 80/10/10 27 | ./smd_4_analyze.sh fraxtil # analyzes dataset 28 | ``` 29 | 30 | ## Building Dataset (part 2) 31 | 32 | The above used the code included with [DDC](https://github.com/chrisdonahue/ddc). What follows afterwards are the procedures to make the dataset compatible to our system. 33 | 34 | 1. Now switch to an environment with python 3. (Our codes have been tested in version 3.9.16.) 35 | 2. navigate to `ddc_to_gorst` 36 | 3. Locate the folder where you had put the data. Locate the `json_filt`, where all the filtered jsons should be placed. Let this directory be `DIR/json_filt`. Run: 37 | ``` 38 | python main.py do "DIR/json_filt" # makes dataset peripherals needed to generate h5py files. 39 | python make_split.py do "DIR/json_filt/itg" # translates the splits made by DDC into our json-like format 40 | python h5pyize_dataset.py "DIR/json_filt/itg/test.json" 41 | python h5pyize_dataset.py "DIR/json_filt/itg/valid.json" 42 | python h5pyize_dataset.py "DIR/json_filt/itg/train.json" 43 | ``` 44 | 4. And now you have the h5 files in `DIR/json_filt/itg/`! now navigate to the top, go to `conf/tr_dataset/` and modify `beatfine_itg` accordingly. 45 | 46 | -------------------------------------------------------------------------------- /2_generate_dataset/generate_dataset_peripherals.py: -------------------------------------------------------------------------------- 1 | """ 2 | given dataset folder, generates the following: 3 | - the list of all files with an ".osu.json.beat.json" extension, a.k.a. ones to be used for training 4 | - the length (in beats) of each file 5 | - the total number of existing files 6 | 7 | This information will be used to init the DataLoader. 8 | 9 | all information is saved in -> $dir / summary.json 10 | 11 | """ 12 | import os 13 | import sys 14 | import json 15 | from tqdm import tqdm 16 | import random 17 | 18 | def get_length_in_bars(osujson_filename): 19 | with open(osujson_filename,'r') as file: 20 | d = json.loads(file.read()) 21 | charts = d['charts'][0]['notes'] 22 | length_in_bars = charts[-1][0][0] 23 | return length_in_bars 24 | 25 | def get_length_in_beats(osujsonbeatjson_filename): 26 | with open(osujsonbeatjson_filename,'r') as file: 27 | d = json.loads(file.read()) 28 | final_beat = d['beat_to_sample_map'][-1][0] 29 | length_in_beats = round(final_beat) 30 | return length_in_beats 31 | 32 | def do(dir_of_dirs, make_splits=True): 33 | random.seed(10101010) 34 | output_file = os.path.join(dir_of_dirs,'summary.json') 35 | all_files = {} 36 | # use first 1500 for training, all else for validation. 37 | for root, dirs, files in tqdm(os.walk(dir_of_dirs)): 38 | for file in files: 39 | if file.endswith(".osu.json"): 40 | file_fullpath = os.path.join(root,file) 41 | length_in_bars = get_length_in_bars(file_fullpath) 42 | length_in_beats = get_length_in_beats(f"{file_fullpath}.beat.json") 43 | all_files.update({file_fullpath: (length_in_bars, length_in_beats)}) 44 | with open(output_file,'w') as file: 45 | file.write(json.dumps(all_files,indent=4)) 46 | # split train-valid-test 8:1:1 (approx) 47 | all_paths = list( set( [os.path.dirname(e) for e in all_files.keys()] ) ) 48 | random.shuffle(all_paths) 49 | 50 | if not make_splits: 51 | return 52 | 53 | split_names = ['train','test','valid'] 54 | split_slices = [(0,1600),(1600,-200),(-200,len(all_paths))] 55 | for split_name, (a,b) in zip(split_names,split_slices): 56 | split_output = os.path.join(dir_of_dirs, f'{split_name}.json') 57 | split = {k:v for k,v in all_files.items() if os.path.dirname(k) in all_paths[a:b]} 58 | with open(split_output,'w') as file: 59 | file.write(json.dumps(split, indent=4)) 60 | 61 | 62 | if __name__=="__main__": 63 | do(sys.argv[1]) -------------------------------------------------------------------------------- /generate_ref.py: -------------------------------------------------------------------------------- 1 | """ 2 | python generate_ref.py +ts_dataset=beatfine 3 | """ 4 | import torch 5 | from torch.utils.data import DataLoader 6 | import torch.nn.functional as F 7 | from tqdm import tqdm 8 | import time 9 | import hydra 10 | import numpy as np 11 | from omegaconf import OmegaConf 12 | import os 13 | import logging 14 | from pprint import pprint 15 | from code.counter.osu4kcounter import Osu4kTwoBeatOnsetCounter 16 | 17 | logger = logging.getLogger(__file__) 18 | 19 | trunc_how_many = None 20 | global_counter = Osu4kTwoBeatOnsetCounter() 21 | 22 | def generate_2bar_greedy(tokens): 23 | """ 24 | ! invariants 25 | tokens: (1, length_in_beats//2 - 1, 100). includes charts from beat 2. beats [2-4, 4-6, 6-8, ....] 26 | """ 27 | global trunc_how_many 28 | 29 | generated = [] 30 | how_many_twobeats = tokens.shape[1] 31 | 32 | for idx_to_generate_rn in range(0,how_many_twobeats): 33 | logger.info(f"{idx_to_generate_rn}\t\t{tokens[:,idx_to_generate_rn,:]}") 34 | 35 | logger.info(f"precision {None} \trecall {None} \tf1 {None}") 36 | return generated 37 | 38 | 39 | 40 | def run(args): 41 | from code import DatasetSelector 42 | global trunc_how_many 43 | global global_counter 44 | ts_loader = DataLoader( 45 | DatasetSelector(args.ts_dataset)(), 46 | batch_size=1, 47 | shuffle=False 48 | ) 49 | trunc_how_many = 92 50 | 51 | ts_pbar = tqdm(ts_loader) 52 | 53 | for data, tokens, fine_difficulty,songname in ts_pbar: 54 | data = data[:].to('cuda:0').to(torch.long) #(1, 80, 96T) 55 | tokens= tokens.to(torch.long).to('cuda:0') # (1, tokenlength) 56 | fine_difficulty = fine_difficulty.to('cuda:0').to(torch.float).unsqueeze(1) # (1, 1) 57 | 58 | logger.info(songname) 59 | generated = generate_2bar_greedy(tokens) 60 | continue 61 | 62 | print("====subtotl====") 63 | logger.info(f"precision {None} \trecall {None} \tf1 {None}") 64 | 65 | 66 | 67 | def _main(args): 68 | global __file__ 69 | __file__ = hydra.utils.to_absolute_path(__file__) 70 | logger.info("For logs, checkpoints and samples check %s", os.getcwd()) 71 | logger.info(args) 72 | run(args) 73 | 74 | print(__file__) 75 | this_script_dir = os.path.dirname(__file__) 76 | 77 | @hydra.main(version_base=None, config_path=os.path.join(this_script_dir,'conf')) 78 | def main(args): 79 | try: 80 | _main(args) 81 | except Exception: 82 | logger.exception("some error happened") 83 | os._exit(1) 84 | 85 | if __name__=="__main__": 86 | main() 87 | 88 | 89 | -------------------------------------------------------------------------------- /0_ddc/dataset/smdataset/abstime.py: -------------------------------------------------------------------------------- 1 | _EPSILON = 1e-6 2 | 3 | def bpm_to_spb(bpm): 4 | return 60.0 / bpm 5 | 6 | def calc_segment_lengths(bpms): 7 | assert len(bpms) > 0 8 | segment_lengths = [] 9 | for i in xrange(len(bpms) - 1): 10 | spb = bpm_to_spb(bpms[i][1]) 11 | segment_lengths.append(spb * (bpms[i + 1][0] - bpms[i][0])) 12 | return segment_lengths 13 | 14 | def calc_abs_for_beat(offset, bpms, stops, segment_lengths, beat): 15 | bpm_idx = 0 16 | while bpm_idx < len(bpms) and beat + _EPSILON > bpms[bpm_idx][0]: 17 | bpm_idx += 1 18 | bpm_idx -= 1 19 | 20 | stop_len_cumulative = 0.0 21 | for stop_beat, stop_len in stops: 22 | diff = beat - stop_beat 23 | # We are at this stop which should not count to its timing 24 | if abs(diff) < _EPSILON: 25 | break 26 | # We are before this stop 27 | elif diff < 0: 28 | break 29 | # We are above this stop 30 | else: 31 | stop_len_cumulative += stop_len 32 | 33 | full_segment_total = sum(segment_lengths[:bpm_idx]) 34 | partial_segment_spb = bpm_to_spb(bpms[bpm_idx][1]) 35 | partial_segment = partial_segment_spb * (beat - bpms[bpm_idx][0]) 36 | 37 | return full_segment_total + partial_segment - offset + stop_len_cumulative 38 | 39 | def calc_note_beats_and_abs_times(offset, bpms, stops, note_data): 40 | segment_lengths = calc_segment_lengths(bpms) 41 | 42 | # copy bpms 43 | bpms = bpms[:] 44 | inc = None 45 | inc_prev = None 46 | time = offset 47 | 48 | # beat loop 49 | note_beats_abs_times = [] 50 | beat_times = [] 51 | for measure_num, measure in enumerate(note_data): 52 | ppm = len(measure) 53 | for i, code in enumerate(measure): 54 | beat = measure_num * 4.0 + 4.0 * (float(i) / ppm) 55 | # TODO: This could be much more efficient but is not the bottleneck for the moment. 56 | beat_abs = calc_abs_for_beat(offset, bpms, stops, segment_lengths, beat) 57 | note_beats_abs_times.append(((measure_num, ppm, i), beat, beat_abs, code)) 58 | beat_times.append(beat_abs) 59 | 60 | # handle negative stops 61 | beat_time_prev = float('-inf') 62 | del_idxs = [] 63 | for i, beat_time in enumerate(beat_times): 64 | if beat_time_prev > beat_time: 65 | del_idxs.append(i) 66 | else: 67 | beat_time_prev = beat_time 68 | for del_idx in sorted(del_idxs, reverse=True): 69 | del note_beats_abs_times[del_idx] 70 | del beat_times[del_idx] 71 | 72 | #TODO: remove when stable 73 | assert sorted(beat_times) == beat_times 74 | 75 | return note_beats_abs_times 76 | -------------------------------------------------------------------------------- /0_ddc/dataset/dataset_json.py: -------------------------------------------------------------------------------- 1 | if __name__ == '__main__': 2 | import argparse 3 | import os 4 | import random 5 | from util import get_subdirs 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('json_dir', type=str, help='Input JSON dir') 9 | parser.add_argument('--dataset_dir', type=str, help='If specified, use different output dir otherwise JSON dir') 10 | parser.add_argument('--rel', dest='abs', action='store_false', help='If set, output relative paths') 11 | parser.add_argument('--splits', type=str, help='CSV list of split values for datasets (e.g. 0.8,0.1,0.1)') 12 | parser.add_argument('--splitnames', type=str, help='CSV list of split names for datasets (e.g. train,test,eval)') 13 | parser.add_argument('--shuffle', dest='shuffle', action='store_true', help='If set, shuffle dataset before split') 14 | parser.add_argument('--shuffle_seed', type=int, help='If set, use this seed for shuffling') 15 | parser.add_argument('--choose', dest='choose', action='store_true', help='If set, choose from list of packs') 16 | 17 | parser.set_defaults( 18 | dataset_dir='', 19 | abs=True, 20 | splits='1', 21 | splitnames='', 22 | shuffle=False, 23 | shuffle_seed=0, 24 | choose=False) 25 | 26 | args = parser.parse_args() 27 | 28 | splits = [float(x) for x in args.splits.split(',')] 29 | split_names = [x.strip() for x in args.splitnames.split(',')] 30 | assert len(splits) == len(split_names) 31 | 32 | out_dir = args.dataset_dir if args.dataset_dir else args.json_dir 33 | if not os.path.isdir(out_dir): 34 | os.mkdir(out_dir) 35 | 36 | pack_names = get_subdirs(args.json_dir, args.choose) 37 | 38 | for pack_name in pack_names: 39 | pack_dir = os.path.join(args.json_dir, pack_name) 40 | sub_fps = sorted(os.listdir(pack_dir)) 41 | 42 | if args.shuffle: 43 | random.seed(args.shuffle_seed) 44 | random.shuffle(sub_fps) 45 | 46 | if args.abs: 47 | sub_fps = [os.path.abspath(os.path.join(pack_dir, sub_fp)) for sub_fp in sub_fps] 48 | 49 | if len(splits) == 0: 50 | splits = [1.0] 51 | else: 52 | splits = [x / sum(splits) for x in splits] 53 | 54 | split_ints = [int(len(sub_fps) * split) for split in splits] 55 | split_ints[0] += len(sub_fps) - sum(split_ints) 56 | 57 | split_fps = [] 58 | for split_int in split_ints: 59 | split_fps.append(sub_fps[:split_int]) 60 | sub_fps = sub_fps[split_int:] 61 | 62 | for split, splitname in zip(split_fps, split_names): 63 | out_name = '{}{}.txt'.format(pack_name, '_' + splitname if splitname else '') 64 | out_fp = os.path.join(out_dir, out_name) 65 | with open(out_fp, 'w') as f: 66 | f.write('\n'.join(split)) 67 | -------------------------------------------------------------------------------- /0_ddc/dataset/preview_wav.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | from scipy.signal import fftconvolve 5 | from scipy.io.wavfile import write as wavwrite 6 | 7 | def _wav_write(wav_fp, fs, wav_f, normalize=False): 8 | if normalize: 9 | wav_f_max = wav_f.max() 10 | if wav_f_max != 0.0: 11 | wav_f /= wav_f.max() 12 | wav_f = np.clip(wav_f, -1.0, 1.0) 13 | wav = (wav_f * 32767.0).astype(np.int16) 14 | wavwrite(wav_fp, fs, wav) 15 | 16 | # (length, val) pairs 17 | def _linterp(val_start, pts, env_len): 18 | pt_lens = [pt[0] for pt in pts] 19 | pt_vals = [pt[1] for pt in pts] 20 | pt_lens = [int(env_len * (pt_len / sum(pt_lens))) for pt_len in pt_lens] 21 | pt_lens[-1] -= sum(pt_lens) - env_len 22 | env = [] 23 | val_curr = val_start 24 | for pt_len, pt_val in zip(pt_lens, pt_vals): 25 | env.append(np.linspace(val_curr, pt_val, pt_len, endpoint=False)) 26 | val_curr = pt_val 27 | return np.concatenate(env) 28 | 29 | def write_preview_wav(wav_fp, note_beats_and_abs_times, wav_fs=11025.0): 30 | wav_len = int(wav_fs * (note_beats_and_abs_times[-1][1] + 0.05)) 31 | dt = 1.0 / wav_fs 32 | 33 | note_type_to_idx = {} 34 | idx = 0 35 | for _, beat, time, note_type in note_beats_and_abs_times: 36 | if note_type == '0' * len(note_type): 37 | continue 38 | if note_type not in note_type_to_idx: 39 | note_type_to_idx[note_type] = idx 40 | idx += 1 41 | num_note_types = len(note_type_to_idx) 42 | 43 | pulse_f = np.zeros((num_note_types, wav_len)) 44 | 45 | for _, beat, time, note_type in note_beats_and_abs_times: 46 | sample = int(time * wav_fs) 47 | if sample > 0 and sample < wav_len and note_type in note_type_to_idx: 48 | pulse_f[note_type_to_idx[note_type]][sample] = 1.0 49 | 50 | scale = [440.0, 587.33, 659.25, 783.99] 51 | freqs = [scale[i % 4] * math.pow(2.0, (i // 4) - 1) for i in xrange(num_note_types)] 52 | metro_f = np.zeros(wav_len) 53 | for idx in xrange(num_note_types): 54 | click_len = 0.05 55 | click_t = np.arange(0.0, click_len, dt) 56 | click_atk = 0.02 57 | click_sus = 0.5 58 | click_rel = 0.2 59 | click_env = _linterp(0.0, [(click_atk, 1.0), (click_sus, 1.0), (click_rel, 0.0)], len(click_t)) 60 | click_f = click_env * np.sin(2.0 * np.pi * freqs[idx] * click_t) 61 | 62 | metro_f += fftconvolve(pulse_f[idx], click_f, mode='full')[:wav_len] 63 | #metro_f += pulse_f[idx][:wav_len] 64 | 65 | _wav_write(wav_fp, wav_fs, metro_f, normalize=True) 66 | 67 | if __name__ == '__main__': 68 | import json 69 | import sys 70 | 71 | json_fp, wav_fp = sys.argv[1:3] 72 | 73 | with open(json_fp, 'r') as f: 74 | meta = json.loads(f.read()) 75 | 76 | for i, chart in enumerate(meta['charts']): 77 | print '{}: {} {} {}'.format(i, chart['type'], chart['difficulty_fine'], chart['desc_or_author']) 78 | chart_idx = int(raw_input('Which chart? ')) 79 | write_preview_wav(wav_fp, meta['charts'][chart_idx]['notes']) 80 | -------------------------------------------------------------------------------- /osu-to-ddc/osutoddc/converter/converter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import shutil 5 | import json 6 | from tqdm import tqdm 7 | from osutoddc import osu_to_ddc 8 | 9 | def make_new_file(filename, exist_ok=True): 10 | if not exist_ok: 11 | if os.path.isfile(filename): 12 | raise FileExistsError() 13 | with open(filename, 'w') as file: 14 | pass 15 | 16 | def one_file(input,output): 17 | if input==output: 18 | raise ValueError("input must be .osu, output must be .json") 19 | if not input.endswith('.osu'): 20 | raise ValueError("this file is not an .osu file!") 21 | make_new_file(output, exist_ok=True) 22 | d = osu_to_ddc(input) 23 | d = json.dumps(d,indent=4) 24 | with open(output,'w') as file: 25 | file.write(d) 26 | 27 | def one_folder(input,output,keep_osu=True): 28 | if not os.path.isdir(input): 29 | raise ValueError("Not a valid directory") 30 | os.makedirs(output, exist_ok=True) 31 | 32 | if input != output: 33 | shutil.copytree(input,output,dirs_exist_ok=True) 34 | 35 | osu_files = [os.path.join(output,filename) for filename in os.listdir(output) 36 | if filename.endswith(".osu")] 37 | for filename in osu_files: 38 | d = osu_to_ddc(filename) 39 | d = json.dumps(d,indent=4) 40 | with open(f"{filename}.json",'w') as file: 41 | file.write(d) 42 | if not keep_osu: 43 | os.remove(filename) 44 | 45 | def many_folder(input,output,keep_osu=True): 46 | if not os.path.isdir(input): 47 | raise ValueError("Not a valid directory") 48 | if input != output: 49 | os.makedirs(output,exist_ok=True) 50 | shutil.copytree(input,output,dirs_exist_ok=True) 51 | 52 | folders = [os.path.join(output, e) for e in os.listdir(output) 53 | if os.path.isdir(os.path.join(output,e))] 54 | for folder in tqdm(folders): 55 | try: 56 | one_folder(folder,folder,keep_osu=keep_osu) 57 | except Exception: 58 | print("error with: ",input) 59 | 60 | def run(input,output,keep_osu=True): 61 | if os.path.isfile(input): 62 | one_file(input,output) 63 | elif os.path.isdir(input): 64 | listing =[os.path.join(input,e) for e in os.listdir(input) 65 | if e.endswith(".osu")] 66 | if len(listing)>0: 67 | one_folder(input,output) 68 | else: 69 | many_folder(input,output,keep_osu) 70 | else: 71 | raise ValueError("input does not exist?") 72 | 73 | if __name__=="__main__": 74 | parser = argparse.ArgumentParser( 75 | description="utility to convert osu mania beatmaps to ddc-like json", 76 | epilog="for more information please contact jayeonyi@umich.edu") 77 | parser.add_argument("path",help="path to file or folder to convert to ddc json format") 78 | parser.add_argument("output_path",help="output path") 79 | parser.add_argument("--keep-osu",action='store_true',default=True) 80 | args = parser.parse_args() 81 | print(args) 82 | 83 | run(args.path, args.output_path,keep_osu=args.keep_osu) -------------------------------------------------------------------------------- /2_generate_dataset/h5pyize_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | make every song into a melspectrogram with 80 bins, and then packs them into one .h5 file 3 | care is taken so that access time can be optimized. 4 | (turns out transposing the mel before saving is about 10x faster, so we're sticking with it (see below).) 5 | 6 | written by Jayeon Yi (stet-stet) 7 | """ 8 | import h5py 9 | import os 10 | import sys 11 | import scipy 12 | import json 13 | from tqdm import tqdm 14 | import numpy as np 15 | import soundfile as sf 16 | import librosa 17 | 18 | 19 | def make_mel_for_whole_file(osujson, fft_size=512, n_mels=80): 20 | beatjson = osujson + ".beat.json" 21 | osujson = beatjson.replace(".osu.json.beat.json",'.osu.json') 22 | beatmapset_path = os.path.dirname(osujson) 23 | with open(osujson) as file: 24 | osujson = json.load(file) 25 | with open(beatjson) as file: 26 | beatjson = json.load(file) 27 | beat_to_sample_map = beatjson['beat_to_sample_map'] 28 | if osujson['music_fp'].startswith("/") or osujson['music_fp'][1] == "C:": 29 | music_fp = osujson['music_fp'] 30 | else: 31 | music_fp = os.path.join(beatmapset_path, osujson['music_fp']) 32 | y, sr = sf.read(music_fp) 33 | try: 34 | y = y.mean(axis=1) 35 | except np.AxisError as e: 36 | pass # already in 1-dim 37 | # OFFSET! 38 | if osujson['offset'] < 0: 39 | left_pad = abs(round(osujson['offset']*sr/1000)) 40 | left_deduct = 0 41 | else: 42 | left_pad = 0 43 | left_deduct = abs(round(osujson['offset']*sr/1000)) 44 | right_pad = round( fft_size + sr*4*(60 / osujson['bpms'][-1][-1])) 45 | y = np.pad(y, (left_pad, right_pad)) 46 | y = y[left_deduct:] 47 | 48 | to_take = [list(range(sample, sample+fft_size)) for _,sample in beat_to_sample_map] 49 | to_fft = np.take(y, to_take) 50 | fftd = np.power(np.absolute(scipy.fft.rfft(to_fft)),2) 51 | melspec = librosa.feature.melspectrogram(S=fftd.T,n_mels=n_mels) 52 | return melspec 53 | 54 | 55 | def h5py_test(filename,attr): 56 | dset = h5py.File(filename,'r') 57 | print(dset[attr]) 58 | 59 | def pack_dset_into_h5py(split_json, transpose=False): 60 | with open(split_json) as file: 61 | split = json.load(file) 62 | split_name = os.path.splitext(split_json)[0].split('/')[-1] 63 | dir_of_dirs = os.path.dirname(split_json) 64 | h5f = h5py.File(os.path.join(dir_of_dirs,f'{split_name}.h5'),'w') 65 | 66 | pbar = tqdm(list(split.keys())) 67 | for osujson in pbar: 68 | pbar.set_description(os.path.basename(osujson)) 69 | dataset_name = osujson.replace(dir_of_dirs+'/','') 70 | mel = np.float32(make_mel_for_whole_file(osujson)) 71 | pbar.set_description(dataset_name) 72 | if transpose: 73 | h5f.create_dataset(dataset_name,data=mel.T) 74 | else: 75 | h5f.create_dataset(dataset_name,data=mel) 76 | # h5f.create_dataset(dataset_name,data=mel.T) <- [:, a:b] are training samples; we want this segment to be consecutive 77 | h5f.close() 78 | 79 | if __name__=="__main__": 80 | pack_dset_into_h5py(sys.argv[1], transpose=True) 81 | #h5py_test('valid.h5',"153853/Trident - Blue Field (arronchu1207) [Shana's HD].osu.json") 82 | 83 | -------------------------------------------------------------------------------- /code/loader/ddcdataset.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import json 5 | import time 6 | import numpy as np 7 | import random 8 | import h5py 9 | from tqdm import tqdm 10 | 11 | def round_to_nearest_48th(num): 12 | return round(num*48)/48 13 | 14 | class DDCDataset(): 15 | def __init__(self, split_json_path="train.json", onset_hdf5_path="all.h5", music_hdf5_path="all_onset.h5", 16 | min_diff = 0.0, max_diff = 100.0, overlaps=True) -> None: 17 | self.split_json_path = split_json_path 18 | self.onset_hdf5_path = onset_hdf5_path 19 | self.music_hdf5_path = music_hdf5_path 20 | self.biggest_dir = os.path.dirname(split_json_path) 21 | with open(split_json_path) as file: 22 | self.split_dict = json.loads(file.read()) 23 | self.music_hdf5 = h5py.File(music_hdf5_path, 'r') 24 | self.onset_hdf5 = h5py.File(onset_hdf5_path, 'r') 25 | self.dir_of_dirs = os.path.dirname(split_json_path) 26 | 27 | self.index = [] 28 | 29 | self.frames = 112 # non-configurable. 30 | for osujson, _ in self.split_dict.items(): 31 | diff = self._get_diff(osujson) 32 | if min_diff > diff or max_diff < diff: continue 33 | d = self.onset_hdf5[self._dset_name_for_onset(osujson)][:] 34 | start_ind = np.argmax(d == True) 35 | start_ind = max(start_ind - self.frames, 0) 36 | end_ind = len(d) - np.argmax(d[::-1] == True) - 1 37 | end_ind = min(end_ind, len(d) - self.frames) 38 | if overlaps: 39 | self.index.extend([(osujson,meter) for meter in range(start_ind, end_ind+1)]) 40 | else: 41 | offset = random.randint(0, self.frames) 42 | self.index.extend([(osujson,meter) for meter in range(start_ind + offset, end_ind+1, self.frames)]) 43 | 44 | self.total_length = len(self.index) 45 | 46 | def _get_diff(self,osujson_fn): 47 | with open(osujson_fn) as file: 48 | a = json.load(file) 49 | return float(a['charts'][0]["difficulty_fine"]) 50 | 51 | def __len__(self): 52 | return self.total_length 53 | 54 | def _dset_name_for_audio(self, audiopath): 55 | # "foldername". 56 | ret = str(os.path.dirname(audiopath)) 57 | ret = ret.replace(str(self.biggest_dir), "") 58 | if ret[0] == '/': ret= ret[1:] 59 | return ret 60 | 61 | def _dset_name_for_onset(self, osujsonpath): 62 | ret = str(osujsonpath).replace(self.biggest_dir,'') 63 | if ret[0] == '/': 64 | ret= ret[1:] 65 | return ret 66 | 67 | def __getitem__(self,idx): 68 | """ 69 | what to return: 70 | - the music for this beats 71 | - the onsets i.e. the goals 72 | - fine diff 73 | """ 74 | # TODO randomize 75 | if idx > self.total_length: 76 | raise ValueError("idx over total length") 77 | osujson_fn, frame = self.index[idx] 78 | with open(osujson_fn) as file: 79 | osujson = json.load(file) 80 | 81 | # fine diff 82 | try: 83 | fine_difficulty = float(osujson['charts'][0]['difficulty_fine']) 84 | except Exception as e: 85 | print(osujson_fn) 86 | raise e 87 | 88 | music_dset_name = self._dset_name_for_audio(osujson_fn) # nasty, but marginally acceptable. 89 | mel = self.music_hdf5[music_dset_name][frame:frame+self.frames, :, :] # (112, 80, 3) 90 | mel_to_return = mel 91 | #mel_to_return = mel[:].transpose((2,0,1)) # (3, 80, 112) 92 | 93 | # import onsets 94 | 95 | onset_dset_name = self._dset_name_for_onset(osujson_fn) 96 | onsets = self.onset_hdf5[onset_dset_name][frame:frame+self.frames] 97 | 98 | return mel_to_return, onsets, fine_difficulty 99 | 100 | def size_test(split="train",**params): 101 | print(f"=============================== split = {split}") 102 | split_json_path = f"OSUFOLDER/{split}.json" 103 | music_hdf5_path = f"OSUFOLDER/all_ddc.h5" 104 | onset_hdf5_path = f"OSUFOLDER/all_onset.h5" 105 | 106 | dset = DDCDataset(split_json_path=split_json_path, music_hdf5_path=music_hdf5_path, onset_hdf5_path=onset_hdf5_path, **params) 107 | sz = len(dset) 108 | print(dset[0]) 109 | for i in tqdm(range(sz)): 110 | dset[i] 111 | 112 | 113 | if __name__=="__main__": 114 | size_test() -------------------------------------------------------------------------------- /osu-to-ddc/.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/python,visualstudiocode 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,visualstudiocode 3 | 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | 166 | ### Python Patch ### 167 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 168 | poetry.toml 169 | 170 | # ruff 171 | .ruff_cache/ 172 | 173 | # LSP config files 174 | pyrightconfig.json 175 | 176 | ### VisualStudioCode ### 177 | .vscode/* 178 | !.vscode/settings.json 179 | !.vscode/tasks.json 180 | !.vscode/launch.json 181 | !.vscode/extensions.json 182 | !.vscode/*.code-snippets 183 | 184 | # Local History for Visual Studio Code 185 | .history/ 186 | 187 | # Built Visual Studio Code Extensions 188 | *.vsix 189 | 190 | ### VisualStudioCode Patch ### 191 | # Ignore all local history of files 192 | .history 193 | .ionide 194 | 195 | # End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode -------------------------------------------------------------------------------- /0_ddc/dataset/analyze_json.py: -------------------------------------------------------------------------------- 1 | if __name__ == '__main__': 2 | import argparse 3 | from collections import Counter, defaultdict 4 | import json 5 | import os 6 | from util import get_subdirs 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('dataset_fps', type=str, nargs='+', help='List of dataset filepaths to analyze') 10 | parser.add_argument('--diff', type=str, help='If provided, only analyze charts of this difficulty') 11 | 12 | parser.set_defaults( 13 | diff='') 14 | 15 | args = parser.parse_args() 16 | 17 | json_fps = [] 18 | for dataset_fp in args.dataset_fps: 19 | with open(dataset_fp, 'r') as f: 20 | json_fps += f.read().splitlines() 21 | 22 | chart_types = Counter() 23 | chart_diff_coarse = Counter() 24 | chart_feet = Counter() 25 | chart_coarse_to_stream = {} 26 | chart_freetexts = Counter() 27 | beat_phases = Counter() 28 | vocab = Counter() 29 | songs_time_annotated = 0.0 30 | charts_time_annotated = 0.0 31 | stream_total = 0.0 32 | feet_total = 0.0 33 | arrows_total = 0 34 | chart_coarse_to_superset = defaultdict(list) 35 | for json_fp in json_fps: 36 | with open(json_fp, 'r') as f: 37 | song_meta = json.loads(f.read()) 38 | 39 | max_time_annotated = -1.0 40 | coarse_to_beats = defaultdict(set) 41 | for chart_meta in song_meta['charts']: 42 | if args.diff and chart_meta['difficulty_coarse'] != args.diff: 43 | continue 44 | 45 | coarse = chart_meta['difficulty_coarse'] 46 | feet = chart_meta['difficulty_fine'] 47 | feet_total += feet 48 | 49 | chart_types[chart_meta['type']] += 1 50 | chart_diff_coarse[coarse] += 1 51 | chart_feet[feet] += 1 52 | chart_freetexts[chart_meta['desc_or_author']] += 1 53 | 54 | num_arrows = 0 55 | for _, beat, time, arrow in chart_meta['notes']: 56 | beat_phase = beat - int(beat) 57 | beat_phase = int(beat_phase * 100.0) / 100.0 58 | beat_phases[beat_phase] += 1 59 | vocab[arrow] += 1 60 | if arrow != '0' * len(arrow): 61 | num_arrows += 1 62 | coarse_to_beats[coarse].add(beat) 63 | arrows_total += num_arrows 64 | 65 | chart_time_annotated = chart_meta['notes'][-1][2] - chart_meta['notes'][0][2] 66 | if chart_time_annotated > max_time_annotated: 67 | max_time_annotated = chart_time_annotated 68 | charts_time_annotated += chart_time_annotated 69 | 70 | stream = num_arrows / chart_time_annotated 71 | stream_total += stream 72 | if feet not in chart_coarse_to_stream: 73 | chart_coarse_to_stream[coarse] = [] 74 | chart_coarse_to_stream[coarse].append(stream) 75 | 76 | songs_time_annotated += max_time_annotated 77 | 78 | coarses = ['Beginner', 'Easy', 'Medium', 'Hard', 'Challenge'] 79 | for i, coarse in enumerate(coarses): 80 | for coarse_next in coarses: 81 | beats = coarse_to_beats[coarse] 82 | beats_next = coarse_to_beats[coarse_next] 83 | chart_coarse_to_superset[(coarse, coarse_next)].append(len(beats & beats_next) / float(len(beats))) 84 | 85 | chart_coarse_to_stream = {k: sum(l) / len(l) for k, l in chart_coarse_to_stream.items()} 86 | chart_coarse_to_superset = {k: (reduce(lambda x, y: x + y, l) / len(l)) for k, l in chart_coarse_to_superset.items()} 87 | 88 | nsongs = len(json_fps) 89 | ncharts = sum(chart_feet.values()) 90 | print ','.join(args.dataset_fps) 91 | print 'Num songs: {}'.format(nsongs) 92 | print 'Total music annotated (s): {}'.format(songs_time_annotated) 93 | print 'Avg song length (s): {}'.format(songs_time_annotated / nsongs) 94 | 95 | print 'Num charts: {}'.format(ncharts) 96 | print 'Avg num charts per song: {}'.format(float(ncharts) / nsongs) 97 | print 'Total chart time annotated (s): {}'.format(charts_time_annotated) 98 | print 'Avg chart length (s): {}'.format(charts_time_annotated / ncharts) 99 | print 'Avg chart length (steps): {}'.format(float(arrows_total) / ncharts) 100 | 101 | print 'Chart types: {}'.format(chart_types) 102 | print 'Chart coarse difficulties: {}'.format(chart_diff_coarse) 103 | print 'Chart feet: {}'.format(chart_feet) 104 | print 'Chart coarse avg arrows per second: {}'.format(chart_coarse_to_stream) 105 | print 'Chart coarse avg superset: {}'.format(chart_coarse_to_superset) 106 | print 'Chart freetext fields: {}'.format(chart_freetexts) 107 | print 'Chart vocabulary (size={}): {}'.format(len(vocab), vocab) 108 | print 'Beat phases: {}'.format(beat_phases) 109 | 110 | print 'Avg feet: {}'.format(feet_total / ncharts) 111 | print 'Avg arrows per second: {}'.format(stream_total / ncharts) 112 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,jupyternotebooks 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,visualstudiocode,jupyternotebooks 3 | 4 | ### JupyterNotebooks ### 5 | # gitignore template for Jupyter Notebooks 6 | # website: http://jupyter.org/ 7 | 8 | .ipynb_checkpoints 9 | */.ipynb_checkpoints/* 10 | 11 | # IPython 12 | profile_default/ 13 | ipython_config.py 14 | 15 | # Remove previous ipynb_checkpoints 16 | # git rm -r .ipynb_checkpoints/ 17 | 18 | ### Python ### 19 | # Byte-compiled / optimized / DLL files 20 | __pycache__/ 21 | *.py[cod] 22 | *$py.class 23 | 24 | # C extensions 25 | *.so 26 | 27 | # Distribution / packaging 28 | .Python 29 | build/ 30 | develop-eggs/ 31 | dist/ 32 | downloads/ 33 | eggs/ 34 | .eggs/ 35 | lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | wheels/ 41 | share/python-wheels/ 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | MANIFEST 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | 57 | # Unit test / coverage reports 58 | htmlcov/ 59 | .tox/ 60 | .nox/ 61 | .coverage 62 | .coverage.* 63 | .cache 64 | nosetests.xml 65 | coverage.xml 66 | *.cover 67 | *.py,cover 68 | .hypothesis/ 69 | .pytest_cache/ 70 | cover/ 71 | 72 | # Translations 73 | *.mo 74 | *.pot 75 | 76 | # Django stuff: 77 | *.log 78 | local_settings.py 79 | db.sqlite3 80 | db.sqlite3-journal 81 | 82 | # Flask stuff: 83 | instance/ 84 | .webassets-cache 85 | 86 | # Scrapy stuff: 87 | .scrapy 88 | 89 | # Sphinx documentation 90 | docs/_build/ 91 | 92 | # PyBuilder 93 | .pybuilder/ 94 | target/ 95 | 96 | # Jupyter Notebook 97 | 98 | # IPython 99 | 100 | # pyenv 101 | # For a library or package, you might want to ignore these files since the code is 102 | # intended to run in multiple environments; otherwise, check them in: 103 | # .python-version 104 | 105 | # pipenv 106 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 107 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 108 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 109 | # install all needed dependencies. 110 | #Pipfile.lock 111 | 112 | # poetry 113 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 114 | # This is especially recommended for binary packages to ensure reproducibility, and is more 115 | # commonly ignored for libraries. 116 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 117 | #poetry.lock 118 | 119 | # pdm 120 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 121 | #pdm.lock 122 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 123 | # in version control. 124 | # https://pdm.fming.dev/#use-with-ide 125 | .pdm.toml 126 | 127 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 128 | __pypackages__/ 129 | 130 | # Celery stuff 131 | celerybeat-schedule 132 | celerybeat.pid 133 | 134 | # SageMath parsed files 135 | *.sage.py 136 | 137 | # Environments 138 | .env 139 | .venv 140 | env/ 141 | venv/ 142 | ENV/ 143 | env.bak/ 144 | venv.bak/ 145 | 146 | # Spyder project settings 147 | .spyderproject 148 | .spyproject 149 | 150 | # Rope project settings 151 | .ropeproject 152 | 153 | # mkdocs documentation 154 | /site 155 | 156 | # mypy 157 | .mypy_cache/ 158 | .dmypy.json 159 | dmypy.json 160 | 161 | # Pyre type checker 162 | .pyre/ 163 | 164 | # pytype static type analyzer 165 | .pytype/ 166 | 167 | # Cython debug symbols 168 | cython_debug/ 169 | 170 | # PyCharm 171 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 172 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 173 | # and can be added to the global gitignore or merged into this file. For a more nuclear 174 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 175 | #.idea/ 176 | 177 | ### Python Patch ### 178 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 179 | poetry.toml 180 | 181 | # ruff 182 | .ruff_cache/ 183 | 184 | # LSP config files 185 | pyrightconfig.json 186 | 187 | ### VisualStudioCode ### 188 | .vscode/* 189 | !.vscode/settings.json 190 | !.vscode/tasks.json 191 | !.vscode/launch.json 192 | !.vscode/extensions.json 193 | !.vscode/*.code-snippets 194 | 195 | # Local History for Visual Studio Code 196 | .history/ 197 | 198 | # Built Visual Studio Code Extensions 199 | *.vsix 200 | 201 | ### VisualStudioCode Patch ### 202 | # Ignore all local history of files 203 | .history 204 | .ionide 205 | 206 | ### h5py 207 | # h5 files 208 | *.h5 209 | dataset_validity.ipynb 210 | outputs/ 211 | ckpts/ 212 | arxivckpts/ 213 | ckpt*/ 214 | *.pth 215 | wandb/ 216 | .vscode/ 217 | # End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,jupyternotebooks 218 | *.zip -------------------------------------------------------------------------------- /1_preindex_similarity_matrix/make_similarity_matrix.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import copy 5 | import scipy 6 | # torch dependency is removed 7 | import numpy as np 8 | import soundfile as sf 9 | 10 | def read_ddc_file(filename): 11 | if not filename.endswith(".json"): 12 | raise ValueError("specified file is not ddc file") 13 | with open(filename,'r',errors='replace') as file: 14 | d = json.loads(file.read()) 15 | return d 16 | 17 | def mspb(bpm): 18 | return 60000 / bpm 19 | 20 | def ms_per_48th_beat(bpm): 21 | return 60000 / bpm / 48 22 | 23 | def round_to_nearest_48th(number): 24 | return round(number * 48)/48 25 | 26 | def increment_by_48th(number): 27 | return round_to_nearest_48th(number+1/48) 28 | 29 | def time_to_sample(time_in_ms, sr, time_mode="sample"): 30 | if time_mode == "sample": 31 | return round(time_in_ms * sr / 1000) 32 | elif time_mode == "time": 33 | return time_in_ms 34 | else: 35 | raise ValueError("invalid mode") 36 | 37 | def get_beat_to_sample_map(bpms=None, sr=44100, make_until=1000, time_mode="sample"): 38 | """ 39 | bpms is given as a List of two-element lists: [beat, bpm] 40 | 41 | returns: a list of tuples: (beat(48th), sample num) 42 | """ 43 | if bpms is None: 44 | raise ValueError("param 'bpms' cannot be empty") 45 | 46 | ret = [(0.,0)] 47 | current_beat = 0 48 | current_bpm = bpms[0][-1] 49 | current_time = 0. 50 | 51 | bpms = copy.deepcopy(bpms[1:]) + [[float(make_until)-1/48,bpms[-1][-1]]] 52 | bpms = [e for e in bpms if e[0] < make_until] # For cases when chart is shorter than the bpm changes denoted. 53 | for bpm in bpms: 54 | delimiter_beat, next_bpm = bpm 55 | delimiter_beat = round_to_nearest_48th(delimiter_beat) 56 | while current_beat < delimiter_beat: 57 | if increment_by_48th(current_beat) < delimiter_beat: 58 | current_beat = increment_by_48th(current_beat) 59 | current_time += ms_per_48th_beat(current_bpm) 60 | ret.append((current_beat, time_to_sample(current_time, sr, time_mode=time_mode))) 61 | else: 62 | current_percentage = ((delimiter_beat) - (current_beat)) / (1/48) 63 | next_percentage = 1 - current_percentage 64 | current_beat = increment_by_48th(current_beat) 65 | current_time += current_percentage * ms_per_48th_beat(current_bpm) + \ 66 | next_percentage * ms_per_48th_beat(next_bpm) 67 | ret.append((current_beat, time_to_sample(current_time, sr,time_mode=time_mode))) 68 | 69 | current_bpm = next_bpm 70 | # TODO: catch corner cases where two BPM changes are b2b less than 1/48th beat away 71 | # (we removed them when making dataset, but might be worthwhile to keep them) 72 | return ret 73 | 74 | def make_indexes_from_beat_to_sample_map(beat_to_sample_map, fft_size=512): 75 | return [list(range(e,e+fft_size)) for _,e in beat_to_sample_map] 76 | 77 | def make_beatwise_similarity_matrix_and_beat_to_sample_map(json_filepath, fft_size=512, no_matrix=False): 78 | # NOTE 79 | # there used to be a code that makes similarity matrix here, but 80 | # since the LBD does not use this (yet), we removed the code. 81 | # 82 | d = read_ddc_file(json_filepath) 83 | json_path = os.path.dirname(json_filepath) 84 | if d["music_fp"].startswith("/") or d["music_fp"].startswith("C:") : 85 | audio_filepath = d["music_fp"] 86 | else: 87 | audio_filepath = os.path.join(json_path, d["music_fp"]) 88 | y, sr = sf.read(audio_filepath) 89 | try: 90 | y = y.mean(axis=1) 91 | except np.AxisError as e: 92 | pass # this is already 1-dim. 93 | # account for offset 94 | left_pad = abs(round(d['offset']*sr/1000)) 95 | right_pad = round(fft_size + sr * 4 * mspb(d['bpms'][-1][-1]) /1000) 96 | y = np.pad(y, (left_pad,right_pad)) 97 | 98 | max_beat = max([a for _,a,_,_ in d['charts'][0]['notes']]) 99 | length_in_beats = int(max_beat) + 1 100 | beat_to_sample_map = get_beat_to_sample_map(bpms=d['bpms'], sr=sr, make_until=length_in_beats) 101 | 102 | return np.zeros([length_in_beats,length_in_beats]), beat_to_sample_map 103 | 104 | 105 | ############################################## FOR TESTING ############################################## 106 | def test_bpms_to_time(json_filepath, fft_size=512): 107 | d = read_ddc_file(json_filepath) 108 | json_path = os.path.dirname(json_filepath) 109 | audio_filepath = os.path.join(json_path, d["music_fp"]) 110 | y, sr = sf.read(audio_filepath) 111 | y = y.mean(axis=1) 112 | # account for offset 113 | left_pad = abs(round(d['offset']*sr/1000)) 114 | right_pad = round(fft_size + sr * 4 * mspb(d['bpms'][-1][-1]) /1000) 115 | y = np.pad(y, (left_pad,right_pad)) 116 | 117 | max_beat = max([a for _,a,_,_ in d['charts'][0]['notes']]) 118 | length_in_beats = int(max_beat) + 1 119 | beat_to_sample_map = get_beat_to_sample_map(bpms=d['bpms'], sr=sr, make_until=length_in_beats, time_mode="time") 120 | beat_to_sample_map = [e for e in beat_to_sample_map if e[0] == float(int(e[0]))] 121 | beat_to_sample_increments = [] 122 | for past,future in zip(beat_to_sample_map[:-1],beat_to_sample_map[1:]): 123 | beat_to_sample_increments.append((future[0], future[1]-past[1])) 124 | from pprint import pprint 125 | pprint(beat_to_sample_increments) 126 | # examined soflanchan output. no problems anywhere... 127 | 128 | 129 | if __name__=="__main__": 130 | # make_beatwise_similarity_matrix(sys.argv[1]) 131 | test_bpms_to_time(sys.argv[1]) 132 | -------------------------------------------------------------------------------- /code/counter/osu4kcounter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from .vocabulary import OneThirtyVocabulary, OneSeventyEightVocabulary 4 | 5 | 6 | class OneSeventyEightCounter(): 7 | def __init__(self): 8 | self.voca = OneSeventyEightVocabulary() 9 | self.reset() 10 | 11 | def reset(self): 12 | self.all_notes = 0 13 | self.true_positive = 0 14 | self.false_positive = 0 15 | self.false_negative = 0 16 | 17 | def stats_with(self, other_number): 18 | if self.true_positive + other_number == 0: 19 | return 0 20 | return self.true_positive / (self.true_positive + other_number) 21 | 22 | def recall(self): 23 | return self.stats_with(self.false_negative) 24 | 25 | def precision(self): 26 | return self.stats_with(self.false_positive) 27 | 28 | def f1(self): 29 | try: 30 | return 2 * self.precision() * self.recall() / (self.precision() + self.recall()) 31 | except ZeroDivisionError: 32 | return 0 33 | 34 | def decipher_token(self, token): 35 | token -= 96 36 | ret = [(token//27),(token//9)%3,(token//3)%3,token%3] 37 | ret = [str(x) for x in ret] 38 | return "".join(ret) 39 | 40 | class Osu4kTwoBeatOnsetCounter(OneSeventyEightCounter): 41 | def __init__(self): 42 | super().__init__() 43 | 44 | def update(self, ref, pred): 45 | board_ref = [ "0000" for _ in range(96) ] 46 | board_pred = [ "0000" for _ in range(96)] 47 | mode = 0 48 | for board, tokens in [(board_ref, ref),(board_pred, pred)]: 49 | if len(tokens) == 0: 50 | continue 51 | for token in tokens: 52 | token = int(token) 53 | assert token != self.voca.beat_token 54 | if token == self.voca.bar_token: 55 | break 56 | elif token in self.voca.time_token.keys(): 57 | mode = int(token) 58 | elif token in self.voca.chart_token.values(): 59 | board[mode] = self.decipher_token(int(token)) 60 | 61 | for keys_ref, keys_pred in zip(board_ref, board_pred): 62 | assert len(keys_ref) == len(keys_pred) 63 | note_ref, note_pred = int(keys_ref), int(keys_pred) 64 | if note_ref == 0 and note_pred == 0: 65 | continue 66 | elif note_ref == 0 and note_pred != 0: 67 | self.false_positive += 1 68 | elif note_ref != 0 and note_pred == 0: 69 | self.false_negative += 1 70 | elif note_ref != 0 and note_pred != 0: 71 | self.true_positive += 1 72 | 73 | class Osu4kTwoBeatTimingCounter(OneSeventyEightCounter): 74 | def __init__(self): 75 | super().__init__() 76 | 77 | def update(self, ref, pred): 78 | board_ref = [ 0 for _ in range(96) ] 79 | board_pred = [ 0 for _ in range(96)] 80 | mode = 0 81 | for board, tokens in [(board_ref, ref),(board_pred, pred)]: 82 | if len(tokens) == 0: 83 | continue 84 | for token in tokens: 85 | token = int(token) 86 | assert token != self.voca.beat_token 87 | if token == self.voca.bar_token: 88 | break 89 | elif token in self.voca.time_token.keys(): 90 | board[int(token)] = 1 91 | 92 | for note_ref, note_pred in zip(board_ref, board_pred): 93 | if note_ref == 0 and note_pred == 0: 94 | continue 95 | elif note_ref == 0 and note_pred != 0: 96 | self.false_positive += 1 97 | elif note_ref != 0 and note_pred == 0: 98 | self.false_negative += 1 99 | elif note_ref != 0 and note_pred != 0: 100 | self.true_positive += 1 101 | 102 | class Osu4kTwoBeatNotesCounter(OneSeventyEightCounter): 103 | def __init__(self): 104 | super().__init__() 105 | 106 | def update(self, ref, pred): 107 | board_ref = [ "0000" for _ in range(96) ] 108 | board_pred = [ "0000" for _ in range(96)] 109 | mode = 0 110 | for board, tokens in [(board_ref, ref),(board_pred, pred)]: 111 | if len(tokens) == 0: 112 | continue 113 | for token in tokens: 114 | token = int(token) 115 | assert token != self.voca.beat_token 116 | if token == self.voca.bar_token: 117 | break 118 | elif token in self.voca.time_token.keys(): 119 | mode = int(token) 120 | elif token in self.voca.chart_token.values(): 121 | board[mode] = self.decipher_token(int(token)) 122 | 123 | for keys_ref, keys_pred in zip(board_ref, board_pred): 124 | assert len(keys_ref) == len(keys_pred) 125 | for note_ref, note_pred in zip(keys_ref, keys_pred): 126 | note_ref, note_pred = int(note_ref), int(note_pred) 127 | if note_ref == 0 and note_pred == 0: 128 | continue 129 | elif note_ref == 0 and note_pred != 0: 130 | self.false_positive += 1 131 | elif note_ref != 0 and note_pred == 0: 132 | self.false_negative += 1 133 | elif note_ref != 0 and note_pred != 0: 134 | self.true_positive += 1 135 | 136 | 137 | 138 | 139 | def test(): 140 | counter = Osu4kTwoBeatCounter() 141 | counter.update([0, 97, 16, 98, 48, 123, 72, 167], [0, 97, 16, 98, 48, 124, 72, 173]) 142 | print(counter.precision()) 143 | print(counter.recall()) 144 | print(counter.f1()) 145 | 146 | if __name__=="__main__": 147 | test() 148 | -------------------------------------------------------------------------------- /metrics_cond_centered_AR.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file partly used to make Table 3. For the numbers, generated/millin_and_anmillin.ipynb was used to make them. 3 | """ 4 | import torch 5 | from torch.utils.data import DataLoader 6 | import torch.nn.functional as F 7 | from tqdm import tqdm 8 | import time 9 | import hydra 10 | import numpy as np 11 | from omegaconf import OmegaConf 12 | import os 13 | import logging 14 | from pprint import pprint 15 | from code.counter.osu4kcounter import Osu4kTwoBeatOnsetCounter, Osu4kTwoBeatTimingCounter 16 | 17 | logger = logging.getLogger(__file__) 18 | 19 | trunc_how_many = None 20 | global_counter = Osu4kTwoBeatOnsetCounter() 21 | 22 | def generate_2bar_greedy(model, data, tokens, fine_difficulty, max_gen_length=200): 23 | """ 24 | ! invariants 25 | model, data, tokens are already in cuda device 26 | data: (1, 80, 96T) 27 | tokens: (1, length_in_beats//2 - 1, 100). includes charts from beat 2. beats [2-4, 4-6, 6-8, ....] 28 | fine_diff: (1,1) 29 | data has dim of (80, L) 30 | """ 31 | global trunc_how_many 32 | global global_counter 33 | 34 | local_counter = Osu4kTwoBeatOnsetCounter() 35 | 36 | 37 | decoder_macro_input = [177 for _ in range(max_gen_length//2-1)] + [96,96] + [177 for _ in range(max_gen_length//2-1)] 38 | decoder_macro_input = decoder_macro_input[trunc_how_many:] 39 | decoder_macro_input = torch.Tensor(decoder_macro_input).to(torch.long).to('cuda:0').unsqueeze(0) 40 | generated = [] 41 | how_many_twobeats = tokens.shape[1] 42 | 43 | for idx_to_generate_rn in range(0,how_many_twobeats): 44 | retrieve_this = (max_gen_length // 2) - trunc_how_many 45 | where_is_96 = int(retrieve_this) 46 | audio_this_cycle = data[: , :, 96*(idx_to_generate_rn):96*(idx_to_generate_rn+2)] 47 | tokens_this_cycle = decoder_macro_input[:,:] 48 | if audio_this_cycle.shape[2] == 0: break 49 | while retrieve_this < 199 - trunc_how_many: # nothing's gonna reach 100 anyways, so cutting it short should be fine...right? 50 | model_output = model(audio_this_cycle, tokens_this_cycle, fine_difficulty) 51 | created_tokens = torch.argmax(model_output, 2).to(torch.long) 52 | if int(created_tokens[0,retrieve_this]) == 177: 53 | break # input_this_cycle is the answer for this cycle. 54 | else: 55 | tokens_this_cycle[0,retrieve_this+1] = created_tokens[0,retrieve_this] 56 | retrieve_this += 1 57 | # split tokens by 96 58 | created_tokens = created_tokens[:, where_is_96:retrieve_this] 59 | local_counter.update(tokens[0,idx_to_generate_rn,:],created_tokens[0,:]) 60 | global_counter.update(tokens[0,idx_to_generate_rn,:],created_tokens[0,:]) 61 | generated.append(created_tokens) 62 | logger.info(f"{idx_to_generate_rn}\t\t{created_tokens}") 63 | created_tokens_length = created_tokens.shape[1] 64 | 65 | decoder_macro_input = torch.cat(( 66 | torch.Tensor([[177 for _ in range(max_gen_length//2-1-created_tokens_length)]]).to(torch.long).to('cuda'), 67 | torch.Tensor([[96]]).to(torch.long).to('cuda'), 68 | created_tokens, 69 | torch.Tensor([[96]]).to(torch.long).to('cuda'), 70 | torch.Tensor([[177 for _ in range(max_gen_length//2-1)]]).to(torch.long).to('cuda') 71 | ),axis=1) 72 | how_many_pad_on_right = max_gen_length - decoder_macro_input.shape[1] 73 | decoder_macro_input = F.pad(decoder_macro_input, (0, how_many_pad_on_right), 'constant', 177) 74 | decoder_macro_input = decoder_macro_input[:,trunc_how_many:] 75 | 76 | logger.info(f"precision {local_counter.precision()} \trecall {local_counter.recall()} \tf1 {local_counter.f1()}") 77 | assert len(generated) == how_many_twobeats or len(generated) == how_many_twobeats-1 78 | return generated 79 | 80 | 81 | 82 | def run(args): 83 | from code import DatasetSelector, ModelSelector 84 | global trunc_how_many 85 | global global_counter 86 | ts_loader = DataLoader( 87 | DatasetSelector(args.ts_dataset)(), 88 | batch_size=1, 89 | shuffle=False 90 | ) 91 | model = ModelSelector(args.model)().to('cuda:0') 92 | model.load_state_dict(torch.load(args.ckpt_path)) 93 | 94 | model.eval() 95 | 96 | trunc_how_many = 92 97 | 98 | ts_pbar = tqdm(ts_loader) 99 | 100 | abridge = False 101 | try: 102 | if args.abridge is True: 103 | abridge = True 104 | except Exception: #omegaconf.errors.ConfigAttributeError: 105 | abridge = False 106 | 107 | for n, (data, tokens, fine_difficulty,songname) in enumerate(ts_pbar): 108 | data = data[:].to('cuda:0') #(1, 80, 96T) 109 | tokens= tokens.to(torch.long).to('cuda:0') # (1, tokenlength) 110 | fine_difficulty = fine_difficulty.to('cuda:0').to(torch.float).unsqueeze(1) # (1, 1) 111 | 112 | logger.info(songname) 113 | generated = generate_2bar_greedy(model, data, tokens, fine_difficulty, 200) 114 | if abridge and n > 2: 115 | break 116 | 117 | print("====subtotl====") 118 | logger.info(f"precision {global_counter.precision()} \trecall {global_counter.recall()} \tf1 {global_counter.f1()}") 119 | 120 | 121 | def _main(args): 122 | global __file__ 123 | __file__ = hydra.utils.to_absolute_path(__file__) 124 | logger.info("For logs, checkpoints and samples check %s", os.getcwd()) 125 | logger.info(args) 126 | run(args) 127 | 128 | print(__file__) 129 | this_script_dir = os.path.dirname(__file__) 130 | 131 | @hydra.main(version_base=None, config_path=os.path.join(this_script_dir,'conf')) 132 | def main(args): 133 | try: 134 | _main(args) 135 | except Exception: 136 | logger.exception("some error happened") 137 | os._exit(1) 138 | 139 | if __name__=="__main__": 140 | main() 141 | 142 | 143 | -------------------------------------------------------------------------------- /code/model/gorstconditionininput.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | GORST: "generation of rhythm-game charts via song-to-token mapping". 4 | 5 | Transformer encoder-decoder. 6 | 7 | input: one-bar-long mel-spectrogram, shaped as follows -> (time 6*48, freq 80) 8 | output(target): tokenized one-bar-long chart 9 | loss: plain ol' cross-entropy 10 | """ 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | 16 | class GorstFineDiffInDecoderInput(nn.Module): 17 | def __init__(self, 18 | encoder_hidden_size=512, 19 | decoder_hidden_size=512, 20 | encoder_layers=6, 21 | decoder_layers=6, 22 | decoder_max_length=60, 23 | condition_dim=64, 24 | vocab_size=178, 25 | norm_first=False, 26 | bidirectional_past_context=False, 27 | dropout=0.1, 28 | initialize_method="uniform", initialize_std=0.1): 29 | super().__init__() 30 | self.encoder_hidden_size = encoder_hidden_size 31 | self.decoder_hidden_size = decoder_hidden_size 32 | assert self.encoder_hidden_size == self.decoder_hidden_size 33 | self.encoder_layers = encoder_layers 34 | self.decoder_layers = decoder_layers 35 | self.decoder_max_length = decoder_max_length 36 | self.vocab_size=vocab_size 37 | self.condition_dim = condition_dim 38 | self.bidirectional_past_context = bidirectional_past_context 39 | self.dropout = dropout 40 | 41 | self.decoder_conditions = nn.Sequential( 42 | nn.Linear(1,condition_dim*2), 43 | nn.ReLU(), 44 | nn.Linear(condition_dim*2,condition_dim), 45 | nn.ReLU() 46 | ) 47 | 48 | self.encoder_embedding = nn.Conv1d(80, self.encoder_hidden_size, 1) 49 | self.decoder_embedding = nn.Embedding(self.vocab_size, self.decoder_hidden_size - self.condition_dim) 50 | 51 | self.encoder_positional_embedding = nn.Parameter(data=torch.ones([1, 4*48, self.encoder_hidden_size], dtype=torch.float32)) 52 | self.decoder_positional_embedding_generator = nn.Embedding(self.decoder_max_length, self.decoder_hidden_size) 53 | self.decoder_positional_embedding_seeds = torch.arange(self.decoder_max_length,device='cuda:0') 54 | 55 | self.final = nn.Conv1d(self.decoder_hidden_size, self.vocab_size, 1) 56 | 57 | self.encoder = nn.TransformerEncoder( 58 | nn.TransformerEncoderLayer(d_model=self.encoder_hidden_size, nhead=8, batch_first=True, norm_first=norm_first, dropout=dropout), 59 | num_layers=self.encoder_layers 60 | ) 61 | self.decoder = nn.ModuleList([ 62 | nn.TransformerDecoderLayer(d_model=self.decoder_hidden_size, nhead=8, batch_first=True,norm_first=norm_first, dropout=dropout) 63 | for _ in range(self.decoder_layers) 64 | ]) 65 | 66 | self.initialize(initialize_method, initialize_std) 67 | 68 | def initialize(self, init_method, std): 69 | if init_method == "uniform": 70 | initrange = std 71 | self.decoder_embedding.weight.data.uniform_(-initrange,initrange) 72 | self.decoder_positional_embedding_generator.weight.data.uniform_(-initrange,initrange) 73 | torch.nn.init.uniform_(self.encoder_positional_embedding, a=-initrange, b=initrange) 74 | elif init_method == "normal": 75 | self.decoder_embedding.weight.data.normal_(std=std) 76 | self.decoder_positional_embedding_generator.weight.data.normal_(std=std) 77 | torch.nn.init.normal_(self.encoder_positional_embedding, mean=0.0,std=std) 78 | elif init_method == "trunc_normal": 79 | torch.nn.init.trunc_normal_(self.decoder_embedding.weight, std=std,a=-2*std,b=2*std) 80 | torch.nn.init.trunc_normal_(self.decoder_positional_embedding_generator.weight, std=std,a=-2*std,b=2*std) 81 | torch.nn.init.trunc_normal_(self.encoder_positional_embedding, mean=0.0,std=std,a=-2*std,b=2*std) 82 | 83 | def generate_square_subsequent_mask(self, sz): 84 | if not self.bidirectional_past_context: 85 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 86 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 87 | return mask 88 | else: 89 | mask = (torch.triu(torch.ones(sz,sz))==1).transpose(0,1) 90 | mask[:-100,:-100] = True 91 | # 92 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 93 | return mask 94 | 95 | def forward(self,spec,tgt,cond): 96 | """ 97 | params 98 | spec: (B, 80, 4*48) 99 | tgt: (B, T) 100 | cond: (B, 1) 101 | """ 102 | encoded = self.encoder_embedding(spec).permute((0,2,1)) 103 | encoded += self.encoder_positional_embedding 104 | 105 | cond = self.decoder_conditions(cond) # (B, self.condition_dim) 106 | cond = cond.unsqueeze(1) 107 | cond = cond.repeat(1,tgt.shape[1],1) 108 | 109 | tgt_in = torch.cat((self.decoder_embedding(tgt),cond),dim=2) 110 | target_length = tgt.shape[1] 111 | tgt_in += self.decoder_positional_embedding_generator(self.decoder_positional_embedding_seeds[:target_length]).unsqueeze(0) 112 | contexts = self.encoder(encoded) 113 | decoder_output = tgt_in 114 | square_mask = self.generate_square_subsequent_mask(target_length).to('cuda:0') 115 | for n, decoder_layer in enumerate(self.decoder): 116 | decoder_output = decoder_layer(tgt=decoder_output, memory=contexts, tgt_mask=square_mask) 117 | 118 | logits = self.final(decoder_output.permute(0,2,1)).permute((0,2,1)) 119 | 120 | return logits 121 | 122 | if __name__=="__main__": 123 | net = GorstFineDiffInDecoderInput(bidirectional_past_context=True).to('cuda') 124 | input = torch.ones([32,80,4*48]).to('cuda') 125 | target = torch.ones([32,24],dtype=torch.long).to('cuda') 126 | condition = torch.ones([32,1]).to('cuda') 127 | print(net(input,target,condition).shape) 128 | 129 | -------------------------------------------------------------------------------- /code/model/gorstconditionininputtiming.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | GORST: "generation of rhythm-game charts via song-to-token mapping". 4 | 5 | Transformer encoder-decoder. 6 | 7 | input: one-bar-long mel-spectrogram, shaped as follows -> (time 6*48, freq 80) 8 | output(target): tokenized one-bar-long chart 9 | loss: plain ol' cross-entropy 10 | """ 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | 16 | class GorstFineDiffInDecoderInputTiming(nn.Module): 17 | def __init__(self, 18 | encoder_hidden_size=512, 19 | decoder_hidden_size=512, 20 | encoder_layers=6, 21 | decoder_layers=6, 22 | decoder_max_length=60, 23 | condition_dim=64, 24 | vocab_size=178, 25 | norm_first=False, 26 | bidirectional_past_context=False, 27 | dropout=0.1, 28 | initialize_method="uniform", initialize_std=0.1): 29 | super().__init__() 30 | self.encoder_hidden_size = encoder_hidden_size 31 | self.decoder_hidden_size = decoder_hidden_size 32 | assert self.encoder_hidden_size == self.decoder_hidden_size 33 | self.encoder_layers = encoder_layers 34 | self.decoder_layers = decoder_layers 35 | self.decoder_max_length = decoder_max_length 36 | self.vocab_size=vocab_size 37 | self.condition_dim = condition_dim 38 | self.bidirectional_past_context = bidirectional_past_context 39 | self.dropout = dropout 40 | 41 | self.decoder_conditions = nn.Sequential( 42 | nn.Linear(1,condition_dim*2), 43 | nn.ReLU(), 44 | nn.Linear(condition_dim*2,condition_dim), 45 | nn.ReLU() 46 | ) 47 | 48 | self.encoder_embedding = nn.Conv1d(80, self.encoder_hidden_size, 1) 49 | self.decoder_embedding = nn.Embedding(self.vocab_size, self.decoder_hidden_size - self.condition_dim) 50 | 51 | self.encoder_positional_embedding = nn.Parameter(data=torch.ones([1, 4*48, self.encoder_hidden_size], dtype=torch.float32)) 52 | self.decoder_positional_embedding_generator = nn.Embedding(self.decoder_max_length, self.decoder_hidden_size) 53 | self.decoder_positional_embedding_seeds = torch.arange(self.decoder_max_length,device='cuda:0') 54 | 55 | self.final = nn.Conv1d(self.decoder_hidden_size, self.vocab_size, 1) 56 | 57 | self.encoder = nn.TransformerEncoder( 58 | nn.TransformerEncoderLayer(d_model=self.encoder_hidden_size, nhead=8, batch_first=True, norm_first=norm_first, dropout=dropout), 59 | num_layers=self.encoder_layers 60 | ) 61 | self.decoder = nn.ModuleList([ 62 | nn.TransformerDecoderLayer(d_model=self.decoder_hidden_size, nhead=8, batch_first=True,norm_first=norm_first, dropout=dropout) 63 | for _ in range(self.decoder_layers) 64 | ]) 65 | 66 | self.initialize(initialize_method, initialize_std) 67 | 68 | def initialize(self, init_method, std): 69 | if init_method == "uniform": 70 | initrange = std 71 | self.decoder_embedding.weight.data.uniform_(-initrange,initrange) 72 | self.decoder_positional_embedding_generator.weight.data.uniform_(-initrange,initrange) 73 | torch.nn.init.uniform_(self.encoder_positional_embedding, a=-initrange, b=initrange) 74 | elif init_method == "normal": 75 | self.decoder_embedding.weight.data.normal_(std=std) 76 | self.decoder_positional_embedding_generator.weight.data.normal_(std=std) 77 | torch.nn.init.normal_(self.encoder_positional_embedding, mean=0.0,std=std) 78 | elif init_method == "trunc_normal": 79 | torch.nn.init.trunc_normal_(self.decoder_embedding.weight, std=std,a=-2*std,b=2*std) 80 | torch.nn.init.trunc_normal_(self.decoder_positional_embedding_generator.weight, std=std,a=-2*std,b=2*std) 81 | torch.nn.init.trunc_normal_(self.encoder_positional_embedding, mean=0.0,std=std,a=-2*std,b=2*std) 82 | 83 | def generate_square_subsequent_mask(self, sz): 84 | if not self.bidirectional_past_context: 85 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 86 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 87 | return mask 88 | else: 89 | mask = (torch.triu(torch.ones(sz,sz))==1).transpose(0,1) 90 | mask[:-50,:-50] = True 91 | # 92 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 93 | return mask 94 | 95 | def forward(self,spec,tgt,cond): 96 | """ 97 | params 98 | spec: (B, 80, 4*48) 99 | tgt: (B, T) 100 | cond: (B, 1) 101 | """ 102 | encoded = self.encoder_embedding(spec).permute((0,2,1)) 103 | encoded += self.encoder_positional_embedding 104 | 105 | cond = self.decoder_conditions(cond) # (B, self.condition_dim) 106 | cond = cond.unsqueeze(1) 107 | cond = cond.repeat(1,tgt.shape[1],1) 108 | 109 | tgt_in = torch.cat((self.decoder_embedding(tgt),cond),dim=2) 110 | target_length = tgt.shape[1] 111 | tgt_in += self.decoder_positional_embedding_generator(self.decoder_positional_embedding_seeds[:target_length]).unsqueeze(0) 112 | contexts = self.encoder(encoded) 113 | decoder_output = tgt_in 114 | square_mask = self.generate_square_subsequent_mask(target_length).to('cuda:0') 115 | for n, decoder_layer in enumerate(self.decoder): 116 | decoder_output = decoder_layer(tgt=decoder_output, memory=contexts, tgt_mask=square_mask) 117 | 118 | logits = self.final(decoder_output.permute(0,2,1)).permute((0,2,1)) 119 | 120 | return logits 121 | 122 | if __name__=="__main__": 123 | net = GorstFineDiffInDecoderInputTiming(bidirectional_past_context=True).to('cuda') 124 | input = torch.ones([32,80,4*48]).to('cuda') 125 | target = torch.ones([32,24],dtype=torch.long).to('cuda') 126 | condition = torch.ones([32,1]).to('cuda') 127 | print(net(input,target,condition).shape) 128 | 129 | -------------------------------------------------------------------------------- /tune_thresholds.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import hydra 4 | import time 5 | import torch 6 | import wandb 7 | import numpy as np 8 | from tqdm import tqdm 9 | from torch.utils.data import DataLoader 10 | import torch.nn.functional as F 11 | from omegaconf import OmegaConf 12 | import copy 13 | 14 | from code.model.cnn import SpectrogramNormalizer, PlacementCNN 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | def convert_outputs(outputs): 19 | """ 20 | outputs: (B, 112) 21 | """ 22 | B = outputs.shape[0] 23 | ret = [[] for _ in range(B)] 24 | for i in range(B): 25 | for j in range(112): 26 | if j==0 and outputs[i][j] > outputs[i][j+1]: 27 | ret[i].append((j,outputs[i][j])) 28 | elif j==111 and outputs[i][j] > outputs[i][j-1]: 29 | ret[i].append((j,outputs[i][j])) 30 | elif 0 outputs[i][j+1] and outputs[i][j] > outputs[i][j-1]: 31 | ret[i].append((j,outputs[i][j])) 32 | return ret 33 | 34 | def convert_targets(targets): 35 | return list(targets) 36 | 37 | def get_fn_fp_tp(one_output, one_target): 38 | """ 39 | output: list of ints from [0,112) 40 | target: true-false array 41 | """ 42 | fn, fp, tp = 0, 0, 0 43 | #print(one_output, one_target) 44 | for i in range(112): 45 | if bool(one_target[i]) is False: continue 46 | for j in [i-2,i-1,i,i+1,i+2]: 47 | if j in one_output: 48 | #print(f"broken on {i}") 49 | break 50 | if j == i+2: 51 | #print("fn") 52 | fn += 1 53 | 54 | for peak in one_output: 55 | flag=False 56 | for j in [peak-2, peak-1, peak, peak+1, peak+2]: 57 | if j < 0: continue 58 | elif j >= 112: 59 | #print(f"broken on {peak}") 60 | break 61 | if bool(one_target[j]) is True: 62 | #print("tp") 63 | tp += 1 64 | flag=True 65 | break 66 | if flag is False: 67 | #print("fp") 68 | fp += 1 69 | #print(fn, fp, tp ) 70 | return fn, fp, tp 71 | 72 | 73 | def tune_thres(outputs, targets): 74 | """ 75 | targets: (B*1000, 112) 76 | outputs: list, len(outputs) = B*1000, each list in outputs may have variable length. 77 | """ 78 | B = len(outputs) 79 | print(f"{B} samples to tune with") 80 | max_f1 = 0 81 | max_thres = 0 82 | for thres in tqdm([0.01 * i for i in range(30)]): 83 | fn, fp, tp = 0, 0, 0 84 | oo = [[e[0] for e in ll if e[1]>thres] for ll in outputs] 85 | for idx in range(B): 86 | target_in_question = targets[idx] 87 | output_in_question = oo[idx] 88 | fnt, fpt, tpt = get_fn_fp_tp(output_in_question, target_in_question) 89 | fn += fnt 90 | fp += fpt 91 | tp += tpt 92 | precision = tp / (tp + fp + 1e-9) 93 | recall = tp / (tp + fn + 1e-9) 94 | f1 = 2 * precision * recall / (precision+recall + 1e-9) 95 | if f1 > max_f1: 96 | max_f1 = f1 97 | max_thres = thres 98 | # print("on ", thres, ":", max_thres, " and ", max_f1) 99 | return max_thres, max_f1 100 | # cnn, /mnt/c/Users/manym/Desktop/gorst/gorst/ckpts/ddc_cnn/2e4_and_64_try2/ckpt_epoch_1.pth 101 | # easy 0.19, 0.745 102 | # normal 0.17, 0.822 103 | # hard 0.17, 0.837 104 | # insane 0.17, 0.840 105 | # expert 0.17, 0.812 106 | # expert+ 0.13, 0.813 107 | 108 | # clstm, /mnt/c/Users/manym/Desktop/gorst/gorst/ckpts/ddc_clstm/0.0002_and_64/ckpt_epoch_0.pth 109 | # easy 0.17, 0.766 110 | # normal 0.16 0.831 111 | # hard 0.18 0.839 112 | # insane 0.19 0.844 113 | # expert 0.17, 0.822 114 | # expert+ 0.13, 0.814 115 | 116 | def run(args): 117 | from code import DatasetSelector, ModelSelector, LossSelector, OptimizerSelector 118 | from tqdm import tqdm 119 | 120 | #torch.manual_seed(10101010) 121 | cv_loader = DataLoader( 122 | DatasetSelector(args.cv_dataset)(), 123 | batch_size=args.experiment.batch_size.cv, 124 | num_workers=args.experiment.num_workers, 125 | shuffle=True 126 | ) 127 | model = ModelSelector(args.model)().to('cuda:0') 128 | model.load_state_dict(torch.load(args.ckpt_path)) 129 | 130 | normalizer = SpectrogramNormalizer().to('cuda:0') 131 | 132 | valid_step = 0 133 | valid_pbar = tqdm(cv_loader, position=0,leave=True) 134 | targets = [] 135 | preds = [] 136 | with torch.no_grad(): 137 | for data, target, fine_difficulty in valid_pbar: 138 | valid_step += 1 139 | data = data.to('cuda:0') # (B, 112, 80, 3) 140 | target = target.to(torch.bool) # (B, 112) 141 | fine_difficulty = fine_difficulty.to('cuda:0').to(torch.float) # (B, 1) 142 | 143 | # print(data.shape, target.shape, fine_difficulty.shape) 144 | 145 | data_norm = normalizer(data) 146 | #print(data_norm.shape) 147 | 148 | model_output = model(data_norm, fine_difficulty) 149 | model_output = model_output.cpu().detach().numpy() # (B, 112) 150 | target = target.cpu().detach().numpy() 151 | 152 | targets.extend(convert_targets(target)) 153 | preds.extend(convert_outputs(model_output)) 154 | 155 | if valid_step >= 2000: # B * 1000 should be enough 156 | break 157 | 158 | print(len(targets)) 159 | print(len(preds)) 160 | tuned_thres,f1 = tune_thres(preds, targets) 161 | 162 | print(tuned_thres,f1) 163 | 164 | 165 | def _main(args): 166 | global __file__ 167 | __file__ = hydra.utils.to_absolute_path(__file__) 168 | logger.info("For logs, checkpoints and samples check %s", os.getcwd()) 169 | logger.info(args) 170 | run(args) 171 | 172 | print(__file__) 173 | this_script_dir = os.path.dirname(__file__) 174 | 175 | @hydra.main(version_base=None, config_path=os.path.join(this_script_dir,'conf')) 176 | def main(args): 177 | try: 178 | _main(args) 179 | except Exception: 180 | logger.exception("some error happened") 181 | os._exit(1) 182 | 183 | if __name__=="__main__": 184 | main() -------------------------------------------------------------------------------- /osu-to-ddc/osutoddc/converter/osuparser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | from parser_exceptions import NotManiaFileException, NotV14Exception, MissingAttributeException, Not4KException 5 | 6 | # https://osu.ppy.sh/wiki/en/Client/File_formats/Osu_%28file_format%29 7 | 8 | def parserFactory(action): 9 | def parser(string): 10 | string = string.strip() 11 | if len(string.strip()) == 0 : return None 12 | else: return action(string) 13 | return parser 14 | 15 | int_parser = parserFactory(int) 16 | str_parser = parserFactory(str) 17 | float_parser = parserFactory(float) 18 | bool_parser = parserFactory(lambda x: x.strip() == "1") 19 | 20 | def dictParserFactory(fields, parsers): 21 | if len(fields)!=len(parsers): 22 | raise ValueError("fields and parsers have unequal lengths") 23 | def parser(string): 24 | string = string.strip() 25 | if len(string)==0: return None 26 | strings = string.split(',') 27 | if len(strings) != len(fields): 28 | print(string) 29 | raise ValueError(f"invalid string {string}") 30 | ret = {} 31 | for s,f,p in zip(strings, fields, parsers): 32 | ret[f] = p(s) 33 | return ret 34 | return parser 35 | 36 | 37 | ############### 38 | ############### General, Editor, Metadata, Difficulty. 39 | ############### 40 | 41 | fields_of_note = { 42 | "AudioFilename":str_parser, 43 | "AudioLeadIn":int_parser, 44 | "Title":str_parser, 45 | "Artist":str_parser, 46 | "Mode":int_parser, 47 | "Version":str_parser, 48 | "BeatmapID":int_parser, 49 | "BeatmapSetID": int_parser, 50 | "CircleSize": float_parser # if this is not 4, raise Not4KException 51 | } 52 | 53 | def linebyline_section_parser(lines): 54 | ret = {} 55 | for line in lines: 56 | line = line.strip() 57 | if len(line)==0: 58 | continue 59 | 60 | entries = line.split(':') 61 | attribute = entries[0] 62 | value = ":".join(entries[1:]) 63 | 64 | if attribute in fields_of_note: 65 | ret[attribute] = fields_of_note[attribute]( value ) 66 | else: 67 | ret[attribute] = str_parser( value ) 68 | return ret 69 | 70 | ############### 71 | ############### Events, Colours 72 | ############### 73 | 74 | def no_parser(str): 75 | return None 76 | 77 | ############### 78 | ############### TimingPoints 79 | ############### 80 | 81 | timing_points_parser = dictParserFactory( 82 | ["time" ,"beatLength","meter" ,"sampleSet","sampleIndex" ,"volume" ,"uninherited" ,"effects" ], 83 | [float_parser ,float_parser,int_parser,int_parser ,int_parser ,int_parser ,bool_parser ,int_parser ] 84 | ) 85 | 86 | def timing_points_section_parser(lines): 87 | lines = [e for e in lines if len(e.strip())>0] 88 | return [timing_points_parser(e) for e in lines] 89 | 90 | ############### 91 | ############### HitObjects 92 | ############### 93 | 94 | def take_first_entry_from_semicolon_separated_list(sslist): 95 | return int(sslist.split(':')[0]) 96 | 97 | def note_type_parser(number): 98 | number = int(number) 99 | return { 100 | 'hit': bool(number&1), 101 | 'slider': bool(number&2), 102 | 'spinner': bool(number&8), 103 | 'hold': bool(number&128) 104 | } 105 | 106 | 107 | x_parser = parserFactory(lambda x: min(3,max(0,int(x)*4//512))) 108 | hold_endpoint_parser = parserFactory(take_first_entry_from_semicolon_separated_list) 109 | 110 | hit_objects_parser = dictParserFactory( 111 | ["x" ,"y" ,"time" ,"type" ,"hitSound","holdEnd" ], 112 | [x_parser ,int_parser ,int_parser ,note_type_parser ,int_parser,hold_endpoint_parser] 113 | ) 114 | 115 | def hit_objects_section_parser(lines): 116 | lines = [e for e in lines if len(e.strip())>0] 117 | return [hit_objects_parser(e) for e in lines] 118 | 119 | ############### 120 | ############### FIELDS 121 | ############### 122 | 123 | fields = {"[General]": linebyline_section_parser, 124 | "[Editor]": no_parser, 125 | "[Metadata]": linebyline_section_parser, 126 | "[Difficulty]": linebyline_section_parser, 127 | "[Events]": no_parser, 128 | "[TimingPoints]": timing_points_section_parser, 129 | "[Colours]": no_parser, 130 | "[HitObjects]": hit_objects_section_parser } # These Sections appear in order. 131 | 132 | mandatory_fields = ["[General]", 133 | # "[Editor]", 134 | "[Metadata]", 135 | "[Difficulty]", 136 | # "[Events]", 137 | "[TimingPoints]", 138 | # "[Colours]", 139 | "[HitObjects]"] 140 | 141 | 142 | ############### 143 | ############### all together... 144 | ############### 145 | 146 | def parse_osu_file(filename): 147 | 148 | ret = {} 149 | 150 | with open(filename,'r',errors='replace') as file: 151 | lines = [line.strip() for line in file.readlines()] 152 | 153 | # correctness check 154 | if "osu file format v14" not in lines[0]: 155 | raise NotV14Exception("This file is not in osu file format v14.") 156 | 157 | extant_fields = [] 158 | fields_index = [] 159 | 160 | for n,line in enumerate(lines): 161 | for field in fields.keys(): 162 | if line.startswith(field): 163 | extant_fields.append(field) 164 | fields_index.append(n) 165 | 166 | for field in mandatory_fields: 167 | if field not in extant_fields: 168 | raise MissingAttributeException(f"{filename} missing mandatory field {field}.") 169 | 170 | field_contents = [lines[a+1:b] for a,b in zip(fields_index[:-1],fields_index[1:])] + [lines[fields_index[-1]+1:]] 171 | 172 | for field, field_content in zip(extant_fields,field_contents): 173 | ret[field] = fields[field](field_content) 174 | 175 | # Correctness Check 176 | if ret["[Difficulty]"]["CircleSize"] != 4: 177 | raise Not4KException("This beatmap is not a 4K") 178 | if ret["[General]"]["Mode"] != 3: 179 | raise NotManiaFileException("This beatmap is not an Osu! Mania Beatmap") 180 | 181 | return ret 182 | 183 | if __name__=="__main__": 184 | from pprint import pprint 185 | pprint(parse_osu_file(sys.argv[1])) 186 | -------------------------------------------------------------------------------- /0_ddc/dataset/extract_json.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging as smlog 3 | import os 4 | import traceback 5 | 6 | from smdataset.abstime import calc_note_beats_and_abs_times 7 | from smdataset.parse import parse_sm_txt 8 | 9 | _ATTR_REQUIRED = ['offset', 'bpms', 'notes'] 10 | 11 | if __name__ == '__main__': 12 | import argparse 13 | from collections import OrderedDict 14 | import json 15 | json.encoder.FLOAT_REPR = lambda f: ('%.6f' % f) 16 | from util import ez_name, get_subdirs 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('packs_dir', type=str, help='Directory of packs (organized like Stepmania songs folder)') 20 | parser.add_argument('json_dir', type=str, help='Output JSON directory') 21 | parser.add_argument('--itg', dest='itg', action='store_true', help='If set, subtract 9ms from offset') 22 | parser.add_argument('--choose', dest='choose', action='store_true', help='If set, choose from list of packs') 23 | 24 | parser.set_defaults( 25 | itg=False, 26 | choose=False) 27 | 28 | args = parser.parse_args() 29 | 30 | pack_names = get_subdirs(args.packs_dir, args.choose) 31 | pack_dirs = [os.path.join(args.packs_dir, pack_name) for pack_name in pack_names] 32 | pack_sm_globs = [os.path.join(pack_dir, '*', '*.sm') for pack_dir in pack_dirs] 33 | 34 | if not os.path.isdir(args.json_dir): 35 | os.mkdir(args.json_dir) 36 | 37 | pack_eznames = set() 38 | for pack_name, pack_sm_glob in zip(pack_names, pack_sm_globs): 39 | pack_sm_fps = sorted(glob.glob(pack_sm_glob)) 40 | pack_ezname = ez_name(pack_name) 41 | if pack_ezname in pack_eznames: 42 | raise ValueError('Pack name conflict: {}'.format(pack_ezname)) 43 | pack_eznames.add(pack_ezname) 44 | 45 | if len(pack_sm_fps) > 0: 46 | pack_outdir = os.path.join(args.json_dir, pack_ezname) 47 | if not os.path.isdir(pack_outdir): 48 | os.mkdir(pack_outdir) 49 | 50 | sm_eznames = set() 51 | for sm_fp in pack_sm_fps: 52 | sm_name = os.path.split(os.path.split(sm_fp)[0])[1] 53 | sm_ezname = ez_name(sm_name) 54 | if sm_ezname in sm_eznames: 55 | raise ValueError('Song name conflict: {}'.format(sm_ezname)) 56 | sm_eznames.add(sm_ezname) 57 | 58 | with open(sm_fp, 'r') as sm_f: 59 | sm_txt = sm_f.read() 60 | 61 | # parse file 62 | try: 63 | sm_attrs = parse_sm_txt(sm_txt) 64 | except ValueError as e: 65 | smlog.error('{} in\n{}'.format(e, sm_fp)) 66 | continue 67 | except Exception as e: 68 | smlog.critical('Unhandled parse exception {}'.format(traceback.format_exc())) 69 | raise e 70 | 71 | # check required attrs 72 | try: 73 | for attr_name in _ATTR_REQUIRED: 74 | if attr_name not in sm_attrs: 75 | raise ValueError('Missing required attribute {}'.format(attr_name)) 76 | except ValueError as e: 77 | smlog.error('{}'.format(e)) 78 | continue 79 | 80 | # handle missing music 81 | root = os.path.abspath(os.path.join(sm_fp, '..')) 82 | music_fp = os.path.join(root, sm_attrs.get('music', '')) 83 | if 'music' not in sm_attrs or not os.path.exists(music_fp): 84 | music_names = [] 85 | sm_prefix = os.path.splitext(sm_name)[0] 86 | 87 | # check directory files for reasonable substitutes 88 | for filename in os.listdir(root): 89 | prefix, ext = os.path.splitext(filename) 90 | if ext.lower()[1:] in ['mp3', 'ogg']: 91 | music_names.append(filename) 92 | 93 | try: 94 | # handle errors 95 | if len(music_names) == 0: 96 | raise ValueError('No music files found') 97 | elif len(music_names) == 1: 98 | sm_attrs['music'] = music_names[0] 99 | else: 100 | raise ValueError('Multiple music files {} found'.format(music_names)) 101 | except ValueError as e: 102 | smlog.error('{}'.format(e)) 103 | continue 104 | 105 | music_fp = os.path.join(root, sm_attrs['music']) 106 | 107 | bpms = sm_attrs['bpms'] 108 | offset = sm_attrs['offset'] 109 | if args.itg: 110 | # Many charters add 9ms of delay to their stepfiles to account for ITG r21/r23 global delay 111 | # see http://r21freak.com/phpbb3/viewtopic.php?f=38&t=12750 112 | offset -= 0.009 113 | stops = sm_attrs.get('stops', []) 114 | 115 | out_json_fp = os.path.join(pack_outdir, '{}_{}.json'.format(pack_ezname, sm_ezname)) 116 | out_json = OrderedDict([ 117 | ('sm_fp', os.path.abspath(sm_fp)), 118 | ('music_fp', os.path.abspath(music_fp)), 119 | ('pack', pack_name), 120 | ('title', sm_attrs.get('title')), 121 | ('artist', sm_attrs.get('artist')), 122 | ('offset', offset), 123 | ('bpms', bpms), 124 | ('stops', stops), 125 | ('charts', []) 126 | ]) 127 | 128 | for idx, sm_notes in enumerate(sm_attrs['notes']): 129 | note_beats_and_abs_times = calc_note_beats_and_abs_times(offset, bpms, stops, sm_notes[5]) 130 | notes = { 131 | 'type': sm_notes[0], 132 | 'desc_or_author': sm_notes[1], 133 | 'difficulty_coarse': sm_notes[2], 134 | 'difficulty_fine': sm_notes[3], 135 | 'notes': note_beats_and_abs_times, 136 | } 137 | out_json['charts'].append(notes) 138 | 139 | with open(out_json_fp, 'w') as out_f: 140 | try: 141 | out_f.write(json.dumps(out_json)) 142 | except UnicodeDecodeError: 143 | smlog.error('Unicode error in {}'.format(sm_fp)) 144 | continue 145 | 146 | print 'Parsed {} - {}: {} charts'.format(pack_name, sm_name, len(out_json['charts'])) 147 | -------------------------------------------------------------------------------- /0_ddc/dataset/smdataset/parse.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | import traceback 4 | 5 | parlog = logging 6 | 7 | VALID_PULSES = set([4, 8, 12, 16, 24, 32, 48, 64, 96, 192]) 8 | 9 | int_parser = lambda x: int(x.strip()) if x.strip() else None 10 | bool_parser = lambda x: True if x.strip() == 'YES' else False 11 | str_parser = lambda x: x.strip() if x.strip() else None 12 | float_parser = lambda x: float(x.strip()) if x.strip() else None 13 | def kv_parser(k_parser, v_parser): 14 | def parser(x): 15 | if not x: 16 | return (None, None) 17 | k, v = x.split('=', 1) 18 | return k_parser(k), v_parser(v) 19 | return parser 20 | def list_parser(x_parser): 21 | def parser(l): 22 | l_strip = l.strip() 23 | if len(l_strip) == 0: 24 | return [] 25 | else: 26 | return [x_parser(x) for x in l_strip.split(',')] 27 | return parser 28 | def bpms_parser(x): 29 | bpms = list_parser(kv_parser(float_parser, float_parser))(x) 30 | 31 | if len(bpms) == 0: 32 | raise ValueError('No BPMs found in list') 33 | if bpms[0][0] != 0.0: 34 | raise ValueError('First beat in BPM list is {}'.format(bpms[0][0])) 35 | 36 | # make sure changes are nonnegative, take last for equivalent 37 | beat_last = -1.0 38 | bpms_cleaned = [] 39 | for beat, bpm in bpms: 40 | if beat == None or bpm == None: 41 | raise ValueError('Empty BPM found') 42 | if bpm <= 0.0: 43 | raise ValueError('Non positive BPM found {}'.format(bpm)) 44 | if beat == beat_last: 45 | bpms_cleaned[-1] = (beat, bpm) 46 | continue 47 | bpms_cleaned.append((beat, bpm)) 48 | if beat <= beat_last: 49 | raise ValueError('Descending list of beats in BPM list') 50 | beat_last = beat 51 | if len(bpms) != len(bpms_cleaned): 52 | parlog.warning('One or more (beat, BPM) pairs begin on the same beat, using last listed') 53 | 54 | return bpms_cleaned 55 | def stops_parser(x): 56 | stops = list_parser(kv_parser(float_parser, float_parser))(x) 57 | 58 | beat_last = -1.0 59 | for beat, stop_len in stops: 60 | if beat == None or stop_len == None: 61 | raise ValueError('Bad stop formatting') 62 | if beat < 0.0: 63 | raise ValueError('Bad beat in stop') 64 | if stop_len == 0.0: 65 | continue 66 | if beat <= beat_last: 67 | raise ValueError('Nonascending list of beats in stops') 68 | beat_last = beat 69 | return stops 70 | def notes_parser(x): 71 | pattern = r'([^:]*):' * 5 + r'([^;:]*)' 72 | notes_split = re.findall(pattern, x) 73 | if len(notes_split) != 1: 74 | raise ValueError('Bad formatting of notes section') 75 | notes_split = notes_split[0] 76 | if (len(notes_split) != 6): 77 | raise ValueError('Bad formatting within notes section') 78 | 79 | # parse/clean measures 80 | measures = [measure.splitlines() for measure in notes_split[5].split(',')] 81 | measures_clean = [] 82 | for measure in measures: 83 | measure_clean = filter(lambda pulse: not pulse.strip().startswith('//') and len(pulse.strip()) > 0, measure) 84 | measures_clean.append(measure_clean) 85 | if len(measures_clean) > 0 and len(measures_clean[-1]) == 0: 86 | measures_clean = measures_clean[:-1] 87 | 88 | # check measure lengths 89 | for measure in measures_clean: 90 | if len(measure) == 0: 91 | raise ValueError('Found measure with 0 notes') 92 | if not len(measure) in VALID_PULSES: 93 | parlog.warning('Nonstandard subdivision {} detected, allowing'.format(len(measure))) 94 | 95 | chart_type = str_parser(notes_split[0]) 96 | if chart_type not in ['dance-single', 'dance-double', 'dance-couple', 'lights-cabinet']: 97 | raise ValueError('Nonstandard chart type {} detected'.format(chart_type)) 98 | 99 | return (str_parser(notes_split[0]), 100 | str_parser(notes_split[1]), 101 | str_parser(notes_split[2]), 102 | int_parser(notes_split[3]), 103 | list_parser(float_parser)(notes_split[4]), 104 | measures_clean 105 | ) 106 | 107 | def unsupported_parser(attr_name): 108 | def parser(x): 109 | raise ValueError('Unsupported attribute: {} with value {}'.format(attr_name, x)) 110 | return None 111 | return parser 112 | 113 | ATTR_NAME_TO_PARSER = { 114 | 'title': str_parser, 115 | 'subtitle': str_parser, 116 | 'artist': str_parser, 117 | 'titletranslit': str_parser, 118 | 'subtitletranslit': str_parser, 119 | 'artisttranslit': str_parser, 120 | 'genre': str_parser, 121 | 'credit': str_parser, 122 | 'banner': str_parser, 123 | 'background': str_parser, 124 | 'lyricspath': str_parser, 125 | 'cdtitle': str_parser, 126 | 'music': str_parser, 127 | 'offset': float_parser, 128 | 'bpms': bpms_parser, 129 | 'stops': stops_parser, 130 | 'samplestart': float_parser, 131 | 'samplelength': float_parser, 132 | 'displaybpm': str_parser, 133 | 'selectable': bool_parser, 134 | 'bgchanges': str_parser, 135 | 'bgchanges2': str_parser, 136 | 'fgchanges': str_parser, 137 | 'keysounds': str_parser, 138 | 'musiclength': float_parser, 139 | 'musicbytes': int_parser, 140 | 'attacks': str_parser, 141 | 'timesignatures': list_parser(kv_parser(float_parser, kv_parser(int_parser, int_parser))), 142 | 'warps': unsupported_parser('warps'), 143 | 'notes': notes_parser 144 | } 145 | ATTR_MULTI = ['notes'] 146 | 147 | def parse_sm_txt(sm_txt): 148 | attrs = {attr_name: [] for attr_name in ATTR_MULTI} 149 | 150 | for attr_name, attr_val in re.findall(r'#([^:]*):([^;]*);', sm_txt): 151 | attr_name = attr_name.lower() 152 | 153 | if attr_name not in ATTR_NAME_TO_PARSER: 154 | parlog.warning('Found unexpected attribute {}:{}, ignoring'.format(attr_name, attr_val)) 155 | continue 156 | 157 | attr_val_parsed = ATTR_NAME_TO_PARSER[attr_name](attr_val) 158 | if attr_name in attrs: 159 | if attr_name not in ATTR_MULTI: 160 | if attr_val_parsed == attrs[attr_name]: 161 | continue 162 | else: 163 | raise ValueError('Attribute {} defined multiple times'.format(attr_name)) 164 | attrs[attr_name].append(attr_val_parsed) 165 | else: 166 | attrs[attr_name] = attr_val_parsed 167 | 168 | for attr_name, attr_val in attrs.items(): 169 | if attr_val == None or attr_val == []: 170 | del attrs[attr_name] 171 | 172 | return attrs 173 | -------------------------------------------------------------------------------- /train_ddc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import hydra 4 | import time 5 | import torch 6 | import wandb 7 | import numpy as np 8 | from torch.utils.data import DataLoader 9 | import torch.nn.functional as F 10 | from omegaconf import OmegaConf 11 | 12 | from code.model.cnn import SpectrogramNormalizer, PlacementCNN 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | class Tracker(): 17 | def __init__(self) -> None: 18 | self.initialize() 19 | 20 | def initialize(self): 21 | self.count = 0 22 | self.sum = 0. 23 | 24 | def __str__(self): 25 | if self.count == 0: 26 | return str(0.) 27 | else: 28 | return str(float(self.sum / self.count)) 29 | 30 | def insert(self, num): 31 | self.sum += float(num) 32 | self.count += 1 33 | 34 | def value(self): 35 | return float(self.sum / self.count) 36 | 37 | class RunningMean(): 38 | def __init__(self, howmany=1000) -> None: 39 | self.howmany=howmany 40 | self.count = 0 41 | self.arr = [1. for _ in range(self.howmany)] 42 | 43 | def __str__(self): 44 | return str(float(sum(self.arr) / len(self.arr))) 45 | 46 | def insert(self, num): 47 | self.arr[self.count % self.howmany] = float(num) 48 | self.count += 1 49 | 50 | def value(self): 51 | return float(sum(self.arr) / len(self.arr)) 52 | 53 | def run(args): 54 | from code import DatasetSelector, ModelSelector, LossSelector, OptimizerSelector 55 | from tqdm import tqdm 56 | 57 | #torch.manual_seed(10101010) 58 | wandb.login() 59 | run = wandb.init( 60 | # Set the project where this run will be logged 61 | project="osu-mania-ddc", 62 | # Track hyperparameters and run metadata 63 | config={ 64 | "learning_rate": 1e-3, 65 | "epochs": 5, 66 | "testrun": 0, 67 | } 68 | ) 69 | 70 | tr_loader = DataLoader( 71 | DatasetSelector(args.tr_dataset)(), 72 | batch_size=args.experiment.batch_size.tr, 73 | num_workers=args.experiment.num_workers, 74 | shuffle=True 75 | ) 76 | cv_loader = DataLoader( 77 | DatasetSelector(args.cv_dataset)(), 78 | batch_size=args.experiment.batch_size.cv, 79 | num_workers=args.experiment.num_workers, 80 | shuffle=True 81 | ) 82 | model = ModelSelector(args.model)().to('cuda:0') 83 | loss = LossSelector(args.loss)().to('cuda:0') 84 | optimizer = OptimizerSelector(args.optimizer)(model.parameters()) 85 | 86 | normalizer = SpectrogramNormalizer().to('cuda:0') 87 | 88 | ckpt_path = str(args.ckpt_path) 89 | ckpt_path = os.path.join(ckpt_path, f"{str(args.optimizer.parameters.lr)}_and_{str(args.experiment.batch_size.tr)}") 90 | os.makedirs(ckpt_path, exist_ok=True) 91 | torch.save(model.state_dict(), f"{ckpt_path}/asdf.pth") 92 | for epoch in range(20): 93 | st = time.time() 94 | tr_pbar = tqdm(tr_loader, position=0,leave=True) 95 | tracker = Tracker() 96 | train_step, valid_step = 0, 0 97 | tr_running_mean = RunningMean(500) 98 | cv_running_mean = RunningMean(300) 99 | print("=================================epoch ",epoch) 100 | for data, target, fine_difficulty in tr_pbar: 101 | train_step += 1 102 | data = data.to('cuda:0') # (B, 112, 80, 3) 103 | target = target.to(torch.float).to('cuda:0') # (B, 112) 104 | fine_difficulty = fine_difficulty.to('cuda:0').to(torch.float) # (B) 105 | 106 | data_norm = normalizer(data) 107 | 108 | model_output = model(data_norm, fine_difficulty) 109 | 110 | 111 | 112 | l = loss(model_output, target) 113 | l.backward() 114 | loss_value = float(l) 115 | tracker.insert(l) 116 | tr_running_mean.insert(l) 117 | optimizer.step() 118 | optimizer.zero_grad() 119 | if tr_running_mean.count % 160 == 159: 120 | tr_pbar.set_description(str(loss_value)) 121 | wandb.log({ 122 | "training_loss": tr_running_mean.value() 123 | }) 124 | if train_step >= 20000: # 5000 steps per epoch 125 | break 126 | print("av.training loss for epoch: ", tracker) 127 | print(f"train time: {time.time()-st}") 128 | valid_pbar = tqdm(cv_loader, position=0,leave=True) 129 | tracker = Tracker() 130 | with torch.no_grad(): 131 | for data, target, fine_difficulty in valid_pbar: 132 | valid_step += 1 133 | data = data.to('cuda:0') # (B, 112, 80, 3) 134 | target = target.to(torch.float).to('cuda:0') # (B, 112) 135 | fine_difficulty = fine_difficulty.to('cuda:0').to(torch.float) # (B, 1) 136 | 137 | 138 | data_norm = normalizer(data) 139 | 140 | model_output = model(data_norm, fine_difficulty) 141 | 142 | l = loss(model_output, target) 143 | loss_value = float(l) 144 | tr_pbar.set_description(str(loss_value)) 145 | tracker.insert(l) 146 | cv_running_mean.insert(l) 147 | if tracker.count % 300 == 0: 148 | print(f"========================== {tracker.count // 300}") 149 | print("GT") 150 | print(target) 151 | print("PRED") 152 | print(model_output) 153 | if valid_step > 1600: break 154 | end = time.time() 155 | print(f"epoch time: {end-st}") 156 | print("av.valid loss for epoch: ", tracker) 157 | wandb.log({ 158 | "validation_loss": tracker.value() 159 | }) 160 | print("end epoch. saving....") 161 | torch.save(model.state_dict(), f"{ckpt_path}/ckpt_epoch_{epoch}.pth") 162 | print("...saved.") 163 | print(f"workers: {args.experiment.num_workers}") 164 | 165 | 166 | def _main(args): 167 | global __file__ 168 | __file__ = hydra.utils.to_absolute_path(__file__) 169 | logger.info("For logs, checkpoints and samples check %s", os.getcwd()) 170 | logger.info(args) 171 | run(args) 172 | 173 | print(__file__) 174 | this_script_dir = os.path.dirname(__file__) 175 | 176 | @hydra.main(version_base=None, config_path=os.path.join(this_script_dir,'conf')) 177 | def main(args): 178 | try: 179 | _main(args) 180 | except Exception: 181 | logger.exception("some error happened") 182 | os._exit(1) 183 | 184 | if __name__=="__main__": 185 | main() -------------------------------------------------------------------------------- /train_cond_centered.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import hydra 4 | import time 5 | import torch 6 | import wandb 7 | import numpy as np 8 | from torch.utils.data import DataLoader 9 | import torch.nn.functional as F 10 | from omegaconf import OmegaConf 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | class Tracker(): 15 | def __init__(self) -> None: 16 | self.initialize() 17 | 18 | def initialize(self): 19 | self.count = 0 20 | self.sum = 0. 21 | 22 | def __str__(self): 23 | if self.count == 0: 24 | return str(0.) 25 | else: 26 | return str(float(self.sum / self.count)) 27 | 28 | def insert(self, num): 29 | self.sum += float(num) 30 | self.count += 1 31 | 32 | def value(self): 33 | return float(self.sum / self.count) 34 | 35 | class RunningMean(): 36 | def __init__(self, howmany=1000) -> None: 37 | self.howmany=howmany 38 | self.count = 0 39 | self.arr = [1. for _ in range(self.howmany)] 40 | 41 | def __str__(self): 42 | return str(float(sum(self.arr) / len(self.arr))) 43 | 44 | def insert(self, num): 45 | self.arr[self.count % self.howmany] = float(num) 46 | self.count += 1 47 | 48 | def value(self): 49 | return float(sum(self.arr) / len(self.arr)) 50 | 51 | def run(args): 52 | from code import DatasetSelector, ModelSelector, LossSelector, OptimizerSelector 53 | from tqdm import tqdm 54 | 55 | wandb.login() 56 | run = wandb.init( 57 | # Set the project where this run will be logged 58 | project="osu-mania-4k", 59 | # Track hyperparameters and run metadata (if you want to) 60 | config={ 61 | "learning_rate": 2e-4, 62 | "epochs": 10, 63 | "testrun": 0, 64 | } 65 | ) 66 | 67 | tr_loader = DataLoader( 68 | DatasetSelector(args.tr_dataset)(), 69 | batch_size=args.experiment.batch_size.tr, 70 | num_workers=args.experiment.num_workers, 71 | shuffle=True 72 | ) 73 | cv_loader = DataLoader( 74 | DatasetSelector(args.cv_dataset)(), 75 | batch_size=args.experiment.batch_size.cv, 76 | num_workers=args.experiment.num_workers, 77 | shuffle=True 78 | ) 79 | 80 | model = ModelSelector(args.model)().to('cuda:0') 81 | loss = LossSelector(args.loss)().to('cuda:0') 82 | optimizer = OptimizerSelector(args.optimizer)(model.parameters()) 83 | 84 | ckpt_path = str(args.ckpt_path) 85 | ckpt_path = os.path.join(ckpt_path, f"{str(args.optimizer.parameters.lr)}_and_{str(args.experiment.batch_size.tr)}") 86 | if "load_from" in args: 87 | load_from = str(args.load_from) 88 | model.load_state_dict(torch.load(load_from)) 89 | 90 | os.makedirs(ckpt_path, exist_ok=True) 91 | for epoch in range(10): 92 | st = time.time() 93 | tr_pbar = tqdm(tr_loader, position=0,leave=True) 94 | tracker = Tracker() 95 | tr_running_mean = RunningMean(500) 96 | cv_running_mean = RunningMean(300) 97 | print("=================================epoch ",epoch) 98 | for data, tokens, mask, fine_difficulty in tr_pbar: 99 | data = data.to('cuda:0') 100 | tokens= tokens.to(torch.long).to('cuda:0') 101 | fine_difficulty = fine_difficulty.to('cuda:0').to(torch.float).unsqueeze(1) 102 | mask = mask.unsqueeze(1).to('cuda:0') #(B,1,else) 103 | 104 | model_output = model(data, tokens, fine_difficulty) 105 | #model_output = model(data,tokens) 106 | goal = torch.cat((tokens[:,1:], tokens[:,-1:]), axis=1) 107 | goal_onehot = F.one_hot(goal,num_classes=178).to(torch.float32) # (B, else, C) 108 | model_output = model_output.permute((0,2,1))[:,:,-50:] #(B,C,else) 109 | goal_onehot = goal_onehot.permute((0,2,1))[:,:,-50:] #(B,C,else) 110 | 111 | l = loss(model_output,goal_onehot) # (B, C, else) 112 | optimizer.zero_grad() 113 | l.backward() 114 | optimizer.step() 115 | loss_value = float(l) 116 | tr_pbar.set_description(str(loss_value)) 117 | tracker.insert(l) 118 | tr_running_mean.insert(l) 119 | if tr_running_mean.count % 500 == 499: 120 | wandb.log({ 121 | "training_loss": tr_running_mean.value() 122 | }) 123 | print("av.training loss for epoch: ", tracker) 124 | print(f"train time: {time.time()-st}") 125 | valid_pbar = tqdm(cv_loader, position=0,leave=True) 126 | tracker = Tracker() 127 | with torch.no_grad(): 128 | for data, tokens, mask, fine_difficulty in valid_pbar: 129 | data = data.to('cuda:0') 130 | tokens= tokens.to(torch.long).to('cuda:0') 131 | fine_difficulty = fine_difficulty.to('cuda:0').to(torch.float).unsqueeze(1) 132 | mask = mask.unsqueeze(1).to('cuda:0') 133 | 134 | model_output = model(data, tokens, fine_difficulty) # (32, length, 178) 135 | # model_output = model(data, tokens) 136 | goal = torch.cat((tokens[:,1:], tokens[:,-1:]), axis=1) 137 | goal_onehot = F.one_hot(goal,num_classes=178).to(torch.float32) 138 | model_output = model_output.permute((0,2,1))[:,:,-50:] #(B,C,else) 139 | goal_onehot = goal_onehot.permute((0,2,1))[:,:,-50:] #(B,C,else) 140 | 141 | l = loss(model_output,goal_onehot) # (B, C, else) 142 | loss_value = float(l) 143 | tr_pbar.set_description(str(loss_value)) 144 | tracker.insert(l) 145 | cv_running_mean.insert(l) 146 | if tracker.count % 300 == 0: 147 | created_tokens = torch.argmax(model_output.permute((0,2,1)),2).cpu().numpy() 148 | print(f"========================== {tracker.count // 300}") 149 | print("GT") 150 | print(goal[0,:].cpu().numpy()) 151 | print("PRED") 152 | print(created_tokens[0,:]) 153 | end = time.time() 154 | print(f"epoch time: {end-st}") 155 | print("av.valid loss for epoch: ", tracker) 156 | wandb.log({ 157 | "validation_loss": tracker.value() 158 | }) 159 | print("end epoch. saving....") 160 | torch.save(model.state_dict(), f"{ckpt_path}/ckpt_epoch_{epoch}.pth") 161 | print("...saved.") 162 | print(f"workers: {args.experiment.num_workers}") 163 | 164 | 165 | def _main(args): 166 | global __file__ 167 | __file__ = hydra.utils.to_absolute_path(__file__) 168 | logger.info("For logs, checkpoints and samples check %s", os.getcwd()) 169 | logger.info(args) 170 | run(args) 171 | 172 | print(__file__) 173 | this_script_dir = os.path.dirname(__file__) 174 | 175 | @hydra.main(version_base=None, config_path=os.path.join(this_script_dir,'conf')) 176 | def main(args): 177 | try: 178 | _main(args) 179 | except Exception: 180 | logger.exception("some error happened") 181 | os._exit(1) 182 | 183 | if __name__=="__main__": 184 | main() -------------------------------------------------------------------------------- /code/loader/audiocharttwobeatsongdataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the one used to supply data to evaluation codes. 3 | """ 4 | 5 | import os 6 | import sys 7 | import json 8 | import time 9 | import numpy as np 10 | import random 11 | import h5py 12 | from tqdm import tqdm 13 | from .beatencoder import NBeatEncoder, TwoTwoEncoder, WholeSongBeatEncoder 14 | 15 | def round_to_nearest_48th(num): 16 | return round(num*48)/48 17 | 18 | class AudioChartTwoBeatSongDataset(): 19 | """ 20 | usage of bs > 1 with this dataset will yield an error. 21 | """ 22 | def __init__(self, split_json_path="train.json", split_hdf5_path="train.h5", 23 | logspec=True, use_notecount_as_finediff=False) -> None : 24 | self.split_json_path = split_json_path 25 | self.split_hdf5_path = split_hdf5_path 26 | with open(split_json_path) as file: 27 | self.split_dict = json.loads(file.read()) 28 | self.hdf5 = h5py.File(split_hdf5_path,'r') 29 | self.dir_of_dirs = os.path.dirname(split_json_path) 30 | 31 | self.total_length_in_songs = len(self.split_dict.values()) 32 | 33 | self.index = [] 34 | for osujson, (bar_len,beat_len) in self.split_dict.items(): 35 | self.index.extend([(osujson,0,beat_len)]) 36 | 37 | assert len(self.index) == self.total_length_in_songs 38 | 39 | self.beat_encoder = TwoTwoEncoder() 40 | self.pad = 177 41 | self.length_in_beats = 4 42 | self.token_length = 200 43 | self.truncate = 92 44 | self.logspec = logspec 45 | self.use_notecount_as_finediff = use_notecount_as_finediff 46 | 47 | def __len__(self) -> int: 48 | return self.total_length_in_songs 49 | 50 | def __getitem__(self,idx,benchmark=False) -> dict: 51 | """ 52 | List of stuff to return: 53 | - the music for this song, made into a melspectrogram with hop size = (beat) / 48 (loaded from a h5) 54 | - the encoded tokens for the whole song, with a beat_token every two beats 55 | - fine difficulty 56 | - the encoded token for the first two beats 57 | """ 58 | if benchmark: tp = time.time() 59 | 60 | if idx > self.total_length_in_songs: 61 | raise ValueError("idx over total length") 62 | osujson_fn, beat, length_in_beats = self.index[idx] # beat = 0 63 | beatmapset_path = os.path.dirname(osujson_fn) 64 | with open(osujson_fn) as file: 65 | osujson = json.load(file) 66 | songname = f'{osujson["artist"]} - {osujson["title"]}, {osujson["charts"][0]["difficulty_fine"]}' 67 | 68 | if benchmark: print(f"json read: {time.time() - tp}") 69 | 70 | # fine diff 71 | try: 72 | if self.use_notecount_as_finediff: 73 | fine_difficulty = len(osujson['charts'][0]['notes']) 74 | else: 75 | fine_difficulty = float(osujson['charts'][0]['difficulty_fine']) 76 | except Exception as e: 77 | print(osujson_fn) 78 | raise e 79 | 80 | assert beat == int(beat) # see directory 0_filtration 81 | 82 | # load melspec from hdf5 83 | hdf5_dset_name = osujson_fn.replace(self.dir_of_dirs+'/','') 84 | mel = self.hdf5[hdf5_dset_name] 85 | mel_to_return = mel[:,:].T # (80, LENGTH) 86 | if self.logspec: 87 | mel_to_return = np.log10(mel_to_return+1e-10) 88 | if mel_to_return.shape[1] % 96 != 0: 89 | how_many_to_pad = 96 - mel_to_return.shape[1] % 96 90 | mel_to_return = np.pad(mel_to_return, ((0,0),(0,how_many_to_pad)), 'constant', constant_values=0) 91 | 92 | if benchmark: print(f"reading melspectrogram: {time.time() - tp}") 93 | 94 | #make tokens for all 95 | notes = osujson['charts'][0]['notes'] 96 | # make charts for whole chart for evalutaion. 97 | whole_chart_iterable = [] 98 | for beat in range(0,length_in_beats-2,2): 99 | nextbeat = beat + self.length_in_beats 100 | notes = list(filter(lambda x:(x[1] >= beat and x[1] < nextbeat), osujson['charts'][0]['notes'])) # TODO: make this configurable? 101 | temp_enc_notes = self.beat_encoder.encode(notes,beat) 102 | where_is_96 = temp_enc_notes[1:].index(96) + 1 103 | chart_for_this_section = np.array(temp_enc_notes[where_is_96+1:]) 104 | right_pad = self.token_length // 2 - len(chart_for_this_section) 105 | chart_for_this_section = np.pad(chart_for_this_section, (0, right_pad), 'constant', constant_values=(self.pad, self.pad)) 106 | whole_chart_iterable.append(chart_for_this_section) 107 | whole_chart_iterable = np.array(whole_chart_iterable) 108 | if benchmark: print(f"tokenmaking: {time.time() - tp}") 109 | 110 | return mel_to_return, whole_chart_iterable, fine_difficulty, songname 111 | 112 | def stress_test(split="train"): 113 | print(f"=============================== split = {split}") 114 | split_json_path = f"OSUFOLDER/{split}.json" 115 | split_hdf5_path = f"OSUFOLDER/{split}.h5" 116 | dset = AudioChartTwoBeatSongDataset(split_json_path=split_json_path, split_hdf5_path=split_hdf5_path) 117 | print(len(dset)) 118 | random_indices = random.sample(list(range(len(dset))),5000) 119 | start_time = time.time() 120 | #for i in tqdm(random_indices): 121 | #dset[i] 122 | print(dset[0][0].shape) 123 | print(dset[30]) # wait, why is there a minus mixed in? 124 | end_time = time.time() 125 | print(end_time-start_time) #five hours. wow, that's ridiculous 126 | 127 | def size_test(split="train"): 128 | print(f"=============================== split = {split}") 129 | split_json_path = f"OSUFOLDER/{split}.json" 130 | split_hdf5_path = f"OSUFOLDER/{split}.h5" 131 | dset = AudioChartTwoBeatSongDataset(split_json_path=split_json_path, split_hdf5_path=split_hdf5_path) 132 | print(len(dset)) 133 | sz = len(dset) 134 | for i in tqdm(range(sz)): 135 | assert dset[i][0].shape == (80,192) 136 | 137 | def print_one(split="train",index=0, **params): 138 | print(f"=============================== split = {split}") 139 | split_json_path = f"OSUFOLDER/{split}.json" 140 | split_hdf5_path = f"OSUFOLDER/{split}.h5" 141 | 142 | dset =AudioChartTwoBeatSongDataset(split_json_path=split_json_path, split_hdf5_path=split_hdf5_path, **params) 143 | print(dset[index]) 144 | 145 | if __name__=="__main__": 146 | # for split in ['valid']: 147 | # stress_test(split) 148 | print_one() 149 | 150 | """ XXX. 151 | benchmarks. 152 | 153 | ## mels: (80, T), windows 154 | 155 | random, 5000 accesses: 156 | - train (542966 samples): 148.2 157 | - valid (73915 samples): 102.4 158 | - test (78460 samples): 107.3 159 | 160 | sequential, 20000 accesses: 161 | - train (542966 samples): 343.1 162 | - valid (73915 samples): 334.4 163 | - test (78460 samples): 353.2 164 | 165 | ## mels: (T, 80), windows 166 | 167 | random: 5000 accesses: 168 | - train_transpose (560928 samples): 21.9 169 | - valid_transpose (63999 samples): 16.4 170 | - test_transpose (63595 samples): 16.7 171 | 172 | ## mels: (T, 80), wsl home dir 173 | 174 | random: 5000 accesses: 175 | - train_transpose (560928 samples): too scared to try 176 | - valid_transpose (63999 samples): 14.7 177 | - test_transpose (63595 samples): 15.2 178 | 179 | """ -------------------------------------------------------------------------------- /code/loader/audiotimingtwobeatsongdataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the one used to supply data to evaluation codes. 3 | """ 4 | 5 | import os 6 | import sys 7 | import json 8 | import time 9 | import numpy as np 10 | import random 11 | import h5py 12 | from tqdm import tqdm 13 | from .beatencoder import NBeatEncoder, TwoTwoEncoder, TwoTwoEncoderTimingOnly 14 | 15 | def round_to_nearest_48th(num): 16 | return round(num*48)/48 17 | 18 | class AudioTimingTwoBeatSongDataset(): 19 | """ 20 | usage of bs > 1 with this dataset will yield an error. 21 | """ 22 | def __init__(self, split_json_path="train.json", split_hdf5_path="train.h5", 23 | logspec=True, use_notecount_as_finediff=False) -> None : 24 | self.split_json_path = split_json_path 25 | self.split_hdf5_path = split_hdf5_path 26 | with open(split_json_path) as file: 27 | self.split_dict = json.loads(file.read()) 28 | self.hdf5 = h5py.File(split_hdf5_path,'r') 29 | self.dir_of_dirs = os.path.dirname(split_json_path) 30 | 31 | self.total_length_in_songs = len(self.split_dict.values()) 32 | 33 | self.index = [] 34 | for osujson, (bar_len,beat_len) in self.split_dict.items(): 35 | self.index.extend([(osujson,0,beat_len)]) 36 | 37 | assert len(self.index) == self.total_length_in_songs 38 | 39 | self.beat_encoder = TwoTwoEncoderTimingOnly() 40 | self.pad = 177 41 | self.length_in_beats = 4 42 | self.token_length = 100 43 | self.truncate = 46 44 | self.logspec = logspec 45 | self.use_notecount_as_finediff = use_notecount_as_finediff 46 | 47 | def __len__(self) -> int: 48 | return self.total_length_in_songs 49 | 50 | def __getitem__(self,idx,benchmark=False) -> dict: 51 | """ 52 | List of stuff to return: 53 | - the music for this song, made into a melspectrogram with hop size = (beat) / 48 (loaded from a h5) 54 | - the encoded tokens for the whole song, with a beat_token every two beats 55 | - fine difficulty 56 | - the encoded token for the first two beats 57 | """ 58 | if benchmark: tp = time.time() 59 | 60 | if idx > self.total_length_in_songs: 61 | raise ValueError("idx over total length") 62 | osujson_fn, beat, length_in_beats = self.index[idx] # beat = 0 63 | beatmapset_path = os.path.dirname(osujson_fn) 64 | with open(osujson_fn) as file: 65 | osujson = json.load(file) 66 | songname = f'{osujson["artist"]} - {osujson["title"]}, {osujson["charts"][0]["difficulty_fine"]}' 67 | 68 | if benchmark: print(f"json read: {time.time() - tp}") 69 | 70 | # fine diff 71 | try: 72 | if self.use_notecount_as_finediff: 73 | fine_difficulty = len(osujson['charts'][0]['notes']) 74 | else: 75 | fine_difficulty = float(osujson['charts'][0]['difficulty_fine']) 76 | except Exception as e: 77 | print(osujson_fn) 78 | raise e 79 | 80 | assert beat == int(beat) # see directory 0_filtration 81 | 82 | # load melspec from hdf5 83 | hdf5_dset_name = osujson_fn.replace(self.dir_of_dirs+'/','') 84 | mel = self.hdf5[hdf5_dset_name] 85 | mel_to_return = mel[:,:].T # (80, LENGTH) 86 | if self.logspec: 87 | mel_to_return = np.log10(mel_to_return+1e-10) 88 | if mel_to_return.shape[1] % 96 != 0: 89 | how_many_to_pad = 96 - mel_to_return.shape[1] % 96 90 | mel_to_return = np.pad(mel_to_return, ((0,0),(0,how_many_to_pad)), 'constant', constant_values=0) 91 | 92 | if benchmark: print(f"reading melspectrogram: {time.time() - tp}") 93 | 94 | #make tokens for all 95 | notes = osujson['charts'][0]['notes'] 96 | # make charts for whole chart for evalutaion. 97 | whole_chart_iterable = [] 98 | for beat in range(0,length_in_beats-2,2): 99 | nextbeat = beat + self.length_in_beats 100 | notes = list(filter(lambda x:(x[1] >= beat and x[1] < nextbeat), osujson['charts'][0]['notes'])) # TODO: make this configurable? 101 | temp_enc_notes = self.beat_encoder.encode(notes,beat) 102 | where_is_96 = temp_enc_notes[1:].index(96) + 1 103 | chart_for_this_section = np.array(temp_enc_notes[where_is_96+1:]) 104 | right_pad = self.token_length // 2 - len(chart_for_this_section) 105 | chart_for_this_section = np.pad(chart_for_this_section, (0, right_pad), 'constant', constant_values=(self.pad, self.pad)) 106 | whole_chart_iterable.append(chart_for_this_section) 107 | whole_chart_iterable = np.array(whole_chart_iterable) 108 | if benchmark: print(f"tokenmaking: {time.time() - tp}") 109 | 110 | return mel_to_return, whole_chart_iterable, fine_difficulty, songname 111 | 112 | def stress_test(split="train"): 113 | print(f"=============================== split = {split}") 114 | split_json_path = f"OSUFOLDER/{split}.json" 115 | split_hdf5_path = f"OSUFOLDER/{split}.h5" 116 | dset = AudioChartTwoBeatSongDataset(split_json_path=split_json_path, split_hdf5_path=split_hdf5_path) 117 | print(len(dset)) 118 | random_indices = random.sample(list(range(len(dset))),5000) 119 | start_time = time.time() 120 | #for i in tqdm(random_indices): 121 | #dset[i] 122 | print(dset[0][0].shape) 123 | print(dset[30]) # wait, why is there a minus mixed in? 124 | end_time = time.time() 125 | print(end_time-start_time) #five hours. wow, that's ridiculous 126 | 127 | def size_test(split="train"): 128 | print(f"=============================== split = {split}") 129 | split_json_path = f"OSUFOLDER/{split}.json" 130 | split_hdf5_path = f"OSUFOLDER/{split}.h5" 131 | dset = AudioChartTwoBeatSongDataset(split_json_path=split_json_path, split_hdf5_path=split_hdf5_path) 132 | print(len(dset)) 133 | sz = len(dset) 134 | for i in tqdm(range(sz)): 135 | assert dset[i][0].shape == (80,192) 136 | 137 | def print_one(split="train",index=0, **params): 138 | print(f"=============================== split = {split}") 139 | split_json_path = f"OSUFOLDER/{split}.json" 140 | split_hdf5_path = f"OSUFOLDER/{split}.h5" 141 | 142 | dset =AudioChartTwoBeatSongDataset(split_json_path=split_json_path, split_hdf5_path=split_hdf5_path, **params) 143 | print(dset[index]) 144 | 145 | if __name__=="__main__": 146 | # for split in ['valid']: 147 | # stress_test(split) 148 | print_one() 149 | 150 | """ XXX. 151 | benchmarks. 152 | 153 | ## mels: (80, T), windows 154 | 155 | random, 5000 accesses: 156 | - train (542966 samples): 148.2 157 | - valid (73915 samples): 102.4 158 | - test (78460 samples): 107.3 159 | 160 | sequential, 20000 accesses: 161 | - train (542966 samples): 343.1 162 | - valid (73915 samples): 334.4 163 | - test (78460 samples): 353.2 164 | 165 | ## mels: (T, 80), windows 166 | 167 | random: 5000 accesses: 168 | - train_transpose (560928 samples): 21.9 169 | - valid_transpose (63999 samples): 16.4 170 | - test_transpose (63595 samples): 16.7 171 | 172 | ## mels: (T, 80), wsl home dir 173 | 174 | random: 5000 accesses: 175 | - train_transpose (560928 samples): too scared to try 176 | - valid_transpose (63999 samples): 14.7 177 | - test_transpose (63595 samples): 15.2 178 | 179 | """ -------------------------------------------------------------------------------- /metrics_timing_AR.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file used to make the lower two rows of Table 2. 3 | 4 | python metrics_timing_AR.py +model=finediffdecoderprelnbipasttiming +ts_dataset=beatfine +abridge=False +ckpt_path=ckpts/mel_timingonly/0.0002_and_32/ckpt_epoch_9.pth 5 | precision 0.8889249703349243 recall 0.8062742555457787 f1 0.8455847640214559 6 | 7 | python metrics_timing_AR.py +model=finediffdecoderprelnbipasttiming +ts_dataset=beatfine_itg_timingonly +ckpt_path=ckpts/itg_mel_timingonly_finetune/2e-05_and_32/ckpt_epoch_13.pth 8 | precision 0.7074169077116192 recall 0.783818104855778 f1 0.7436603557085484 9 | 10 | python metrics_timing_AR.py +model=finediffdecoderprelnbipasttiming +ts_dataset=beatfine_fraxtil_timingonly +ckpt_path=vckpts/fraxtil_mel_timingonly_finetune/2e-05_and_32/ckpt_epoch_13.pth 11 | precision 0.7460073034449382 recall 0.8383683559950557 f1 0.7894957396284398 12 | 13 | python metrics_timing_AR.py +model=finediffdecoderprelnbipasttiming +ts_dataset=beatfine_itg_timingonly +ckpt_path=ckpts/itg_mel_timingonly/0.0002_and_32/ckpt_epoch_9.pth 14 | precision 0.6625449640287769 recall 0.7128257846042209 f1 0.6867662908910833 15 | 16 | python metrics_timing_AR.py +model=finediffdecoderprelnbipasttiming +ts_dataset=beatfine_fraxtil_timingonly +ckpt_path=ckpts/fraxtil_mel_timingonly/0.0002_and_32/ckpt_epoch_9.pth 17 | precision 0.6798731442754056 recall 0.7313720642768851 f1 0.7046829593635368 18 | 19 | I am sorry for not setting the seed properly... 20 | """ 21 | import torch 22 | from torch.utils.data import DataLoader 23 | import torch.nn.functional as F 24 | from tqdm import tqdm 25 | import time 26 | import hydra 27 | import numpy as np 28 | from omegaconf import OmegaConf 29 | import os 30 | import logging 31 | from pprint import pprint 32 | from code.counter.osu4kcounter import Osu4kTwoBeatOnsetCounter, Osu4kTwoBeatTimingCounter 33 | 34 | logger = logging.getLogger(__file__) 35 | 36 | trunc_how_many = None 37 | global_counter = Osu4kTwoBeatTimingCounter() 38 | 39 | 40 | def generate_2bar_greedy(model, data, tokens, fine_difficulty, max_gen_length=200): 41 | """ 42 | ! invariants 43 | model, data, tokens are already in cuda device 44 | data: (1, 80, 96T) 45 | tokens: (1, length_in_beats//2 - 1, 100). includes charts from beat 2. beats [2-4, 4-6, 6-8, ....] 46 | fine_diff: (1,1) 47 | data has dim of (80, L) 48 | """ 49 | global trunc_how_many 50 | global global_counter 51 | 52 | 53 | local_counter = Osu4kTwoBeatTimingCounter() 54 | 55 | 56 | decoder_macro_input = [177 for _ in range(max_gen_length//2-1)] + [96,96] + [177 for _ in range(max_gen_length//2-1)] 57 | decoder_macro_input = decoder_macro_input[trunc_how_many:] 58 | decoder_macro_input = torch.Tensor(decoder_macro_input).to(torch.long).to('cuda:0').unsqueeze(0) 59 | generated = [] 60 | how_many_twobeats = tokens.shape[1] 61 | 62 | for idx_to_generate_rn in range(0,how_many_twobeats): 63 | retrieve_this = (max_gen_length // 2) - trunc_how_many 64 | where_is_96 = int(retrieve_this) 65 | audio_this_cycle = data[: , :, 96*(idx_to_generate_rn):96*(idx_to_generate_rn+2)] 66 | tokens_this_cycle = decoder_macro_input[:,:] 67 | if audio_this_cycle.shape[2] == 0: break 68 | while retrieve_this < 199 - trunc_how_many: # nothing's gonna reach 100 anyways, so cutting it short should be fine...right? 69 | model_output = model(audio_this_cycle, tokens_this_cycle, fine_difficulty) 70 | created_tokens = torch.argmax(model_output, 2).to(torch.long) 71 | if int(created_tokens[0,retrieve_this]) == 177: 72 | break # input_this_cycle is the answer for this cycle. 73 | else: 74 | tokens_this_cycle[0,retrieve_this+1] = created_tokens[0,retrieve_this] 75 | retrieve_this += 1 76 | # split tokens by 96 77 | created_tokens = created_tokens[:, where_is_96:retrieve_this] 78 | local_counter.update(tokens[0,idx_to_generate_rn,:],created_tokens[0,:]) 79 | global_counter.update(tokens[0,idx_to_generate_rn,:],created_tokens[0,:]) 80 | generated.append(created_tokens) 81 | logger.info(f"{idx_to_generate_rn}\t\t{created_tokens}") 82 | created_tokens_length = created_tokens.shape[1] 83 | 84 | decoder_macro_input = torch.cat(( 85 | torch.Tensor([[177 for _ in range(max_gen_length//2-1-created_tokens_length)]]).to(torch.long).to('cuda'), 86 | torch.Tensor([[96]]).to(torch.long).to('cuda'), 87 | created_tokens, 88 | torch.Tensor([[96]]).to(torch.long).to('cuda'), 89 | torch.Tensor([[177 for _ in range(max_gen_length//2-1)]]).to(torch.long).to('cuda') 90 | ),axis=1) 91 | how_many_pad_on_right = max_gen_length - decoder_macro_input.shape[1] 92 | decoder_macro_input = F.pad(decoder_macro_input, (0, how_many_pad_on_right), 'constant', 177) 93 | decoder_macro_input = decoder_macro_input[:,trunc_how_many:] 94 | 95 | logger.info(f"precision {local_counter.precision()} \trecall {local_counter.recall()} \tf1 {local_counter.f1()}") 96 | assert len(generated) == how_many_twobeats or len(generated) == how_many_twobeats-1 97 | return generated 98 | 99 | 100 | 101 | def run(args): 102 | from code import DatasetSelector, ModelSelector 103 | global trunc_how_many 104 | global global_counter 105 | ts_loader = DataLoader( 106 | DatasetSelector(args.ts_dataset)(), 107 | batch_size=1, 108 | shuffle=False 109 | ) 110 | model = ModelSelector(args.model)().to('cuda:0') 111 | model.load_state_dict(torch.load(args.ckpt_path)) 112 | 113 | model.eval() 114 | 115 | trunc_how_many = 46 116 | 117 | ts_pbar = tqdm(ts_loader) 118 | 119 | abridge = False 120 | try: 121 | if args.abridge is True: 122 | abridge = True 123 | except Exception: #omegaconf.errors.ConfigAttributeError: 124 | abridge = False 125 | 126 | for n, (data, tokens, fine_difficulty,songname) in enumerate(ts_pbar): 127 | data = data[:].to('cuda:0') #(1, 80, 96T) 128 | tokens= tokens.to(torch.long).to('cuda:0') # (1, tokenlength) 129 | fine_difficulty = fine_difficulty.to('cuda:0').to(torch.float).unsqueeze(1) # (1, 1) 130 | 131 | logger.info(songname) 132 | generated = generate_2bar_greedy(model, data, tokens, fine_difficulty, 100) 133 | if abridge and n > 20: 134 | break 135 | 136 | print("====subtotl====") 137 | logger.info(f"precision {global_counter.precision()} \trecall {global_counter.recall()} \tf1 {global_counter.f1()}") 138 | 139 | 140 | def _main(args): 141 | global __file__ 142 | __file__ = hydra.utils.to_absolute_path(__file__) 143 | logger.info("For logs, checkpoints and samples check %s", os.getcwd()) 144 | logger.info(args) 145 | run(args) 146 | 147 | print(__file__) 148 | this_script_dir = os.path.dirname(__file__) 149 | 150 | @hydra.main(version_base=None, config_path=os.path.join(this_script_dir,'conf')) 151 | def main(args): 152 | try: 153 | _main(args) 154 | except Exception: 155 | logger.exception("some error happened") 156 | os._exit(1) 157 | 158 | if __name__=="__main__": 159 | main() 160 | 161 | 162 | -------------------------------------------------------------------------------- /ddc_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file used to make the upper two rows of Table 1. 3 | 4 | python ddc_eval.py +ts_dataset=ddc +model=ddc_cnn +experiment=ddc +ckpt_path=ckpts/ddc_cnn/0.0002_and_64/ckpt_epoch_1.pth 5 | precision 0.8404987330632542 , recall 0.8206751211129037 , f1 0.8304686440887686 6 | 7 | python ddc_eval.py +ts_dataset=ddc +model=ddc_clstm +experiment=ddc +ckpt_path=ckpts/ddc_clstm/0.0002_and_64/ckpt_epoch_0.pth 8 | precision 0.8390783855934946 , recall 0.8380883363992774 , f1 0.8385830682781487 9 | 10 | When you run the eval, you will want to switch out the thresholds as indicated on the comments for function `threshold_from_finediff`. 11 | These thresholds were found using scripts/thres_tune_{clstm,cnn}.sh 12 | """ 13 | import os 14 | import logging 15 | import hydra 16 | import time 17 | import torch 18 | import wandb 19 | import numpy as np 20 | from tqdm import tqdm 21 | from torch.utils.data import DataLoader 22 | import torch.nn.functional as F 23 | from omegaconf import OmegaConf 24 | import copy 25 | 26 | from code.model.cnn import SpectrogramNormalizer, PlacementCNN 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | def convert_outputs(outputs): 31 | """ 32 | outputs: (B, 112) 33 | """ 34 | B = outputs.shape[0] 35 | ret = [[] for _ in range(B)] 36 | for i in range(B): 37 | for j in range(112): 38 | if j==0 and outputs[i][j] > outputs[i][j+1]: 39 | ret[i].append((j,outputs[i][j])) 40 | elif j==111 and outputs[i][j] > outputs[i][j-1]: 41 | ret[i].append((j,outputs[i][j])) 42 | elif 0 outputs[i][j+1] and outputs[i][j] > outputs[i][j-1]: 43 | ret[i].append((j,outputs[i][j])) 44 | return ret 45 | 46 | def convert_targets(targets): 47 | return list(targets) 48 | 49 | def threshold_from_finediff(finediffs): 50 | # cnn. /mnt/c/Users/manym/Desktop/gorst/gorst/ckpts/ddc_cnn/2e4_and_64_try2/ckpt_epoch_1.pth 51 | # easy 0.19, 0.745 52 | # normal 0.17, 0.822 53 | # hard 0.17, 0.837 54 | # insane 0.17, 0.840 55 | # expert 0.17, 0.812 56 | # expert+ 0.13, 0.813 57 | 58 | # clstm. /mnt/c/Users/manym/Desktop/gorst/gorst/ckpts/ddc_clstm/0.0002_and_64/ckpt_epoch_0.pth 59 | # easy 0.17, 0.766 60 | # normal 0.16 0.831 61 | # hard 0.18 0.839 62 | # insane 0.19 0.844 63 | # expert 0.17, 0.822 64 | # expert+ 0.13, 0.814 65 | #print(finediffs) 66 | ret = [0.17 for e in finediffs] 67 | for n,finediff in enumerate(finediffs): 68 | if finediff < 2.0: ret[n] = 0.17 69 | elif 2.0<=finediff and finediff < 2.7: ret[n] = 0.16 70 | elif 2.7<=finediff and finediff < 4.0: ret[n] = 0.18 71 | elif 4.0<=finediff and finediff < 5.3: ret[n] = 0.19 72 | elif 5.3<=finediff and finediff < 6.5: ret[n] = 0.17 73 | elif finediff >= 6.5: ret[n] = 0.13 74 | return ret 75 | 76 | def get_fn_fp_tp(one_output, one_target): 77 | """ 78 | output: list of ints from [0,112) 79 | target: true-false array 80 | """ 81 | fn, fp, tp = 0, 0, 0 82 | #print(one_output, one_target) 83 | for i in range(112): 84 | if bool(one_target[i]) is False: continue 85 | for j in [i-2,i-1,i,i+1,i+2]: 86 | if j in one_output: 87 | #print(f"broken on {i}") 88 | break 89 | if j == i+2: 90 | #print("fn") 91 | fn += 1 92 | 93 | for peak in one_output: 94 | flag=False 95 | for j in [peak-2, peak-1, peak, peak+1, peak+2]: 96 | if j < 0: continue 97 | elif j >= 112: 98 | #print(f"broken on {peak}") 99 | break 100 | if bool(one_target[j]) is True: 101 | #print("tp") 102 | tp += 1 103 | flag=True 104 | break 105 | if flag is False: 106 | #print("fp") 107 | fp += 1 108 | #print(fn, fp, tp ) 109 | return fn, fp, tp 110 | 111 | 112 | def get_metric(outputs, targets, thresholds): 113 | """ 114 | targets: (B*1000, 112) 115 | outputs: list, len(outputs) = B*1000, each list in outputs may have variable length. 116 | """ 117 | B = len(outputs) 118 | fn, fp, tp = 0, 0, 0 119 | for output, target, threshold in zip(outputs, targets, thresholds): 120 | oo = [e[0] for e in output if e[1]>threshold] 121 | fnt,fpt,tpt = get_fn_fp_tp(oo, target) 122 | fn+=fnt 123 | fp+=fpt 124 | tp+=tpt 125 | return fn,fp,tp 126 | 127 | 128 | def run(args): 129 | from code import DatasetSelector, ModelSelector, LossSelector, OptimizerSelector 130 | from tqdm import tqdm 131 | 132 | #torch.manual_seed(10101010) 133 | ts_loader = DataLoader( 134 | DatasetSelector(args.ts_dataset)(), 135 | batch_size=args.experiment.batch_size.ts, 136 | num_workers=args.experiment.num_workers, 137 | shuffle=True 138 | ) 139 | model = ModelSelector(args.model)().to('cuda:0') 140 | model.load_state_dict(torch.load(args.ckpt_path)) 141 | 142 | normalizer = SpectrogramNormalizer().to('cuda:0') 143 | fn, fp, tp = 0, 0, 0 144 | the_step = 0 145 | the_pbar = tqdm(ts_loader, position=0,leave=True) 146 | 147 | with torch.no_grad(): 148 | for data, target, fine_difficulty in the_pbar: 149 | the_step += 1 150 | data = data.to('cuda:0') # (B, 112, 80, 3) 151 | target = target.to(torch.bool) # (B, 112) 152 | fine_difficulty = fine_difficulty.to('cuda:0').to(torch.float) # (B, 1) 153 | 154 | # print(data.shape, target.shape, fine_difficulty.shape) 155 | 156 | data_norm = normalizer(data) 157 | #print(data_norm.shape) 158 | 159 | model_output = model(data_norm, fine_difficulty) 160 | 161 | model_output = model_output.cpu().detach().numpy() # (B, 112) 162 | target = target.cpu().detach().numpy() 163 | fine_difficulty = fine_difficulty.squeeze().cpu().detach().numpy() 164 | 165 | target = convert_targets(target) 166 | pred = convert_outputs(model_output) 167 | thresholds = threshold_from_finediff(fine_difficulty) 168 | 169 | fnt,fpt,tpt = get_metric(pred,target,thresholds) 170 | fn += fnt 171 | fp += fpt 172 | tp += tpt 173 | 174 | # if valid_step >= 2000: 175 | # break 176 | 177 | print(fn, fp, tp) 178 | precision = tp / (tp + fp + 1e-9) 179 | recall = tp / (tp + fn + 1e-9) 180 | f1 = 2 * precision * recall / (precision+recall + 1e-9) 181 | print(f"precision {precision}\t\t, recall {recall}\t\t, f1 {f1}") 182 | 183 | 184 | def _main(args): 185 | global __file__ 186 | __file__ = hydra.utils.to_absolute_path(__file__) 187 | logger.info("For logs, checkpoints and samples check %s", os.getcwd()) 188 | logger.info(args) 189 | run(args) 190 | 191 | print(__file__) 192 | this_script_dir = os.path.dirname(__file__) 193 | 194 | @hydra.main(version_base=None, config_path=os.path.join(this_script_dir,'conf')) 195 | def main(args): 196 | try: 197 | _main(args) 198 | except Exception: 199 | logger.exception("some error happened") 200 | os._exit(1) 201 | 202 | if __name__=="__main__": 203 | main() -------------------------------------------------------------------------------- /code/loader/audiotimingfinebeatdataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | the argument to feed into torch.util.data.DataLoader 3 | """ 4 | 5 | import os 6 | import sys 7 | import json 8 | import time 9 | import numpy as np 10 | import random 11 | import h5py 12 | from tqdm import tqdm 13 | from .beatencoder import NBeatEncoder, TwoTwoEncoderTimingOnly 14 | 15 | def round_to_nearest_48th(num): 16 | return round(num*48)/48 17 | 18 | class AudioTimingFineBeatDataset(): 19 | def __init__(self, split_json_path="train.json", split_hdf5_path="train.h5", encoder="twotwo", 20 | length_in_beats=4, token_length=100, beat_upper_limit=20000, drop_first_two_rate=0.0, 21 | logspec=False, center=False, truncate=None, use_notecount_as_finediff=False, 22 | debug=False) -> None : 23 | self.split_json_path = split_json_path 24 | self.split_hdf5_path = split_hdf5_path 25 | with open(split_json_path) as file: 26 | self.split_dict = json.loads(file.read()) 27 | self.hdf5 = h5py.File(split_hdf5_path,'r') 28 | self.dir_of_dirs = os.path.dirname(split_json_path) 29 | 30 | self.total_length_in_beats = sum([min(beat_len,beat_upper_limit) for bar_len, beat_len in self.split_dict.values()]) 31 | 32 | self.index = [] 33 | for osujson, (bar_len,beat_len) in self.split_dict.items(): 34 | self.index.extend([(osujson,beat) for beat in range(min(beat_len,beat_upper_limit))]) 35 | 36 | assert len(self.index) == self.total_length_in_beats 37 | 38 | self.beat_encoder = TwoTwoEncoderTimingOnly() 39 | self.pad = 177 40 | assert length_in_beats == 4 41 | 42 | self.length_in_beats = length_in_beats 43 | self.token_length = token_length 44 | self.drop_first_two_rate = drop_first_two_rate 45 | self.logspec = logspec 46 | self.center = center 47 | self.truncate = truncate 48 | self.use_notecount_as_finediff = use_notecount_as_finediff 49 | self.debug = debug 50 | 51 | if self.center is False and truncate is not None: 52 | raise ValueError("truncate is only usable if self.center is set.") 53 | if truncate is None: 54 | self.truncate = 0 55 | if self.truncate > self.token_length//2 or self.truncate < 0: 56 | raise ValueError("self.truncate must be a non-negative smaller than half of token length") 57 | 58 | def __len__(self) -> int: 59 | return self.total_length_in_beats 60 | 61 | def __getitem__(self,idx,benchmark=False) -> dict: 62 | """ 63 | List of stuff to return: 64 | - the music for this bar, made into a melspectrogram with hop size = (beat) / 48 (loaded from a h5) 65 | - the encoded tokens for this _ beats 66 | - fine difficulty 67 | """ 68 | if benchmark: tp = time.time() 69 | 70 | dropmode = (random.random() < self.drop_first_two_rate) 71 | 72 | if idx > self.total_length_in_beats: 73 | raise ValueError("idx over total length") 74 | osujson_fn, beat = self.index[idx] 75 | with open(osujson_fn) as file: 76 | osujson = json.load(file) 77 | 78 | if benchmark: print(f"json read: {time.time() - tp}") 79 | if self.debug: print(self.index[idx]) 80 | # fine diff 81 | try: 82 | if self.use_notecount_as_finediff: 83 | fine_difficulty = len(osujson['charts'][0]['notes']) 84 | else: 85 | fine_difficulty = float(osujson['charts'][0]['difficulty_fine']) 86 | 87 | except Exception as e: 88 | print(osujson_fn) 89 | raise e 90 | 91 | assert beat == int(beat) # see directory 0_filtration 92 | nextbeat = beat + self.length_in_beats 93 | 94 | # load melspec from hdf5 95 | hdf5_dset_name = osujson_fn.replace(self.dir_of_dirs+'/','') 96 | mel = self.hdf5[hdf5_dset_name] 97 | mel_to_return = mel[48*round(beat):48*round(nextbeat),:].T # (80, 48*beat) 98 | if self.logspec: 99 | mel_to_return = np.log10(mel_to_return+1e-10) 100 | if mel_to_return.shape[1] < 192: 101 | mel_to_return = np.pad(mel_to_return, ((0,0),(0, 192 - mel_to_return.shape[1])), 'constant', constant_values=0) 102 | # load audio and account for offset 103 | 104 | if benchmark: print(f"reading melspectrogram: {time.time() - tp}") 105 | 106 | #make tokens 107 | notes = osujson['charts'][0]['notes'] 108 | if not dropmode: 109 | notes = list(filter(lambda x:(x[1] >= beat and x[1] < nextbeat), notes)) 110 | else: 111 | notes = list(filter(lambda x:(x[1] >= beat+2 and x[1] < nextbeat), notes)) # TODO: make this configurable? 112 | encoded_notes = self.beat_encoder.encode(notes,beat) 113 | where_is_96 = encoded_notes[1:].index(96) + 1 114 | where_should_96_be = self.token_length // 2 115 | encoded_notes = np.array(encoded_notes) 116 | if not self.center: 117 | where_should_96_be = where_is_96 118 | 119 | try: 120 | left_pad = where_should_96_be - where_is_96 121 | right_pad = self.token_length - len(encoded_notes) - left_pad 122 | encoded_notes = np.pad(encoded_notes, (left_pad, right_pad), 'constant', constant_values=(self.pad,self.pad)) 123 | except ValueError as e: 124 | self.bar_encoder.pretty_print(encoded_notes) 125 | print(e) 126 | raise ValueError(e) 127 | if benchmark: print(f"tokenmaking: {time.time() - tp}") 128 | 129 | beat_token_index = int(np.where(encoded_notes == 96)[0][1].item()) 130 | mask = [0. for i in range(beat_token_index) ] + [1.for i in range(self.token_length - beat_token_index)] 131 | mask = np.array(mask) 132 | 133 | # truncation for centered 134 | if self.center: 135 | mask = mask[self.truncate:] 136 | encoded_notes = encoded_notes[self.truncate:] 137 | 138 | #print(mel_to_return.shape, encoded_notes.shape, fine_difficulty, mask.shape) 139 | return mel_to_return, encoded_notes, mask, fine_difficulty 140 | 141 | def stress_test(split="train", **params): 142 | print(f"=============================== split = {split}") 143 | split_json_path = f"OSUFOLDER/{split}.json" 144 | split_hdf5_path = f"OSUFOLDER/{split}.h5" 145 | dset = AudioTimingFineBeatDataset(split_json_path=split_json_path, split_hdf5_path=split_hdf5_path, **params) 146 | print(len(dset)) 147 | random_indices = random.sample(list(range(len(dset))),5000) 148 | start_time = time.time() 149 | #for i in tqdm(random_indices): 150 | #dset[i] 151 | print(dset[30]) 152 | end_time = time.time() 153 | print(end_time-start_time) #five hours. wow, that's ridiculous 154 | 155 | def size_test(split="train",**params): 156 | print(f"=============================== split = {split}") 157 | split_json_path = f"OSUFOLDER/{split}.json" 158 | split_hdf5_path = f"OSUFOLDER/{split}.h5" 159 | 160 | dset = AudioTimingFineBeatDataset(split_json_path=split_json_path, split_hdf5_path=split_hdf5_path, **params) 161 | sz = len(dset) 162 | print(dset[0][1].shape) 163 | for i in tqdm(range(sz)): 164 | dset[i] 165 | 166 | if __name__=="__main__": 167 | # for split in ['valid']: 168 | # stress_test(split 169 | #size_test(center=True, token_length=200, truncate=94) 170 | #stress_test(center=True, token_length=100, truncate=47) 171 | size_test(center=True, token_length=100, truncate=47) 172 | 173 | -------------------------------------------------------------------------------- /code/loader/audiorandomshift.py: -------------------------------------------------------------------------------- 1 | """ 2 | the argument to feed into torch.util.data.DataLoader 3 | """ 4 | 5 | import os 6 | import sys 7 | import json 8 | import time 9 | import numpy as np 10 | import random 11 | import h5py 12 | from tqdm import tqdm 13 | from .beatencoder import TwoTwoEncoderFrac 14 | 15 | def round_to_nearest_48th(num): 16 | return round(num*48)/48 17 | 18 | class AudioRandomShiftDataset(): 19 | def __init__(self, split_json_path="train.json", split_hdf5_path="train.h5", encoder="twotwo", 20 | length_in_beats=4, token_length=100, beat_upper_limit=20000, drop_first_two_rate=0.0, 21 | logspec=False, center=False, truncate=None) -> None : 22 | self.split_json_path = split_json_path 23 | self.split_hdf5_path = split_hdf5_path 24 | with open(split_json_path) as file: 25 | self.split_dict = json.loads(file.read()) 26 | self.hdf5 = h5py.File(split_hdf5_path,'r') 27 | self.dir_of_dirs = os.path.dirname(split_json_path) 28 | 29 | self.total_length_in_beats = sum([min(beat_len,beat_upper_limit) for bar_len, beat_len in self.split_dict.values()]) 30 | 31 | self.index = [] 32 | for osujson, (bar_len,beat_len) in self.split_dict.items(): 33 | self.index.extend([(osujson,beat) for beat in range(min(beat_len,beat_upper_limit))]) 34 | 35 | assert len(self.index) == self.total_length_in_beats 36 | 37 | self.beat_encoder = TwoTwoEncoderFrac() 38 | self.pad = 177 39 | assert length_in_beats == 4 40 | 41 | self.length_in_beats = length_in_beats 42 | self.token_length = token_length 43 | self.drop_first_two_rate = drop_first_two_rate 44 | self.logspec = logspec 45 | self.center = center 46 | self.truncate = truncate 47 | 48 | if self.center is False and truncate is not None: 49 | raise ValueError("truncate is only usable if self.center is set.") 50 | if truncate is None: 51 | self.truncate = 0 52 | if self.truncate > self.token_length//2 or self.truncate < 0: 53 | raise ValueError("self.truncate must be a non-negative smaller than half of token length") 54 | 55 | def __len__(self) -> int: 56 | return self.total_length_in_beats 57 | 58 | def __getitem__(self,idx,benchmark=False) -> dict: 59 | """ 60 | List of stuff to return: 61 | - the music for this bar, made into a melspectrogram with hop size = (beat) / 48 (loaded from a h5) 62 | - the encoded tokens for this _ beats 63 | - fine difficulty 64 | """ 65 | if benchmark: tp = time.time() 66 | 67 | dropmode = (random.random() < self.drop_first_two_rate) 68 | 69 | if idx > self.total_length_in_beats: 70 | raise ValueError("idx over total length") 71 | osujson_fn, beat = self.index[idx] 72 | with open(osujson_fn) as file: 73 | osujson = json.load(file) 74 | 75 | if benchmark: print(f"json read: {time.time() - tp}") 76 | 77 | # fine diff 78 | try: 79 | fine_difficulty = float(osujson['charts'][0]['difficulty_fine']) 80 | except Exception as e: 81 | print(osujson_fn) 82 | raise e 83 | 84 | assert beat == int(beat) # see directory 0_filtration 85 | random_nudge = random.randint(0,48) 86 | beat += random_nudge / 48 87 | nextbeat = beat + self.length_in_beats 88 | 89 | # load melspec from hdf5 90 | hdf5_dset_name = osujson_fn.replace(self.dir_of_dirs+'/','') 91 | mel = self.hdf5[hdf5_dset_name] 92 | mel_to_return = mel[48*round(beat):48*round(nextbeat),:].T # (80, 48*beat) 93 | if self.logspec: 94 | mel_to_return = np.log10(mel_to_return+1e-10) 95 | if mel_to_return.shape[1] < 192: 96 | mel_to_return = np.pad(mel_to_return, ((0,0),(0, 192 - mel_to_return.shape[1])), 'constant', constant_values=0) 97 | # load audio and account for offset 98 | 99 | if benchmark: print(f"reading melspectrogram: {time.time() - tp}") 100 | 101 | #make tokens 102 | notes = osujson['charts'][0]['notes'] 103 | if not dropmode: 104 | notes = list(filter(lambda x:(round_to_nearest_48th(x[1]) >= round_to_nearest_48th(beat) 105 | and round_to_nearest_48th(x[1]) < round_to_nearest_48th(nextbeat)), notes)) 106 | else: 107 | notes = list(filter(lambda x:(round_to_nearest_48th(x[1]) >= round_to_nearest_48th(beat+2) 108 | and round_to_nearest_48th(x[1]) < round_to_nearest_48th(nextbeat)), notes)) 109 | encoded_notes = self.beat_encoder.encode(notes,beat) 110 | where_is_96 = encoded_notes[1:].index(96) + 1 111 | where_should_96_be = self.token_length // 2 112 | encoded_notes = np.array(encoded_notes) 113 | if not self.center: 114 | where_should_96_be = where_is_96 115 | 116 | try: 117 | left_pad = where_should_96_be - where_is_96 118 | right_pad = self.token_length - len(encoded_notes) - left_pad 119 | encoded_notes = np.pad(encoded_notes, (left_pad, right_pad), 'constant', constant_values=(self.pad,self.pad)) 120 | except ValueError as e: 121 | self.bar_encoder.pretty_print(encoded_notes) 122 | print(e) 123 | raise ValueError(e) 124 | if benchmark: print(f"tokenmaking: {time.time() - tp}") 125 | 126 | beat_token_index = int(np.where(encoded_notes == 96)[0][1].item()) 127 | mask = [0. for i in range(beat_token_index) ] + [1.for i in range(self.token_length - beat_token_index)] 128 | mask = np.array(mask) 129 | 130 | # truncation for centered 131 | if self.center: 132 | mask = mask[self.truncate:] 133 | encoded_notes = encoded_notes[self.truncate:] 134 | 135 | #print(mel_to_return.shape, encoded_notes.shape, fine_difficulty, mask.shape) 136 | return mel_to_return, encoded_notes, mask, fine_difficulty 137 | 138 | def stress_test(split="train", **params): 139 | print(f"=============================== split = {split}") 140 | split_json_path = f"OSUFOLDER/{split}.json" 141 | split_hdf5_path = f"OSUFOLDER/{split}.h5" 142 | dset = AudioRandomShiftDataset(split_json_path=split_json_path, split_hdf5_path=split_hdf5_path, **params) 143 | print(len(dset)) 144 | random_indices = random.sample(list(range(len(dset))),5000) 145 | start_time = time.time() 146 | #for i in tqdm(random_indices): 147 | #dset[i] 148 | print(dset[350]) 149 | end_time = time.time() 150 | print(end_time-start_time) 151 | 152 | def size_test(split="train",**params): 153 | print(f"=============================== split = {split}") 154 | split_json_path = f"OSUFOLDER/{split}.json" 155 | split_hdf5_path = f"OSUFOLDER/{split}.h5" 156 | 157 | dset = AudioRandomShiftDataset(split_json_path=split_json_path, split_hdf5_path=split_hdf5_path, **params) 158 | sz = len(dset) 159 | for i in tqdm(range(sz)): 160 | dset[i] 161 | 162 | if __name__=="__main__": 163 | # for split in ['valid']: 164 | # stress_test(split 165 | stress_test(center=True, token_length=200, truncate=94) 166 | # stress_test(center=True) 167 | 168 | """ XXX. 169 | benchmarks. 170 | 171 | ## mels: (80, T), windows 172 | 173 | random, 5000 accesses: 174 | - train (542966 samples): 148.2 175 | - valid (73915 samples): 102.4 176 | - test (78460 samples): 107.3 177 | 178 | sequential, 20000 accesses: 179 | - train (542966 samples): 343.1 180 | - valid (73915 samples): 334.4 181 | - test (78460 samples): 353.2 182 | 183 | ## mels: (T, 80), windows 184 | 185 | random: 5000 accesses: 186 | - train_transpose (560928 samples): 21.9 187 | - valid_transpose (63999 samples): 16.4 188 | - test_transpose (63595 samples): 16.7 189 | 190 | ## mels: (T, 80), wsl home dir 191 | 192 | random: 5000 accesses: 193 | - train_transpose (560928 samples): too scared to try 194 | - valid_transpose (63999 samples): 14.7 195 | - test_transpose (63595 samples): 15.2 196 | 197 | """ --------------------------------------------------------------------------------