├── code
├── finetune
│ ├── core
│ │ ├── __init__.py
│ │ ├── datasets
│ │ │ ├── __init__.py
│ │ │ ├── Makefile
│ │ │ ├── utils.py
│ │ │ ├── megatron_dataset.py
│ │ │ ├── blended_megatron_dataset_config.py
│ │ │ ├── blended_dataset.py
│ │ │ ├── blended_megatron_dataset_builder.py
│ │ │ ├── gpt_dataset.py
│ │ │ ├── indexed_dataset.py
│ │ │ └── helpers.cpp
│ │ └── parse_mixture.py
│ ├── scripts
│ │ ├── count_tokens.sh
│ │ ├── preprocess_data.sh
│ │ └── run_finetune.sh
│ ├── tools
│ │ ├── count_mmap_token.py
│ │ └── codecmanipulator.py
│ ├── config
│ │ └── ds_config_zero2.json
│ └── requirements.txt
├── prompt_egs
│ ├── genre.txt
│ └── lyrics.txt
├── inference
│ ├── mm_tokenizer_v0.2_hf
│ │ └── tokenizer.model
│ ├── codecmanipulator.py
│ ├── mmtokenizer.py
│ └── infer.py
├── requirements.txt
├── evals
│ └── pitch_range
│ │ └── main.py
└── top_200_tags.json
├── example.mp3
├── fig
├── model.pdf
└── tokenpair.pdf
├── tokenpair.png
└── README.md
/code/finetune/core/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/code/finetune/core/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/example.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WtxwNs/BACH/HEAD/example.mp3
--------------------------------------------------------------------------------
/fig/model.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WtxwNs/BACH/HEAD/fig/model.pdf
--------------------------------------------------------------------------------
/tokenpair.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WtxwNs/BACH/HEAD/tokenpair.png
--------------------------------------------------------------------------------
/fig/tokenpair.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WtxwNs/BACH/HEAD/fig/tokenpair.pdf
--------------------------------------------------------------------------------
/code/prompt_egs/genre.txt:
--------------------------------------------------------------------------------
1 | inspiring female uplifting pop airy vocal electronic bright vocal vocal
--------------------------------------------------------------------------------
/code/inference/mm_tokenizer_v0.2_hf/tokenizer.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WtxwNs/BACH/HEAD/code/inference/mm_tokenizer_v0.2_hf/tokenizer.model
--------------------------------------------------------------------------------
/code/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | omegaconf
3 | torchaudio
4 | einops
5 | numpy
6 | transformers
7 | sentencepiece
8 | tqdm
9 | tensorboard
10 | descript-audiotools>=0.7.2
11 | descript-audio-codec
12 | scipy
13 | accelerate>=0.26.0
14 |
--------------------------------------------------------------------------------
/code/finetune/core/datasets/Makefile:
--------------------------------------------------------------------------------
1 | CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color
2 | CPPFLAGS += $(shell python3 -m pybind11 --includes)
3 | LIBNAME = helpers
4 | LIBEXT = $(shell python3-config --extension-suffix)
5 |
6 | default: $(LIBNAME)$(LIBEXT)
7 |
8 | %$(LIBEXT): %.cpp
9 | $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@
10 |
--------------------------------------------------------------------------------
/code/finetune/scripts/count_tokens.sh:
--------------------------------------------------------------------------------
1 | echo "Please input parenet directory, will count all .bin files..."
2 | echo "Example: bash ./count_tokens.sh /workspace/dataset/music"
3 |
4 | PARENT_DIR=${1:-/workspace/dataset/music}
5 | LOG_DIR=./count_token_logs/
6 | mkdir -p $LOG_DIR
7 |
8 | # find all .bin files
9 | BINS=$(find $PARENT_DIR -name "*.bin" -type f)
10 |
11 | for bin in $BINS; do
12 | echo Checking mmap file: $bin
13 |
14 | mmap_path=$bin
15 |
16 | # mmap size in human readable format (e.g. 1.2G)
17 | mmap_size=$(du -h $mmap_path | awk '{print $1}')
18 | echo "Counting largest mmap file: $mmap_path, size: $mmap_size"
19 |
20 | # remove PARENT_DIR, replace / with _
21 | subdir=$(echo $mmap_path | sed "s|$PARENT_DIR/||g" | sed 's/\//_/g')
22 |
23 | cmd="nohup python tools/count_mmap_token.py --mmap_path $mmap_path > $LOG_DIR/count.$subdir.log 2>&1 &"
24 | echo $cmd
25 |
26 | eval $cmd
27 |
28 |
29 | echo "Finished!"
30 | done
31 |
32 |
33 |
--------------------------------------------------------------------------------
/code/finetune/tools/count_mmap_token.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
4 | os.path.pardir)))
5 | from core.datasets.indexed_dataset import MMapIndexedDataset
6 | import argparse
7 | from tqdm import tqdm
8 |
9 |
10 | def get_args():
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument("--mmap_path", type=str, required=True, help="Path to the .bin mmap file")
13 | return parser.parse_args()
14 |
15 |
16 | args = get_args()
17 |
18 | slice_path = args.mmap_path
19 | if slice_path.endswith(".bin"):
20 | slice_path = slice_path[:-4]
21 |
22 | dataset = MMapIndexedDataset(slice_path)
23 |
24 |
25 | def count_ids(dataset):
26 | count = 0
27 | for doc_ids in tqdm(dataset):
28 | count += doc_ids.shape[0]
29 | return count
30 |
31 | print("Counting tokens in ", args.mmap_path)
32 | total_cnt = count_ids(dataset)
33 | print("Total number of tokens: ", total_cnt)
34 |
--------------------------------------------------------------------------------
/code/prompt_egs/lyrics.txt:
--------------------------------------------------------------------------------
1 | [verse]
2 | Staring at the sunset, colors paint the sky
3 | Thoughts of you keep swirling, can't deny
4 | I know I let you down, I made mistakes
5 | But I'm here to mend the heart I didn't break
6 |
7 | [chorus]
8 | Every road you take, I'll be one step behind
9 | Every dream you chase, I'm reaching for the light
10 | You can't fight this feeling now
11 | I won't back down
12 | You know you can't deny it now
13 | I won't back down
14 |
15 | [verse]
16 | They might say I'm foolish, chasing after you
17 | But they don't feel this love the way we do
18 | My heart beats only for you, can't you see?
19 | I won't let you slip away from me
20 |
21 | [chorus]
22 | Every road you take, I'll be one step behind
23 | Every dream you chase, I'm reaching for the light
24 | You can't fight this feeling now
25 | I won't back down
26 | You know you can't deny it now
27 | I won't back down
28 |
29 | [bridge]
30 | No, I won't back down, won't turn around
31 | Until you're back where you belong
32 | I'll cross the oceans wide, stand by your side
33 | Together we are strong
34 |
35 | [outro]
36 | Every road you take, I'll be one step behind
37 | Every dream you chase, love's the tie that binds
38 | You can't fight this feeling now
39 | I won't back down
--------------------------------------------------------------------------------
/code/finetune/config/ds_config_zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 64,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 |
23 | "scheduler": {
24 | "type": "WarmupCosineLR",
25 | "params": {
26 | "total_num_steps": "auto",
27 | "warmup_min_ratio": 0.03,
28 | "warmup_num_steps": "auto",
29 | "cos_min_ratio": 0.1
30 | }
31 | },
32 |
33 | "zero_optimization": {
34 | "stage": 2,
35 | "offload_optimizer": {
36 | "device": "none",
37 | "pin_memory": true
38 | },
39 | "offload_param": {
40 | "device": "none",
41 | "pin_memory": true
42 | },
43 | "overlap_comm": false,
44 | "contiguous_gradients": true,
45 | "sub_group_size": 1e9,
46 | "reduce_bucket_size": "auto"
47 | },
48 |
49 | "gradient_accumulation_steps": "auto",
50 | "gradient_clipping": 1.0,
51 | "steps_per_print": 100,
52 | "train_batch_size": "auto",
53 | "train_micro_batch_size_per_gpu": "auto",
54 | "wall_clock_breakdown": false
55 | }
56 |
--------------------------------------------------------------------------------
/code/finetune/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==1.6.0
2 | annotated-types==0.7.0
3 | blobfile==3.0.0
4 | certifi==2025.4.26
5 | charset-normalizer==3.4.2
6 | click==8.2.0
7 | deepspeed==0.16.7
8 | docker-pycreds==0.4.0
9 | einops==0.8.1
10 | filelock==3.13.1
11 | fsspec==2024.6.1
12 | gitdb==4.0.12
13 | GitPython==3.1.44
14 | hjson==3.1.0
15 | huggingface-hub==0.31.2
16 | idna==3.10
17 | Jinja2==3.1.4
18 | lxml==5.4.0
19 | MarkupSafe==2.1.5
20 | mpmath==1.3.0
21 | msgpack==1.1.0
22 | networkx==3.3
23 | ninja==1.11.1.4
24 | nltk==3.9.1
25 | numpy==2.1.2
26 | nvidia-cublas-cu12==12.1.3.1
27 | nvidia-cuda-cupti-cu12==12.1.105
28 | nvidia-cuda-nvrtc-cu12==12.1.105
29 | nvidia-cuda-runtime-cu12==12.1.105
30 | nvidia-cudnn-cu12==9.1.0.70
31 | nvidia-cufft-cu12==11.0.2.54
32 | nvidia-curand-cu12==10.3.2.106
33 | nvidia-cusolver-cu12==11.4.5.107
34 | nvidia-cusparse-cu12==12.1.0.106
35 | nvidia-ml-py==12.575.51
36 | nvidia-nccl-cu12==2.20.5
37 | nvidia-nvjitlink-cu12==12.1.105
38 | nvidia-nvtx-cu12==12.1.105
39 | packaging==25.0
40 | peft==0.15.2
41 | pillow==11.0.0
42 | platformdirs==4.3.8
43 | protobuf==6.30.2
44 | psutil==7.0.0
45 | py-cpuinfo==9.0.0
46 | pybind11==2.13.6
47 | pycryptodomex==3.22.0
48 | pydantic==2.11.4
49 | pydantic_core==2.33.2
50 | PyYAML==6.0.2
51 | regex==2024.11.6
52 | requests==2.32.3
53 | safetensors==0.5.3
54 | scipy==1.15.3
55 | sentencepiece==0.2.0
56 | sentry-sdk==2.28.0
57 | setproctitle==1.3.6
58 | six==1.17.0
59 | smmap==5.0.2
60 | sympy==1.13.3
61 | tiktoken==0.9.0
62 | tokenizers==0.21.1
63 | torch==2.4.0
64 | torchaudio==2.4.0
65 | torchvision==0.19.0
66 | tqdm==4.67.1
67 | transformers==4.50.0
68 | triton==3.0.0
69 | typing-inspection==0.4.0
70 | typing_extensions==4.12.2
71 | urllib3==2.4.0
72 | wandb==0.19.11
73 |
--------------------------------------------------------------------------------
/code/finetune/core/datasets/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2 |
3 | import logging
4 | from enum import Enum
5 | from typing import List
6 |
7 | import numpy
8 | import torch
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | class Split(Enum):
14 | train = 0
15 | valid = 1
16 | test = 2
17 |
18 |
19 | def compile_helpers():
20 | """Compile C++ helper functions at runtime. Make sure this is invoked on a single process.
21 | """
22 | import os
23 | import subprocess
24 |
25 | command = ["make", "-C", os.path.abspath(os.path.dirname(__file__))]
26 | if subprocess.run(command).returncode != 0:
27 | import sys
28 |
29 | log_single_rank(logger, logging.ERROR, "Failed to compile the C++ dataset helper functions")
30 | sys.exit(1)
31 |
32 |
33 | def log_single_rank(logger: logging.Logger, *args, rank=0, **kwargs):
34 | """If torch distributed is initialized, log only on rank
35 |
36 | Args:
37 | logger (logging.Logger): The logger to write the logs
38 |
39 | rank (int, optional): The rank to write on. Defaults to 0.
40 | """
41 | if torch.distributed.is_initialized():
42 | if torch.distributed.get_rank() == rank:
43 | logger.log(*args, **kwargs)
44 | else:
45 | logger.log(*args, **kwargs)
46 |
47 |
48 | def normalize(weights: List[float]) -> List[float]:
49 | """Do non-exponentiated normalization
50 |
51 | Args:
52 | weights (List[float]): The weights
53 |
54 | Returns:
55 | List[float]: The normalized weights
56 | """
57 | w = numpy.array(weights, dtype=numpy.float64)
58 | w_sum = numpy.sum(w)
59 | w = (w / w_sum).tolist()
60 | return w
61 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Watch how BACH turns raw tokens into structured music—step by step.
5 |
6 |
7 | # BACH: Bar-level AI Composing Helper
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 | > *"Via Score to Performance: Efficient Human-Controllable Long Song Generation with Bar-Level Symbolic Notation"*
21 | > ICASSP 2026 Submission – **Pending Review**
22 |
23 | ---
24 |
25 | ## 🎼 One-sentence Summary
26 | BACH is the first **human-editable**, **bar-level** symbolic song generator:
27 | LLM writes lyrics → Transformer emits ABC score → off-the-shelf renderers give **minutes-long, Suno-level** music.
28 | **1 B params**, **minute-level** inference, **SOTA open-source**.
29 |
30 | ---
31 |
32 | ## 📦 What is inside this repo (preview release)
33 | | Path | Description |
34 | |------|-------------|
35 | | `README.md` | This file |
36 | | `code/` | inference code |
37 | | `example.mp3` | an example song |
38 | | `fig/` | Architecture figure |
39 |
40 | ---
41 |
42 | ## 🏗️ Model Architecture (one glance)
43 |
44 | User prompt
45 | Qwen3 — lyrics & style tags
46 | BACH-1B Decoder-Only Transformer
47 | ABC score (Dual-NTP + Chain-of-Score)
48 | ABC → MIDI → FluidSynth + VOCALOID
49 | Stereo mix
50 |
51 |
52 | | Component | Key idea |
53 | |-----------|----------|
54 | | **Dual-NTP** | Predict `{vocal_patch, accomp_patch}` jointly every step |
55 | | **Chain-of-Score** | Section tags `[START:Chorus] ... [END:Chorus]` for long coherence |
56 | | **Bar-stream patch** | 16-char non-overlapping patches per bar |
57 |
58 | ---
59 |
60 | ## 🧪 Quick start (CPU friendly)
61 | ```bash
62 | # 1. Clone
63 | git clone https://github.com/your-github/BACH.git
64 | cd BACH
65 |
66 | # 2. Install
67 | pip install -r requirements.txt # transformers>=4.41 mido abcpy fluidsynth
68 |
69 | # 3. Generate ABC
70 | python bach/generate.py \
71 | --prompt "A rainy-day lo-fi hip-hop song about missing the last train" \
72 | --out_abc demo/rainy_lofi.abc
73 |
74 | # 4. Render audio
75 | ```
76 |
77 | ## 🎧 Listen now
78 | example.mp3 is ready for you, it's a whole song. You can compare it with Suno🙂
79 |
80 | ## Full release upon related paper acceptance
81 | - Complete training set (ABC + lyrics + structure labels)
82 | - BACH-1B weights (Transformers format)
83 | - Training scripts (multiphase + multitask + ICL)
84 | - Complete Code
85 |
86 | ## 📎 Citation
87 | Paper is released on Arxiv,
88 | ```bibtex
89 | @misc{wang2025scoreperformanceefficienthumancontrollable,
90 | title={Via Score to Performance: Efficient Human-Controllable Long Song Generation with Bar-Level Symbolic Notation},
91 | author={Tongxi Wang and Yang Yu and Qing Wang and Junlang Qian},
92 | year={2025},
93 | eprint={2508.01394},
94 | archivePrefix={arXiv},
95 | primaryClass={cs.SD},
96 | url={https://arxiv.org/abs/2508.01394},
97 | }
98 |
--------------------------------------------------------------------------------
/code/evals/pitch_range/main.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import librosa
3 | import time
4 | import argparse
5 | import torch
6 | from extract_pitch_values_from_audio.src import RMVPE
7 | import os
8 | from pathlib import Path
9 | from tqdm import tqdm
10 |
11 | def process_audio(rmvpe, audio_path, output_path, device, hop_length, threshold):
12 | """Process an audio file in 10-second chunks and save the results."""
13 | # Load the audio file
14 | audio, sr = librosa.load(str(audio_path), sr=None)
15 | chunk_size = 10 * sr
16 | # pad to make the audio length to be multiple of hop_length
17 | audio = np.pad(audio, (0, chunk_size - len(audio) % chunk_size), mode='constant')
18 |
19 | # Calculate chunk size in samples (10 seconds * sample rate)
20 | total_chunks = int(np.round(len(audio) / chunk_size))
21 |
22 | # Initialize arrays to store results
23 | all_f0 = []
24 | total_infer_time = 0
25 |
26 | # Process each chunk
27 | for i in tqdm(range(total_chunks)):
28 | start_idx = i * chunk_size
29 | end_idx = min((i + 1) * chunk_size, len(audio))
30 | chunk = audio[start_idx:end_idx]
31 |
32 | # Process the chunk
33 | t = time.time()
34 | f0_chunk = rmvpe.infer_from_audio(chunk, sr, device=device, thred=threshold, use_viterbi=True)
35 | chunk_infer_time = time.time() - t
36 | total_infer_time += chunk_infer_time
37 |
38 | # Append results
39 | all_f0.extend(f0_chunk)
40 |
41 | # Create output directory if it doesn't exist
42 | output_path.parent.mkdir(parents=True, exist_ok=True)
43 |
44 | # remove all 0 in the f0
45 | all_f0 = np.array(all_f0)
46 | all_f0 = all_f0[all_f0 != 0]
47 |
48 | # convert all_f0 to a list
49 | all_f0 = all_f0.tolist()
50 |
51 | # Save the results
52 | with open(output_path, 'w') as f:
53 | for f0 in all_f0:
54 | f.write(f'{f0:.2f}\n')
55 |
56 | return total_infer_time, len(audio) / sr # Return total inference time and audio duration
57 |
58 | def main():
59 | input_dir = Path("/root/yue_pitch_evals/yue_vs_others_sep")
60 | output_dir = Path("/root/yue_pitch_evals/yue_vs_others_sep_pitch")
61 | device = "cuda"
62 |
63 | print(f'Using device: {device}')
64 | print('Loading model...')
65 | rmvpe = RMVPE("model.pt", hop_length=160)
66 |
67 | # Find all WAV files in input directory and subdirectories
68 | wav_files = list(input_dir.rglob('*.Vocals.mp3'))
69 | print(f'Found {len(wav_files)} WAV files to process')
70 |
71 | total_time = 0
72 | total_audio_duration = 0
73 |
74 | # Process each WAV file
75 | for wav_path in tqdm(wav_files, desc="Processing files"):
76 | # Calculate relative path to maintain directory structure
77 | rel_path = wav_path.relative_to(input_dir)
78 | # Create output path with .txt extension
79 | output_path = output_dir / str(rel_path).replace('.Vocals.mp3', '.txt')
80 |
81 | try:
82 | infer_time, audio_duration = process_audio(
83 | rmvpe, wav_path, output_path, device,
84 | 160, 0.03
85 | )
86 | total_time += infer_time
87 | total_audio_duration += audio_duration
88 |
89 | tqdm.write(f'Processed {wav_path.name}')
90 | tqdm.write(f'Time: {infer_time:.2f}s, RTF: {infer_time/audio_duration:.2f}')
91 |
92 | except Exception as e:
93 | tqdm.write(f'Error processing {wav_path}: {str(e)}')
94 | continue
95 |
96 | print('\nProcessing complete!')
97 | print(f'Total processing time: {total_time:.2f}s')
98 | print(f'Average RTF: {total_time/total_audio_duration:.2f}')
99 |
100 | if __name__ == '__main__':
101 | main()
--------------------------------------------------------------------------------
/code/finetune/scripts/preprocess_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATA_SETTING=$1
4 | MODE_TYPE=$2
5 | TOKENIZER_MODEL=$3
6 | AUDIO_PROMPT_MODES=($4)
7 | if [ -z "$4" ]; then
8 | AUDIO_PROMPT_MODES=('dual' 'inst' 'vocal' 'mixture')
9 | fi
10 |
11 | if [ -z "$DATA_SETTING" ] || [ -z "$MODE_TYPE" ]; then
12 | echo "Usage: $0 "
13 | echo " : e.g., dummy"
14 | echo " : cot or icl_cot"
15 | exit 1
16 | fi
17 |
18 | # Common settings based on DATA_SETTING
19 | if [ "$DATA_SETTING" == "dummy" ]; then
20 | DATA_ROOT=example
21 | NAME_PREFIX=dummy.msa.xcodec_16k
22 | CODEC_TYPE=xcodec
23 | INSTRUCTION="Generate music from the given lyrics segment by segment."
24 | ORDER=textfirst
25 | DROPOUT=0.0
26 | KEEP_SEQUENTIAL_SAMPLES=true
27 | QUANTIZER_BEGIN_IDX=0
28 | NUM_QUANTIZERS=1
29 | else
30 | echo "Invalid setting: $DATA_SETTING"
31 | exit 1
32 | fi
33 |
34 | JSONL_NAME=jsonl/$NAME_PREFIX.jsonl
35 |
36 | # Mode-specific settings and execution
37 | if [ "$MODE_TYPE" == "cot" ]; then
38 | echo "Running in 'cot' mode..."
39 | NAME_SUFFIX=stage_1_token_level_interleave_cot_xcodec
40 | MMAP_NAME=mmap/${NAME_PREFIX}_${NAME_SUFFIX}_$ORDER
41 |
42 | rm -f $DATA_ROOT/jsonl/${NAME_PREFIX}_*.jsonl # Use -f to avoid error if files don't exist
43 | mkdir -p $DATA_ROOT/$MMAP_NAME
44 |
45 | args="python core/preprocess_data_conditional_xcodec_segment.py \
46 | --input $DATA_ROOT/$JSONL_NAME \
47 | --output-prefix $DATA_ROOT/$MMAP_NAME \
48 | --tokenizer-model $TOKENIZER_MODEL \
49 | --tokenizer-type MMSentencePieceTokenizer \
50 | --codec-type $CODEC_TYPE \
51 | --workers 8 \
52 | --partitions 1 \
53 | --instruction \"$INSTRUCTION\" \
54 | --instruction-dropout-rate $DROPOUT \
55 | --order $ORDER \
56 | --append-eod \
57 | --quantizer-begin $QUANTIZER_BEGIN_IDX \
58 | --n-quantizer $NUM_QUANTIZERS \
59 | --use-token-level-interleave \
60 | --keep-sequential-samples \
61 | --cot
62 | "
63 |
64 | echo "$args"
65 | sleep 5
66 | eval $args
67 |
68 | rm -f $DATA_ROOT/jsonl/${NAME_PREFIX}_*.jsonl # Use -f
69 | rm -f $DATA_ROOT/${MMAP_NAME}_*_text_document.bin # Use -f
70 | rm -f $DATA_ROOT/${MMAP_NAME}_*_text_document.idx # Use -f
71 |
72 | elif [ "$MODE_TYPE" == "icl_cot" ]; then
73 | echo "Running in 'icl_cot' mode..."
74 | NAME_SUFFIX=stage_1_token_level_interleave_long_prompt_msa
75 | MMAP_NAME=mmap/${NAME_PREFIX}_${NAME_SUFFIX}_$ORDER # Define MMAP_NAME base for this mode
76 | PROMPT_LEN=30
77 |
78 | rm -f $DATA_ROOT/jsonl/${NAME_PREFIX}_*.jsonl # Use -f
79 | mkdir -p $DATA_ROOT/$MMAP_NAME # Ensure base MMAP dir exists
80 |
81 |
82 | for mode in "${AUDIO_PROMPT_MODES[@]}"; do
83 | echo "Processing mode: $mode"
84 | MODE_MMAP_NAME=${MMAP_NAME}_${mode} # Mode specific path
85 | mkdir -p $DATA_ROOT/$MODE_MMAP_NAME # Ensure mode-specific dir exists
86 |
87 | args="python core/preprocess_data_conditional_xcodec_segment.py \
88 | --input $DATA_ROOT/$JSONL_NAME \
89 | --output-prefix $DATA_ROOT/$MODE_MMAP_NAME \
90 | --tokenizer-model $TOKENIZER_MODEL \
91 | --tokenizer-type MMSentencePieceTokenizer \
92 | --codec-type $CODEC_TYPE \
93 | --workers 8 \
94 | --partitions 1 \
95 | --instruction \"$INSTRUCTION\" \
96 | --instruction-dropout-rate $DROPOUT \
97 | --order $ORDER \
98 | --append-eod \
99 | --quantizer-begin $QUANTIZER_BEGIN_IDX \
100 | --n-quantizer $NUM_QUANTIZERS \
101 | --cot \
102 | --use-token-level-interleave \
103 | --use-audio-icl \
104 | --audio-prompt-mode $mode \
105 | --audio-prompt-len $PROMPT_LEN \
106 | --keep-sequential-samples
107 | "
108 |
109 | echo "$args"
110 | sleep 5
111 | eval $args
112 |
113 | # Clean up mode-specific files
114 | rm -f $DATA_ROOT/jsonl/${NAME_PREFIX}_*.jsonl # Use -f
115 | rm -f $DATA_ROOT/${MODE_MMAP_NAME}_*_text_document.bin # Use -f
116 | rm -f $DATA_ROOT/${MODE_MMAP_NAME}_*_text_document.idx # Use -f
117 | done
118 |
119 | else
120 | echo "Invalid mode_type: $MODE_TYPE. Use 'cot' or 'icl_cot'."
121 | exit 1
122 | fi
123 |
124 | echo "Preprocessing finished for setting '$DATA_SETTING' and mode_type '$MODE_TYPE'."
--------------------------------------------------------------------------------
/code/finetune/core/datasets/megatron_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2 |
3 | import hashlib
4 | import json
5 | from abc import ABC, abstractmethod, abstractstaticmethod
6 | from collections import OrderedDict
7 | from typing import Dict, List
8 |
9 | import numpy
10 | import torch
11 |
12 | from core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
13 | from core.datasets.indexed_dataset import MMapIndexedDataset
14 | from core.datasets.utils import Split
15 |
16 |
17 | class MegatronDataset(ABC, torch.utils.data.Dataset):
18 | """The wrapper class from which dataset classes should inherit e.g. GPTDataset
19 |
20 | Args:
21 | indexed_dataset (MMapIndexedDataset): The MMapIndexedDataset around which to build the
22 | MegatronDataset
23 |
24 | indexed_indices (numpy.ndarray): The set of the documents indices to expose
25 |
26 | num_samples (int): The number of samples to draw from the indexed dataset
27 |
28 | index_split (Split): The indexed_indices Split
29 |
30 | config (BlendedMegatronDatasetConfig): The container for all config sourced parameters
31 | """
32 |
33 | def __init__(
34 | self,
35 | indexed_dataset: MMapIndexedDataset,
36 | indexed_indices: numpy.ndarray,
37 | num_samples: int,
38 | index_split: Split,
39 | config: BlendedMegatronDatasetConfig,
40 | ) -> None:
41 | assert indexed_indices.size > 0
42 | assert num_samples > 0
43 | assert self.is_multimodal() == indexed_dataset.multimodal
44 | assert self.is_split_by_sequence() != self.is_split_by_document()
45 |
46 | self.indexed_dataset = indexed_dataset
47 | self.indexed_indices = indexed_indices
48 | self.num_samples = num_samples
49 | self.index_split = index_split
50 | self.config = config
51 |
52 | self.unique_identifiers = OrderedDict()
53 | self.unique_identifiers["class"] = type(self).__name__
54 | self.unique_identifiers["path_prefix"] = self.indexed_dataset.path_prefix
55 | self.unique_identifiers["num_samples"] = self.num_samples
56 | self.unique_identifiers["index_split"] = self.index_split.name
57 | for attr in self._key_config_attributes():
58 | self.unique_identifiers[attr] = getattr(self.config, attr)
59 | self.unique_identifiers["add_bos"] = getattr(self.config, "add_bos", False)
60 |
61 | self.unique_description = json.dumps(self.unique_identifiers, indent=4)
62 | self.unique_description_hash = hashlib.md5(
63 | self.unique_description.encode("utf-8")
64 | ).hexdigest()
65 |
66 | self._finalize()
67 |
68 | @abstractmethod
69 | def _finalize(self) -> None:
70 | """Build the dataset and assert any subclass-specific conditions
71 | """
72 | pass
73 |
74 | @abstractmethod
75 | def __len__(self) -> int:
76 | """Return the length of the dataset
77 |
78 | Returns:
79 | int: See abstract implementation
80 | """
81 | pass
82 |
83 | @abstractmethod
84 | def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]:
85 | """Return from the dataset
86 |
87 | Args:
88 | idx (int): The index into the dataset
89 |
90 | Returns:
91 | Dict[str, numpy.ndarray]: See abstract implementation
92 | """
93 | pass
94 |
95 | @abstractstaticmethod
96 | def is_multimodal() -> bool:
97 | """Return True if the inheritor class and its internal MMapIndexedDataset are multimodal
98 |
99 | Returns:
100 | bool: See abstract implementation
101 | """
102 | pass
103 |
104 | @abstractstaticmethod
105 | def is_split_by_sequence() -> bool:
106 | """Return whether the dataset is split by sequence
107 |
108 | For example, the GPT train/valid/test split is document agnostic
109 |
110 | Returns:
111 | bool: See abstract implementation
112 | """
113 | pass
114 |
115 | @classmethod
116 | def is_split_by_document(cls) -> bool:
117 | """Return whether the dataset is split by document
118 |
119 | For example, the BERT train/valid/test split is document aware
120 |
121 | Returns:
122 | bool: The negation of cls.is_split_by_sequence
123 | """
124 | return not cls.is_split_by_sequence()
125 |
126 | @staticmethod
127 | def _key_config_attributes() -> List[str]:
128 | """Return all config attributes which contribute to uniquely identifying the dataset.
129 |
130 | These attributes will be used to build a uniquely identifying string and MD5 hash which
131 | will be used to cache/load the dataset from run to run.
132 |
133 | Returns:
134 | List[str]: The key config attributes
135 | """
136 | return ["split", "random_seed", "sequence_length"]
--------------------------------------------------------------------------------
/code/finetune/core/parse_mixture.py:
--------------------------------------------------------------------------------
1 |
2 | """
3 | # you can run the following command to make DB2TOKCNT readable
4 | autopep8 --in-place --aggressive --aggressive finetune/scripts/parse_mixture.py
5 |
6 | This script is used to parse the mixture of the pretraining data
7 | input: path to the yaml file
8 | output: a megatron style data mixture string
9 | """
10 |
11 | import os
12 | import sys
13 | import argparse
14 | import yaml
15 | import re
16 |
17 |
18 | EXAMPLE_LOG_STRING = """Zarr-based strategies will not be registered because of missing packages
19 | Counting tokens in ./mmap/example.bin
20 |
21 | 0%| | 0/597667 [00:00, ?it/s]
22 | 14%|█▍ | 83737/597667 [00:00<00:00, 837344.85it/s]
23 | 30%|██▉ | 178908/597667 [00:00<00:00, 904601.65it/s]
24 | 45%|████▌ | 269369/597667 [00:00<00:00, 883202.85it/s]
25 | 60%|█████▉ | 357748/597667 [00:00<00:00, 841687.26it/s]
26 | 74%|███████▍ | 442183/597667 [00:00<00:00, 837274.61it/s]
27 | 88%|████████▊ | 528884/597667 [00:00<00:00, 847062.60it/s]
28 | 100%|██████████| 597667/597667 [00:00<00:00, 850432.85it/s]
29 | Total number of tokens: 806001459
30 | """
31 |
32 | global DB2TOKCNT
33 | DB2TOKCNT = {}
34 |
35 | def get_count_logs_paths(logs_dir, pattern='count.*.log'):
36 | return [
37 | os.path.join(
38 | logs_dir,
39 | f) for f in os.listdir(logs_dir) if re.match(
40 | pattern,
41 | f)]
42 |
43 |
44 | def get_tokcnt_from_log(log_path, by_billions=True):
45 | """
46 | input: path to the log file
47 | output: Tuple of (path, token_count)
48 | """
49 | print(f"[INFO] Checking token count log from {log_path}")
50 | match_path_pattern = r'Counting tokens in\s+(.*)'
51 | match_tokcnt_pattern = r'Total number of tokens:\s+(\d+)'
52 |
53 | with open(log_path, 'r') as f:
54 | log = f.read()
55 | path = re.search(match_path_pattern, log).group(1)
56 | tokcnt = int(re.search(match_tokcnt_pattern, log).group(1))
57 | if by_billions:
58 | tokcnt = tokcnt / 1e9
59 | # into string x.xxxB
60 | tokcnt = f"{tokcnt:.3f}B"
61 | return (path, tokcnt)
62 |
63 |
64 | def get_tokcnts_from_logs(logs_dir, by_billions=True):
65 |
66 | logs = get_count_logs_paths(logs_dir)
67 | for log in logs:
68 | db, tokcnt = get_tokcnt_from_log(log, by_billions)
69 | DB2TOKCNT[db] = tokcnt
70 |
71 |
72 | def parse_args():
73 | parser = argparse.ArgumentParser(
74 | description="parse the mixture of the pretraining data")
75 | parser.add_argument(
76 | "--cfg",
77 | "-c",
78 | type=str,
79 | required=True,
80 | help="path to the yaml file")
81 | parser.add_argument(
82 | "--reload-db2tokcnt",
83 | "-r",
84 | action="store_true",
85 | help="DB2TOKCNT is currently hardcoded, reload it from the TOKEN_COUNT_LOG_DIR"
86 | )
87 | parser.add_argument(
88 | "--by-billions",
89 | "-b",
90 | action="store_true",
91 | help="output the tokcnt by billions")
92 | return parser.parse_args()
93 |
94 |
95 | def load_yaml(cfg_path):
96 | with open(cfg_path, "r") as f:
97 | cfg = yaml.load(f, Loader=yaml.FullLoader)
98 | return cfg
99 |
100 |
101 | def parse_mixture_from_cfg_deprecated(cfg):
102 | keys = list(cfg.keys())
103 | # find keys ends with _ROUND
104 | rounds = [k for k in keys if k.endswith("_ROUND")]
105 |
106 | def repeat_str(s, n):
107 | return "".join([s for _ in range(n)])
108 |
109 | total_tokcnt = 0
110 | mixture_str = ""
111 | for r in rounds:
112 | repeat_times = float(r.replace("_ROUND", ""))
113 | mmap_paths = sorted(set(cfg[r]))
114 | for mmap_path in mmap_paths:
115 | mmap_path_without_ext = os.path.splitext(mmap_path)[0]
116 | if repeat_times >= 1:
117 | mixture_str += repeat_str(
118 | f"1 {mmap_path_without_ext} ", int(repeat_times))
119 | else:
120 | # weight is less than 1
121 | mixture_str += f"{repeat_times} {mmap_path_without_ext} "
122 | tokcnt = DB2TOKCNT[mmap_path]
123 | if isinstance(tokcnt, str):
124 | assert tokcnt.endswith("B"), f"invalid tokcnt: {tokcnt}"
125 | tokcnt = float(tokcnt.replace("B", "")) * 10**9
126 | total_tokcnt += tokcnt * repeat_times
127 | else:
128 | assert isinstance(tokcnt, int), f"invalid tokcnt: {tokcnt}"
129 | total_tokcnt += tokcnt * repeat_times
130 |
131 | # total iter count
132 | total_iter = total_tokcnt / (cfg["GLOBAL_BATCH_SIZE"] * cfg["SEQ_LEN"])
133 |
134 | # into string x.xxxB
135 | total_tokcnt /= 1e9
136 | total_tokcnt = f"{total_tokcnt:.3f}B"
137 |
138 | return mixture_str, total_tokcnt, total_iter
139 |
140 |
141 | def parse_mixture_from_cfg(cfg):
142 | keys = list(cfg.keys())
143 | # find keys ends with _ROUND
144 | rounds = [k for k in keys if k.endswith("_ROUND")]
145 |
146 | def repeat_str(s, n):
147 | return "".join([s for _ in range(n)])
148 |
149 | total_tokcnt = 0
150 | mixture_str = ""
151 | for r in rounds:
152 | repeat_times = float(r.replace("_ROUND", ""))
153 | mmap_paths = sorted(set(cfg[r]))
154 | for mmap_path in mmap_paths:
155 | mmap_path_without_ext = os.path.splitext(mmap_path)[0]
156 | tokcnt = DB2TOKCNT[mmap_path]
157 | if isinstance(tokcnt, str):
158 | assert tokcnt.endswith("B"), f"invalid tokcnt: {tokcnt}"
159 | tokcnt = float(tokcnt.replace("B", "")) * 10**9
160 | total_tokcnt += tokcnt * repeat_times
161 | else:
162 | assert isinstance(tokcnt, int), f"invalid tokcnt: {tokcnt}"
163 | total_tokcnt += tokcnt * repeat_times
164 |
165 | mixture_str += f"{int(tokcnt * repeat_times)} {mmap_path_without_ext} "
166 |
167 | # total iter count
168 | total_iter = total_tokcnt / (cfg["GLOBAL_BATCH_SIZE"] * cfg["SEQ_LEN"])
169 |
170 | # into string x.xxxB
171 | total_tokcnt /= 1e9
172 | total_tokcnt = f"{total_tokcnt:.3f}B"
173 |
174 | return mixture_str, total_tokcnt, total_iter
175 |
176 |
177 | if __name__ == "__main__":
178 |
179 | args = parse_args()
180 |
181 | cfg = load_yaml(args.cfg)
182 | print(f"[INFO] Loaded cfg from {args.cfg}")
183 |
184 | TOKEN_COUNT_LOG_DIR = cfg["TOKEN_COUNT_LOG_DIR"]
185 | print(f"[INFO] TOKEN_COUNT_LOG_DIR: {TOKEN_COUNT_LOG_DIR}")
186 |
187 | get_tokcnts_from_logs(TOKEN_COUNT_LOG_DIR,
188 | by_billions=args.by_billions)
189 | print(f"[INFO] DB2TOKCNT reloaded from the logs in {TOKEN_COUNT_LOG_DIR}\n")
190 |
191 | mixture_str, total_tokcnt, total_iter = parse_mixture_from_cfg(cfg)
192 | print(f"[CRITICAL] DATA_PATH **(copy to the training script)**:\n{mixture_str}\n")
193 | print(f"[CRITICAL] TRAIN_ITERS **(copy to the training script)**:\n{total_iter}\n")
194 | print(f"[INFO] Total token count: {total_tokcnt}")
195 |
--------------------------------------------------------------------------------
/code/finetune/core/datasets/blended_megatron_dataset_config.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2 |
3 | import functools
4 | import logging
5 | import re
6 | from dataclasses import dataclass, field
7 | from typing import Callable, List, Optional, Tuple
8 |
9 | import torch
10 |
11 | from core.datasets.utils import Split, log_single_rank, normalize
12 | # from parallel_state import get_virtual_pipeline_model_parallel_rank
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | @dataclass
18 | class BlendedMegatronDatasetConfig:
19 | """Configuration object for megatron-core blended and megatron datasets
20 |
21 | Attributes:
22 | is_built_on_rank (Callable): A callable which returns True if the dataset should be built
23 | on the current rank. It should be Megatron Core parallelism aware i.e. global rank, group
24 | rank, and virtual rank may inform its return value.
25 |
26 | random_seed (int): The seed for all RNG during dataset creation.
27 |
28 | sequence_length (int): The sequence length.
29 |
30 | blend (Optional[List[str]]): The blend string, consisting of either a single dataset or a
31 | flattened sequential sequence of weight-dataset pairs. For exampe, ["dataset-path1"] and
32 | ["50", "dataset-path1", "50", "dataset-path2"] are both valid. Not to be used with
33 | 'blend_per_split'. Defaults to None.
34 |
35 | blend_per_split (blend_per_split: Optional[List[Optional[List[str]]]]): A set of blend
36 | strings, as defined above, one for each split distribution. Not to be used with 'blend'.
37 | Defauls to None.
38 |
39 | split (Optional[str]): The split string, a comma separated weighting for the dataset splits
40 | when drawing samples from a single distribution. Not to be used with 'blend_per_split'.
41 | Defaults to None.
42 |
43 | split_vector: (Optional[List[float]]): The split string, parsed and normalized post-
44 | initialization. Not to be passed to the constructor.
45 |
46 | path_to_cache (str): Where all re-useable dataset indices are to be cached.
47 | """
48 |
49 | is_built_on_rank: Callable
50 |
51 | random_seed: int
52 |
53 | sequence_length: int
54 |
55 | blend: Optional[List[str]] = None
56 |
57 | blend_per_split: Optional[List[Optional[List[str]]]] = None
58 |
59 | split: Optional[str] = None
60 |
61 | split_vector: Optional[List[float]] = field(init=False, default=None)
62 |
63 | path_to_cache: str = None
64 |
65 | def __post_init__(self):
66 | """Python dataclass method that is used to modify attributes after initialization. See
67 | https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details.
68 | """
69 | if torch.distributed.is_initialized():
70 | gb_rank = torch.distributed.get_rank()
71 | # vp_rank = get_virtual_pipeline_model_parallel_rank()
72 | vp_rank = 0
73 | if gb_rank == 0 and (vp_rank == 0 or vp_rank is None):
74 | assert (
75 | self.is_built_on_rank()
76 | ), "is_built_on_rank must return True when global rank = 0 and vp rank = 0"
77 |
78 | if self.blend_per_split is not None and any(self.blend_per_split):
79 | assert self.blend is None, "blend and blend_per_split are incompatible"
80 | assert len(self.blend_per_split) == len(
81 | Split
82 | ), f"blend_per_split must contain {len(Split)} blends"
83 | if self.split is not None:
84 | self.split = None
85 | log_single_rank(logger, logging.WARNING, f"Let split = {self.split}")
86 | else:
87 | assert self.blend is not None, "one of either blend or blend_per_split must be provided"
88 | assert self.split is not None, "both blend and split must be provided"
89 | self.split_vector = _parse_and_normalize_split(self.split)
90 | self.split_matrix = convert_split_vector_to_split_matrix(self.split_vector)
91 | log_single_rank(logger, logging.INFO, f"Let split_vector = {self.split_vector}")
92 |
93 |
94 | @dataclass
95 | class GPTDatasetConfig(BlendedMegatronDatasetConfig):
96 | """Configuration object for megatron-core blended and megatron GPT datasets
97 |
98 | Attributes:
99 | return_document_ids (bool): Whether to return the document ids when querying the dataset.
100 | """
101 |
102 | return_document_ids: bool = False
103 |
104 | add_bos: bool = False
105 |
106 | enable_shuffle: bool = False
107 |
108 |
109 | def _parse_and_normalize_split(split: str) -> List[float]:
110 | """Parse the dataset split ratios from a string
111 |
112 | Args:
113 | split (str): The train valid test split string e.g. "99,1,0"
114 |
115 | Returns:
116 | List[float]: The trian valid test split ratios e.g. [99.0, 1.0, 0.0]
117 | """
118 | split = list(map(float, re.findall(r"[.0-9]+", split)))
119 | split = split + [0.0 for _ in range(len(Split) - len(split))]
120 |
121 | assert len(split) == len(Split)
122 | assert all(map(lambda _: _ >= 0.0, split))
123 |
124 | split = normalize(split)
125 |
126 | return split
127 |
128 |
129 | def convert_split_vector_to_split_matrix(
130 | vector_a: List[float], vector_b: Optional[List[float]] = None
131 | ) -> List[Optional[Tuple[float, float]]]:
132 | """Build the split matrix from one or optionally two contributing split vectors.
133 |
134 | Ex. a standard conversion:
135 |
136 | [0.99, 0.01, 0.0] -> [(0, 0.99), (0.99, 1.0), None]
137 |
138 | Ex. a conversion for Retro when Retro pretraining uses a [0.99, 0.01, 0.0] split and Retro
139 | preprocessing used a [0.98, 0.02, 0.0] split:
140 |
141 | [0.99, 0.01, 0.0], [0.98, 0.02, 0.0] -> [(0, 0.98), (0.99, 1.0), None]
142 |
143 | Args:
144 | vector_a (List[float]): The primary split vector
145 |
146 | vector_b (Optional[List[float]]): An optional secondary split vector which constrains the
147 | primary split vector. Defaults to None.
148 |
149 | Returns:
150 | List[Tuple[float, float]]: The split matrix consisting of book-ends of each split in order
151 | """
152 | if vector_b is None:
153 | vector_b = vector_a
154 |
155 | # [.900, .090, .010] -> [0.00, .900, .990, 100]
156 | expansion_a = functools.reduce(lambda a, b: a + [a[len(a) - 1] + b], [[0], *vector_a])
157 | expansion_b = functools.reduce(lambda a, b: a + [a[len(a) - 1] + b], [[0], *vector_b])
158 |
159 | # [0.00, .900, .990, 100.0] -> [(0.00, .900), (.900, .990), (.990, 100)]
160 | bookends_a = list(zip(expansion_a[:-1], expansion_a[1:]))
161 | bookends_b = list(zip(expansion_b[:-1], expansion_b[1:]))
162 |
163 | # gather per-split overlap or None
164 | matrix = []
165 | for bookend_a, bookend_b in zip(bookends_a, bookends_b):
166 | if min(bookend_a[1], bookend_b[1]) <= max(bookend_a[0], bookend_b[0]):
167 | overlap = None
168 | else:
169 | overlap = (max(bookend_a[0], bookend_b[0]), min(bookend_a[1], bookend_b[1]))
170 | matrix.append(overlap)
171 |
172 | return matrix
173 |
--------------------------------------------------------------------------------
/code/finetune/scripts/run_finetune.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # ==============================
4 | # YuE Fine-tuning Script
5 | # ==============================
6 |
7 | # Help information
8 | print_help() {
9 | echo "========================================================"
10 | echo "YuE Fine-tuning Script Help"
11 | echo "========================================================"
12 | echo "Before running this script, please update the following variables:"
13 | echo ""
14 | echo "1. Data paths:"
15 | echo " DATA_PATH - Replace with actual weights and data paths"
16 | echo " DATA_CACHE_PATH - Replace with actual cache directory"
17 | echo ""
18 | echo "2. Model configuration:"
19 | echo " TOKENIZER_MODEL_PATH - Replace with actual tokenizer path"
20 | echo " MODEL_CACHE_DIR - Replace with actual cache directory"
21 | echo " OUTPUT_DIR - Replace with actual output directory"
22 | echo ""
23 | echo "3. If using WandB:"
24 | echo " WANDB_API_KEY - Replace with your actual API key"
25 | echo ""
26 | echo "Example usage:"
27 | echo " DATA_PATH=\"data1-weight /path/to/data1 data2-weight /path/to/data2\""
28 | echo " DATA_CACHE_PATH=\"/path/to/cache\""
29 | echo " TOKENIZER_MODEL_PATH=\"/path/to/tokenizer\""
30 | echo " MODEL_CACHE_DIR=\"/path/to/model/cache\""
31 | echo " OUTPUT_DIR=\"/path/to/output\""
32 | echo " WANDB_API_KEY=\"your-actual-wandb-key\""
33 | echo "========================================================"
34 | exit 1
35 | }
36 |
37 | # Check if help is requested
38 | if [[ "$1" == "--help" || "$1" == "-h" ]]; then
39 | print_help
40 | fi
41 |
42 | # Check for placeholder values
43 | check_placeholders() {
44 | local has_placeholders=false
45 |
46 | if [[ "$DATA_PATH" == *""* ]]; then
52 | echo "Error: Please set actual data cache path in DATA_CACHE_PATH variable."
53 | has_placeholders=true
54 | fi
55 |
56 | if [[ "$TOKENIZER_MODEL_PATH" == *""* ]]; then
57 | echo "Error: Please set actual tokenizer model path in TOKENIZER_MODEL_PATH variable."
58 | has_placeholders=true
59 | fi
60 |
61 | if [[ "$MODEL_CACHE_DIR" == *""* ]]; then
62 | echo "Error: Please set actual model cache directory in MODEL_CACHE_DIR variable."
63 | has_placeholders=true
64 | fi
65 |
66 | if [[ "$OUTPUT_DIR" == *""* ]]; then
67 | echo "Error: Please set actual output directory in OUTPUT_DIR variable."
68 | has_placeholders=true
69 | fi
70 |
71 | if [[ "$USE_WANDB" == "true" && "$WANDB_API_KEY" == *""* ]]; then
72 | echo "Error: Please set actual WandB API key in WANDB_API_KEY variable or disable WandB."
73 | has_placeholders=true
74 | fi
75 |
76 | if [[ "$has_placeholders" == "true" ]]; then
77 | echo ""
78 | echo "Please update the script with your actual paths and values."
79 | echo "Run './scripts/run_finetune.sh --help' for more information."
80 | exit 1
81 | fi
82 | }
83 |
84 | # Exit on error
85 | set -e
86 |
87 | # Check if we're in the finetune directory
88 | CURRENT_DIR=$(basename "$PWD")
89 | if [ "$CURRENT_DIR" != "finetune" ]; then
90 | echo "Error: This script must be run from the finetune/ directory"
91 | echo "Current directory: $PWD"
92 | echo "Please change to the finetune directory and try again"
93 | exit 1
94 | fi
95 |
96 | # ==============================
97 | # Configuration Parameters
98 | # ==============================
99 |
100 | # Hardware configuration
101 | NUM_GPUS=8
102 | MASTER_PORT=9999
103 | # Uncomment and modify if you need specific GPUs
104 | # export CUDA_VISIBLE_DEVICES=4,5,6,7
105 |
106 | # Training hyperparameters
107 | PER_DEVICE_TRAIN_BATCH_SIZE=1
108 | PER_DEVICE_EVAL_BATCH_SIZE=1
109 | GLOBAL_BATCH_SIZE=$((NUM_GPUS*PER_DEVICE_TRAIN_BATCH_SIZE))
110 | USE_BF16=true
111 | SEQ_LENGTH=8192
112 | TRAIN_ITERS=150
113 | NUM_TRAIN_EPOCHS=10
114 |
115 | # Data paths (replace with your actual paths)
116 | DATA_PATH=""
117 | DATA_CACHE_PATH=""
118 |
119 | # Set comma-separated list of proportions for training, validation, and test split
120 | DATA_SPLIT="900,50,50"
121 |
122 | # Model configuration
123 | TOKENIZER_MODEL_PATH=""
124 | MODEL_NAME="m-a-p/YuE-s1-7B-anneal-en-cot"
125 | MODEL_CACHE_DIR=""
126 | OUTPUT_DIR=""
127 | DEEPSPEED_CONFIG=config/ds_config_zero2.json
128 |
129 | # LoRA configuration
130 | LORA_R=64
131 | LORA_ALPHA=32
132 | LORA_DROPOUT=0.1
133 | LORA_TARGET_MODULES="q_proj k_proj v_proj o_proj"
134 | # Logging configuration
135 | LOGGING_STEPS=5
136 | SAVE_STEPS=5
137 | USE_WANDB=true
138 | WANDB_API_KEY=""
139 | RUN_NAME="YuE-ft-lora"
140 |
141 | # ==============================
142 | # Environment Setup
143 | # ==============================
144 |
145 | # Check for placeholder values
146 | check_placeholders
147 |
148 | # Export environment variables
149 | export WANDB_API_KEY=$WANDB_API_KEY
150 | export PYTHONPATH=$PWD:$PYTHONPATH
151 |
152 | # Print configuration
153 | echo "==============================================="
154 | echo "YuE Fine-tuning Configuration:"
155 | echo "==============================================="
156 | echo "Number of GPUs: $NUM_GPUS"
157 | echo "Global batch size: $GLOBAL_BATCH_SIZE"
158 | echo "Model: $MODEL_NAME"
159 | echo "Output directory: $OUTPUT_DIR"
160 | echo "Training epochs: $NUM_TRAIN_EPOCHS"
161 | echo "==============================================="
162 |
163 | # ==============================
164 | # Build and Execute Command
165 | # ==============================
166 |
167 | # Base command
168 | CMD="torchrun --nproc_per_node=$NUM_GPUS --master_port=$MASTER_PORT scripts/train_lora.py \
169 | --seq-length $SEQ_LENGTH \
170 | --data-path $DATA_PATH \
171 | --data-cache-path $DATA_CACHE_PATH \
172 | --split $DATA_SPLIT \
173 | --tokenizer-model $TOKENIZER_MODEL_PATH \
174 | --global-batch-size $GLOBAL_BATCH_SIZE \
175 | --per-device-train-batch-size $PER_DEVICE_TRAIN_BATCH_SIZE \
176 | --per-device-eval-batch-size $PER_DEVICE_EVAL_BATCH_SIZE \
177 | --train-iters $TRAIN_ITERS \
178 | --num-train-epochs $NUM_TRAIN_EPOCHS \
179 | --logging-steps $LOGGING_STEPS \
180 | --save-steps $SAVE_STEPS \
181 | --deepspeed $DEEPSPEED_CONFIG"
182 |
183 | # Add conditional arguments
184 | if [ "$USE_WANDB" = true ]; then
185 | CMD="$CMD --report-to wandb --run-name \"$RUN_NAME\""
186 | elif [ "$USE_WANDB" = false ]; then
187 | CMD="$CMD --report-to none"
188 | fi
189 |
190 | CMD="$CMD \
191 | --model-name-or-path \"$MODEL_NAME\" \
192 | --cache-dir $MODEL_CACHE_DIR \
193 | --output-dir $OUTPUT_DIR \
194 | --lora-r $LORA_R \
195 | --lora-alpha $LORA_ALPHA \
196 | --lora-dropout $LORA_DROPOUT \
197 | --lora-target-modules $LORA_TARGET_MODULES"
198 |
199 | if [ "$USE_BF16" = true ]; then
200 | CMD="$CMD --bf16"
201 | fi
202 |
203 | # Execute the command
204 | echo "Running command: $CMD"
205 | echo "==============================================="
206 | eval $CMD
207 |
208 | # Check exit status
209 | if [ $? -eq 0 ]; then
210 | echo "==============================================="
211 | echo "Fine-tuning completed successfully!"
212 | echo "Output saved to: $OUTPUT_DIR"
213 | echo "==============================================="
214 | else
215 | echo "==============================================="
216 | echo "Error: Fine-tuning failed with exit code $?"
217 | echo "==============================================="
218 | exit 1
219 | fi
--------------------------------------------------------------------------------
/code/finetune/core/datasets/blended_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2 |
3 | import hashlib
4 | import json
5 | import logging
6 | import os
7 | import time
8 | from collections import OrderedDict
9 | from typing import Dict, List, Tuple, Union
10 |
11 | import numpy
12 | import torch
13 |
14 | from core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
15 | from core.datasets.megatron_dataset import MegatronDataset
16 | from core.datasets.utils import log_single_rank, normalize
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 | _VERBOSE = False
21 |
22 |
23 | class BlendedDataset(torch.utils.data.Dataset):
24 | """Conjugating class for a set of MegatronDataset instances
25 |
26 | Args:
27 | datasets (List[MegatronDataset]): The MegatronDataset instances to blend
28 |
29 | weights (List[float]): The weights which determines the dataset blend ratios
30 |
31 | size (int): The number of samples to draw from the blend
32 |
33 | config (BlendedMegatronDatasetConfig): The config object which informs dataset creation
34 |
35 | Raises:
36 | RuntimeError: When the dataset has fewer or more samples than 'size' post-initialization
37 | """
38 |
39 | def __init__(
40 | self,
41 | datasets: List[MegatronDataset],
42 | weights: List[float],
43 | size: int,
44 | config: BlendedMegatronDatasetConfig,
45 | ) -> None:
46 | assert len(datasets) < 32767
47 | assert len(datasets) == len(weights)
48 | assert numpy.isclose(sum(weights), 1.0)
49 | assert all(map(lambda _: type(_) == type(datasets[0]), datasets))
50 |
51 | # Alert user to unnecessary blending
52 | if len(datasets) == 1:
53 | log_single_rank(
54 | logger, logging.WARNING, f"Building a BlendedDataset for a single MegatronDataset"
55 | )
56 |
57 | # Redundant normalization for bitwise identical comparison with Megatron-LM
58 | weights = normalize(weights)
59 |
60 | self.datasets = datasets
61 | self.weights = weights
62 | self.size = size
63 | self.config = config
64 |
65 | unique_identifiers = OrderedDict()
66 | unique_identifiers["class"] = type(self).__name__
67 | unique_identifiers["datasets"] = [dataset.unique_identifiers for dataset in self.datasets]
68 | unique_identifiers["weights"] = self.weights
69 | unique_identifiers["size"] = self.size
70 |
71 | self.unique_description = json.dumps(unique_identifiers, indent=4)
72 | self.unique_description_hash = hashlib.md5(
73 | self.unique_description.encode("utf-8")
74 | ).hexdigest()
75 |
76 | self.dataset_index, self.dataset_sample_index = self._build_indices()
77 |
78 | # Check size
79 | _ = self[self.size - 1]
80 | try:
81 | _ = self[self.size]
82 | raise RuntimeError(f"{type(self).__name__} size is improperly bounded")
83 | except IndexError:
84 | log_single_rank(logger, logging.INFO, f"> {type(self).__name__} length: {len(self)}")
85 |
86 | def __len__(self) -> int:
87 | return self.size
88 |
89 | def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]:
90 | dataset_id = self.dataset_index[idx]
91 | dataset_sample_id = self.dataset_sample_index[idx]
92 | return {
93 | "dataset_id": dataset_id,
94 | **self.datasets[dataset_id][dataset_sample_id],
95 | }
96 |
97 | def _build_indices(self) -> Tuple[numpy.ndarray, numpy.ndarray]:
98 | """Build and optionally cache the dataset index and the dataset sample index
99 |
100 | The dataset index is a 1-D mapping which determines the dataset to query. The dataset
101 | sample index is a 1-D mapping which determines the sample to request from the queried
102 | dataset.
103 |
104 | Returns:
105 | Tuple[numpy.ndarray, numpy.ndarray]: The dataset index and the dataset sample index
106 | """
107 | path_to_cache = self.config.path_to_cache
108 |
109 | if path_to_cache:
110 | get_path_to = lambda suffix: os.path.join(
111 | path_to_cache, f"{self.unique_description_hash}-{type(self).__name__}-{suffix}"
112 | )
113 | path_to_description = get_path_to("description.txt")
114 | path_to_dataset_index = get_path_to("dataset_index.npy")
115 | path_to_dataset_sample_index = get_path_to("dataset_sample_index.npy")
116 | cache_hit = all(
117 | map(
118 | os.path.isfile,
119 | [path_to_description, path_to_dataset_index, path_to_dataset_sample_index],
120 | )
121 | )
122 | else:
123 | cache_hit = False
124 |
125 | if not path_to_cache or (not cache_hit and torch.distributed.get_rank() == 0):
126 | log_single_rank(
127 | logger, logging.INFO, f"Build and save the {type(self).__name__} indices",
128 | )
129 |
130 | # Build the dataset and dataset sample indexes
131 | log_single_rank(
132 | logger, logging.INFO, f"\tBuild and save the dataset and dataset sample indexes"
133 | )
134 | t_beg = time.time()
135 | from core.datasets import helpers
136 |
137 | dataset_index = numpy.zeros(self.size, dtype=numpy.int16)
138 | dataset_sample_index = numpy.zeros(self.size, dtype=numpy.int64)
139 | helpers.build_blending_indices(
140 | dataset_index,
141 | dataset_sample_index,
142 | self.weights,
143 | len(self.datasets),
144 | self.size,
145 | _VERBOSE,
146 | )
147 |
148 | if path_to_cache:
149 | os.makedirs(path_to_cache, exist_ok=True)
150 | # Write the description
151 | with open(path_to_description, "wt") as writer:
152 | writer.write(self.unique_description)
153 | # Save the indexes
154 | numpy.save(path_to_dataset_index, dataset_index, allow_pickle=True)
155 | numpy.save(path_to_dataset_sample_index, dataset_sample_index, allow_pickle=True)
156 | else:
157 | log_single_rank(
158 | logger,
159 | logging.WARNING,
160 | "Unable to save the indexes because path_to_cache is None",
161 | )
162 |
163 | t_end = time.time()
164 | log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
165 |
166 | return dataset_index, dataset_sample_index
167 |
168 | log_single_rank(logger, logging.INFO, f"Load the {type(self).__name__} indices")
169 |
170 | log_single_rank(
171 | logger, logging.INFO, f"\tLoad the dataset index from {path_to_dataset_index}"
172 | )
173 | t_beg = time.time()
174 | dataset_index = numpy.load(path_to_dataset_index, allow_pickle=True, mmap_mode='r')
175 | t_end = time.time()
176 | log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
177 |
178 | log_single_rank(
179 | logger,
180 | logging.INFO,
181 | f"\tLoad the dataset sample index from {path_to_dataset_sample_index}",
182 | )
183 | t_beg = time.time()
184 | dataset_sample_index = numpy.load(
185 | path_to_dataset_sample_index, allow_pickle=True, mmap_mode='r'
186 | )
187 | t_end = time.time()
188 | log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
189 |
190 | return dataset_index, dataset_sample_index
191 |
--------------------------------------------------------------------------------
/code/finetune/tools/codecmanipulator.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | import einops
4 |
5 |
6 | class CodecManipulator(object):
7 | r"""
8 | **mm tokenizer v0.1**
9 | see codeclm/hf/mm_tokenizer_v0.1_hf/id2vocab.json
10 |
11 | text tokens:
12 | llama tokenizer 0~31999
13 |
14 | special tokens: "32000": "", "32001": "", "32002": "", "32003": "", "32004": "", "32005": "", "32006": "", "32007": "", "32008": "", "32009": "", "32010": "", "32011": "", "32012": "", "32013": "", "32014": "", "32015": "", "32016": "", "32017": "", "32018": "", "32019": "", "32020": "", "32021": ""
15 |
16 | mm tokens:
17 | dac_16k: 4 codebook, 1024 vocab, 32022 - 36117
18 | dac_44k: 9 codebook, 1024 vocab, 36118 - 45333
19 | xcodec: 12 codebook, 1024 vocab, 45334 - 57621
20 | semantic mert: 1024, 57622 - 58645
21 | semantic hubert: 512, 58646 - 59157
22 | visual: 64000, not included in v0.1
23 | semanticodec 100tps 16384: semantic=16384, 59158 - 75541, acoustic=8192, 75542 - 83733
24 | """
25 | def __init__(self, codec_type, quantizer_begin=None, n_quantizer=None, teacher_forcing=False, data_feature="codec"):
26 | self.codec_type = codec_type
27 | self.mm_v0_2_cfg = {
28 | "dac16k": {"codebook_size": 1024, "num_codebooks": 4, "global_offset": 32022, "sep": [""], "fps": 50},
29 | "dac44k": {"codebook_size": 1024, "num_codebooks": 9, "global_offset": 36118, "sep": [""]},
30 | "xcodec": {"codebook_size": 1024, "num_codebooks": 12, "global_offset": 45334, "sep": [""], "fps": 50},
31 | "mert": {"codebook_size": 1024, "global_offset": 57622, "sep": [""]},
32 | "hubert": {"codebook_size": 512, "global_offset": 58646, "sep": [""]},
33 | "semantic/s": {"codebook_size": 16384, "num_codebooks": 1, "global_offset": 59158, "sep": ["", ""]},
34 | "semantic/a": {"codebook_size": 8192, "num_codebooks": 1, "global_offset": 75542, "sep": ["", ""]},
35 | "semanticodec": {"codebook_size": [16384, 8192], "num_codebooks": 2, "global_offset": 59158, "sep": [""], "fps": 50},
36 | "special_tokens": {
37 | '': 32000, '': 32001, '': 32002, '': 32003, '': 32004, '': 32005, '': 32006, '': 32007, '': 32008, '': 32009, '': 32010, '': 32011, '': 32012, '': 32013, '': 32014, '': 32015, '': 32016, '': 32017, '': 32018, '': 32019, '': 32020, '': 32021
38 | },
39 | "metadata": {
40 | "len": 83734,
41 | "text_range": [0, 31999],
42 | "special_range": [32000, 32021],
43 | "mm_range": [32022, 83733]
44 | },
45 | "codec_range": {
46 | "dac16k": [32022, 36117],
47 | "dac44k": [36118, 45333],
48 | "xcodec": [45334, 57621],
49 | # "hifi16k": [53526, 57621],
50 | "mert": [57622, 58645],
51 | "hubert": [58646, 59157],
52 | "semantic/s": [59158, 75541],
53 | "semantic/a": [75542, 83733],
54 | "semanticodec": [59158, 83733]
55 | }
56 | }
57 | self.sep = self.mm_v0_2_cfg[self.codec_type]["sep"]
58 | self.sep_ids = [self.mm_v0_2_cfg["special_tokens"][s] for s in self.sep]
59 | self.codebook_size = self.mm_v0_2_cfg[self.codec_type]["codebook_size"]
60 | self.num_codebooks = self.mm_v0_2_cfg[self.codec_type]["num_codebooks"]
61 | self.global_offset = self.mm_v0_2_cfg[self.codec_type]["global_offset"]
62 | self.fps = self.mm_v0_2_cfg[self.codec_type]["fps"] if "fps" in self.mm_v0_2_cfg[self.codec_type] else None
63 |
64 | self.quantizer_begin = quantizer_begin if quantizer_begin is not None else 0
65 | self.n_quantizer = n_quantizer if n_quantizer is not None else self.num_codebooks
66 | self.teacher_forcing = teacher_forcing
67 | self.data_feature = data_feature
68 |
69 |
70 | def offset_tok_ids(self, x, global_offset=0, codebook_size=2048, num_codebooks=4):
71 | """
72 | x: (K, T)
73 | """
74 | if isinstance(codebook_size, int):
75 | assert x.max() < codebook_size, f"max(x)={x.max()}, codebook_size={codebook_size}"
76 | elif isinstance(codebook_size, list):
77 | for i, cs in enumerate(codebook_size):
78 | assert x[i].max() < cs, f"max(x)={x[i].max()}, codebook_size={cs}, layer_id={i}"
79 | else:
80 | raise ValueError(f"codebook_size={codebook_size}")
81 | assert x.min() >= 0, f"min(x)={x.min()}"
82 | assert x.shape[0] == num_codebooks or x.shape[0] == self.n_quantizer, \
83 | f"x.shape[0]={x.shape[0]}, num_codebooks={num_codebooks}, n_quantizer={self.n_quantizer}"
84 |
85 | _x = x.copy()
86 | _x = _x.astype(np.uint32)
87 | cum_offset = 0
88 | quantizer_begin = self.quantizer_begin
89 | quantizer_end = quantizer_begin+self.n_quantizer
90 | for k in range(self.quantizer_begin, quantizer_end): # k: quantizer_begin to quantizer_end - 1
91 | if isinstance(codebook_size, int):
92 | _x[k] += global_offset + k * codebook_size
93 | elif isinstance(codebook_size, list):
94 | _x[k] += global_offset + cum_offset
95 | cum_offset += codebook_size[k]
96 | else:
97 | raise ValueError(f"codebook_size={codebook_size}")
98 | return _x[quantizer_begin:quantizer_end]
99 |
100 | def unoffset_tok_ids(self, x, global_offset=0, codebook_size=2048, num_codebooks=4):
101 | """
102 | x: (K, T)
103 | """
104 | if isinstance(codebook_size, int):
105 | assert x.max() < global_offset + codebook_size * num_codebooks, f"max(x)={x.max()}, codebook_size={codebook_size}"
106 | elif isinstance(codebook_size, list):
107 | assert x.max() < global_offset + sum(codebook_size), f"max(x)={x.max()}, codebook_size={codebook_size}"
108 | assert x.min() >= global_offset, f"min(x)={x.min()}, global_offset={global_offset}"
109 | assert x.shape[0] == num_codebooks or x.shape[0] == self.n_quantizer, \
110 | f"x.shape[0]={x.shape[0]}, num_codebooks={num_codebooks}, n_quantizer={self.n_quantizer}"
111 |
112 | _x = x.copy()
113 | _x = _x.astype(np.uint32)
114 | cum_offset = 0
115 | quantizer_begin = self.quantizer_begin
116 | quantizer_end = quantizer_begin+self.n_quantizer
117 | for k in range(quantizer_begin, quantizer_end):
118 | if isinstance(codebook_size, int):
119 | _x[k-quantizer_begin] -= global_offset + k * codebook_size
120 | elif isinstance(codebook_size, list):
121 | _x[k-quantizer_begin] -= global_offset + cum_offset
122 | cum_offset += codebook_size[k]
123 | else:
124 | raise ValueError(f"codebook_size={codebook_size}")
125 | return _x
126 |
127 | def flatten(self, x):
128 | if len(x.shape) > 2:
129 | x = x.squeeze()
130 | assert x.shape[0] == self.num_codebooks or x.shape[0] == self.n_quantizer, \
131 | f"x.shape[0]={x.shape[0]}, num_codebooks={self.num_codebooks}, n_quantizer={self.n_quantizer}"
132 | return einops.rearrange(x, 'K T -> (T K)')
133 |
134 | def unflatten(self, x, n_quantizer=None):
135 | if x.ndim > 1 and x.shape[0] == 1:
136 | x = x.squeeze(0)
137 | assert len(x.shape) == 1
138 | assert x.shape[0] % self.num_codebooks == 0 or x.shape[0] % self.n_quantizer == 0, \
139 | f"x.shape[0]={x.shape[0]}, num_codebooks={self.num_codebooks}, n_quantizer={self.n_quantizer}"
140 | if n_quantizer!=self.num_codebooks:
141 | return einops.rearrange(x, '(T K) -> K T', K=n_quantizer)
142 | return einops.rearrange(x, '(T K) -> K T', K=self.num_codebooks)
143 |
144 | # def check_codec_type_from_path(self, path):
145 | # if self.codec_type == "hifi16k":
146 | # assert "academicodec_hifi_16k_320d_large_uni" in path
147 |
148 | def get_codec_type_from_range(self, ids):
149 | ids_range = [ids.min(), ids.max()]
150 | codec_range = self.mm_v0_2_cfg["codec_range"]
151 | for codec_type, r in codec_range.items():
152 | if ids_range[0] >= r[0] and ids_range[1] <= r[1]:
153 | return codec_type
154 | raise ValueError(f"ids_range={ids_range}, codec_range={codec_range}")
155 |
156 | def npy2ids(self, npy):
157 | if isinstance(npy, str):
158 | data = np.load(npy)
159 | elif isinstance(npy, np.ndarray):
160 | data = npy
161 | else:
162 | raise ValueError(f"not supported type: {type(npy)}")
163 |
164 | assert len(data.shape)==2, f'data shape: {data.shape} is not (n_codebook, seq_len)'
165 | data = self.offset_tok_ids(
166 | data,
167 | global_offset=self.global_offset,
168 | codebook_size=self.codebook_size,
169 | num_codebooks=self.num_codebooks,
170 | )
171 | data = self.flatten(data)
172 | codec_range = self.get_codec_type_from_range(data)
173 | assert codec_range == self.codec_type, f"get_codec_type_from_range(data)={codec_range}, self.codec_type={self.codec_type}"
174 | data = data.tolist()
175 | return data
176 |
177 | def ids2npy(self, token_ids):
178 | # make sure token_ids starts with codebook 0
179 | if isinstance(self.codebook_size, int):
180 | codebook_0_range = (self.global_offset + self.quantizer_begin*self.codebook_size, self.global_offset + (self.quantizer_begin+1)*self.codebook_size)
181 | elif isinstance(self.codebook_size, list):
182 | codebook_0_range = (self.global_offset, self.global_offset + self.codebook_size[0])
183 | assert token_ids[0] >= codebook_0_range[0] \
184 | and token_ids[0] < codebook_0_range[1], f"token_ids[0]={token_ids[self.quantizer_begin]}, codebook_0_range={codebook_0_range}"
185 | data = np.array(token_ids)
186 | data = self.unflatten(data, n_quantizer=self.n_quantizer)
187 | data = self.unoffset_tok_ids(
188 | data,
189 | global_offset=self.global_offset,
190 | codebook_size=self.codebook_size,
191 | num_codebooks=self.num_codebooks,
192 | )
193 | return data
194 |
195 | def npy_to_json_str(self, npy_path):
196 | data = self.npy2ids(npy_path)
197 | return json.dumps({"text": data, "src": npy_path, "codec": self.codec_type})
198 |
199 | def sep(self):
200 | return ''.join(self.sep)
201 |
202 | def sep_ids(self):
203 | return self.sep_ids
204 |
--------------------------------------------------------------------------------
/code/inference/codecmanipulator.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | import einops
4 |
5 |
6 | class CodecManipulator(object):
7 | r"""
8 | **mm tokenizer v0.1**
9 | see codeclm/hf/mm_tokenizer_v0.1_hf/id2vocab.json
10 |
11 | text tokens:
12 | llama tokenizer 0~31999
13 |
14 | special tokens: "32000": "", "32001": "", "32002": "", "32003": "", "32004": "", "32005": "", "32006": "", "32007": "", "32008": "", "32009": "", "32010": "", "32011": "", "32012": "", "32013": "", "32014": "", "32015": "", "32016": "", "32017": "", "32018": "", "32019": "", "32020": "", "32021": ""
15 |
16 | mm tokens:
17 | dac_16k: 4 codebook, 1024 vocab, 32022 - 36117
18 | dac_44k: 9 codebook, 1024 vocab, 36118 - 45333
19 | xcodec: 12 codebook, 1024 vocab, 45334 - 57621
20 | semantic mert: 1024, 57622 - 58645
21 | semantic hubert: 512, 58646 - 59157
22 | visual: 64000, not included in v0.1
23 | semanticodec 100tps 16384: semantic=16384, 59158 - 75541, acoustic=8192, 75542 - 83733
24 | """
25 | def __init__(self, codec_type, quantizer_begin=None, n_quantizer=None, teacher_forcing=False, data_feature="codec"):
26 | self.codec_type = codec_type
27 | self.mm_v0_2_cfg = {
28 | "dac16k": {"codebook_size": 1024, "num_codebooks": 4, "global_offset": 32022, "sep": [""], "fps": 50},
29 | "dac44k": {"codebook_size": 1024, "num_codebooks": 9, "global_offset": 36118, "sep": [""]},
30 | "xcodec": {"codebook_size": 1024, "num_codebooks": 12, "global_offset": 45334, "sep": [""], "fps": 50},
31 | "mert": {"codebook_size": 1024, "global_offset": 57622, "sep": [""]},
32 | "hubert": {"codebook_size": 512, "global_offset": 58646, "sep": [""]},
33 | "semantic/s": {"codebook_size": 16384, "num_codebooks": 1, "global_offset": 59158, "sep": ["", ""]},
34 | "semantic/a": {"codebook_size": 8192, "num_codebooks": 1, "global_offset": 75542, "sep": ["", ""]},
35 | "semanticodec": {"codebook_size": [16384, 8192], "num_codebooks": 2, "global_offset": 59158, "sep": [""], "fps": 50},
36 | "special_tokens": {
37 | '': 32000, '': 32001, '': 32002, '': 32003, '': 32004, '': 32005, '': 32006, '': 32007, '': 32008, '': 32009, '': 32010, '': 32011, '': 32012, '': 32013, '': 32014, '': 32015, '': 32016, '': 32017, '': 32018, '': 32019, '': 32020, '': 32021
38 | },
39 | "metadata": {
40 | "len": 83734,
41 | "text_range": [0, 31999],
42 | "special_range": [32000, 32021],
43 | "mm_range": [32022, 83733]
44 | },
45 | "codec_range": {
46 | "dac16k": [32022, 36117],
47 | "dac44k": [36118, 45333],
48 | "xcodec": [45334, 57621],
49 | # "hifi16k": [53526, 57621],
50 | "mert": [57622, 58645],
51 | "hubert": [58646, 59157],
52 | "semantic/s": [59158, 75541],
53 | "semantic/a": [75542, 83733],
54 | "semanticodec": [59158, 83733]
55 | }
56 | }
57 | self.sep = self.mm_v0_2_cfg[self.codec_type]["sep"]
58 | self.sep_ids = [self.mm_v0_2_cfg["special_tokens"][s] for s in self.sep]
59 | self.codebook_size = self.mm_v0_2_cfg[self.codec_type]["codebook_size"]
60 | self.num_codebooks = self.mm_v0_2_cfg[self.codec_type]["num_codebooks"]
61 | self.global_offset = self.mm_v0_2_cfg[self.codec_type]["global_offset"]
62 | self.fps = self.mm_v0_2_cfg[self.codec_type]["fps"] if "fps" in self.mm_v0_2_cfg[self.codec_type] else None
63 |
64 | self.quantizer_begin = quantizer_begin if quantizer_begin is not None else 0
65 | self.n_quantizer = n_quantizer if n_quantizer is not None else self.num_codebooks
66 | self.teacher_forcing = teacher_forcing
67 | self.data_feature = data_feature
68 |
69 |
70 | def offset_tok_ids(self, x, global_offset=0, codebook_size=2048, num_codebooks=4):
71 | """
72 | x: (K, T)
73 | """
74 | if isinstance(codebook_size, int):
75 | assert x.max() < codebook_size, f"max(x)={x.max()}, codebook_size={codebook_size}"
76 | elif isinstance(codebook_size, list):
77 | for i, cs in enumerate(codebook_size):
78 | assert x[i].max() < cs, f"max(x)={x[i].max()}, codebook_size={cs}, layer_id={i}"
79 | else:
80 | raise ValueError(f"codebook_size={codebook_size}")
81 | assert x.min() >= 0, f"min(x)={x.min()}"
82 | assert x.shape[0] == num_codebooks or x.shape[0] == self.n_quantizer, \
83 | f"x.shape[0]={x.shape[0]}, num_codebooks={num_codebooks}, n_quantizer={self.n_quantizer}"
84 |
85 | _x = x.copy()
86 | _x = _x.astype(np.uint32)
87 | cum_offset = 0
88 | quantizer_begin = self.quantizer_begin
89 | quantizer_end = quantizer_begin+self.n_quantizer
90 | for k in range(self.quantizer_begin, quantizer_end): # k: quantizer_begin to quantizer_end - 1
91 | if isinstance(codebook_size, int):
92 | _x[k] += global_offset + k * codebook_size
93 | elif isinstance(codebook_size, list):
94 | _x[k] += global_offset + cum_offset
95 | cum_offset += codebook_size[k]
96 | else:
97 | raise ValueError(f"codebook_size={codebook_size}")
98 | return _x[quantizer_begin:quantizer_end]
99 |
100 | def unoffset_tok_ids(self, x, global_offset=0, codebook_size=2048, num_codebooks=4):
101 | """
102 | x: (K, T)
103 | """
104 | if isinstance(codebook_size, int):
105 | assert x.max() < global_offset + codebook_size * num_codebooks, f"max(x)={x.max()}, codebook_size={codebook_size}"
106 | elif isinstance(codebook_size, list):
107 | assert x.max() < global_offset + sum(codebook_size), f"max(x)={x.max()}, codebook_size={codebook_size}"
108 | assert x.min() >= global_offset, f"min(x)={x.min()}, global_offset={global_offset}"
109 | assert x.shape[0] == num_codebooks or x.shape[0] == self.n_quantizer, \
110 | f"x.shape[0]={x.shape[0]}, num_codebooks={num_codebooks}, n_quantizer={self.n_quantizer}"
111 |
112 | _x = x.copy()
113 | _x = _x.astype(np.uint32)
114 | cum_offset = 0
115 | quantizer_begin = self.quantizer_begin
116 | quantizer_end = quantizer_begin+self.n_quantizer
117 | for k in range(quantizer_begin, quantizer_end):
118 | if isinstance(codebook_size, int):
119 | _x[k-quantizer_begin] -= global_offset + k * codebook_size
120 | elif isinstance(codebook_size, list):
121 | _x[k-quantizer_begin] -= global_offset + cum_offset
122 | cum_offset += codebook_size[k]
123 | else:
124 | raise ValueError(f"codebook_size={codebook_size}")
125 | return _x
126 |
127 | def flatten(self, x):
128 | if len(x.shape) > 2:
129 | x = x.squeeze()
130 | assert x.shape[0] == self.num_codebooks or x.shape[0] == self.n_quantizer, \
131 | f"x.shape[0]={x.shape[0]}, num_codebooks={self.num_codebooks}, n_quantizer={self.n_quantizer}"
132 | return einops.rearrange(x, 'K T -> (T K)')
133 |
134 | def unflatten(self, x, n_quantizer=None):
135 | if x.ndim > 1 and x.shape[0] == 1:
136 | x = x.squeeze(0)
137 | assert len(x.shape) == 1
138 | assert x.shape[0] % self.num_codebooks == 0 or x.shape[0] % self.n_quantizer == 0, \
139 | f"x.shape[0]={x.shape[0]}, num_codebooks={self.num_codebooks}, n_quantizer={self.n_quantizer}"
140 | if n_quantizer!=self.num_codebooks:
141 | return einops.rearrange(x, '(T K) -> K T', K=n_quantizer)
142 | return einops.rearrange(x, '(T K) -> K T', K=self.num_codebooks)
143 |
144 | # def check_codec_type_from_path(self, path):
145 | # if self.codec_type == "hifi16k":
146 | # assert "academicodec_hifi_16k_320d_large_uni" in path
147 |
148 | def get_codec_type_from_range(self, ids):
149 | ids_range = [ids.min(), ids.max()]
150 | codec_range = self.mm_v0_2_cfg["codec_range"]
151 | for codec_type, r in codec_range.items():
152 | if ids_range[0] >= r[0] and ids_range[1] <= r[1]:
153 | return codec_type
154 | raise ValueError(f"ids_range={ids_range}, codec_range={codec_range}")
155 |
156 | def npy2ids(self, npy):
157 | if isinstance(npy, str):
158 | data = np.load(npy)
159 | elif isinstance(npy, np.ndarray):
160 | data = npy
161 | else:
162 | raise ValueError(f"not supported type: {type(npy)}")
163 | # data = data.squeeze()
164 |
165 | assert len(data.shape)==2, f'data shape: {data.shape} is not (n_codebook, seq_len)'
166 | data = self.offset_tok_ids(
167 | data,
168 | global_offset=self.global_offset,
169 | codebook_size=self.codebook_size,
170 | num_codebooks=self.num_codebooks,
171 | )
172 | data = self.flatten(data)
173 | codec_range = self.get_codec_type_from_range(data)
174 | assert codec_range == self.codec_type, f"get_codec_type_from_range(data)={codec_range}, self.codec_type={self.codec_type}"
175 | data = data.tolist()
176 | return data
177 |
178 | def ids2npy(self, token_ids):
179 | # make sure token_ids starts with codebook 0
180 | if isinstance(self.codebook_size, int):
181 | codebook_0_range = (self.global_offset + self.quantizer_begin*self.codebook_size, self.global_offset + (self.quantizer_begin+1)*self.codebook_size)
182 | elif isinstance(self.codebook_size, list):
183 | codebook_0_range = (self.global_offset, self.global_offset + self.codebook_size[0])
184 | assert token_ids[0] >= codebook_0_range[0] \
185 | and token_ids[0] < codebook_0_range[1], f"token_ids[0]={token_ids[self.quantizer_begin]}, codebook_0_range={codebook_0_range}"
186 | data = np.array(token_ids)
187 | data = self.unflatten(data, n_quantizer=self.n_quantizer)
188 | data = self.unoffset_tok_ids(
189 | data,
190 | global_offset=self.global_offset,
191 | codebook_size=self.codebook_size,
192 | num_codebooks=self.num_codebooks,
193 | )
194 | return data
195 |
196 | def npy_to_json_str(self, npy_path):
197 | data = self.npy2ids(npy_path)
198 | return json.dumps({"text": data, "src": npy_path, "codec": self.codec_type})
199 |
200 | def sep(self):
201 | return ''.join(self.sep)
202 |
203 | def sep_ids(self):
204 | return self.sep_ids
205 |
--------------------------------------------------------------------------------
/code/inference/mmtokenizer.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 | from abc import abstractmethod
3 |
4 |
5 | class AbstractTokenizer(ABC):
6 | """Abstract class for tokenizer."""
7 |
8 | def __init__(self, name):
9 | self.name = name
10 | super().__init__()
11 |
12 | @property
13 | @abstractmethod
14 | def vocab_size(self):
15 | pass
16 |
17 | @property
18 | @abstractmethod
19 | def vocab(self):
20 | """Dictionary from vocab text token to id token."""
21 | pass
22 |
23 | @property
24 | @abstractmethod
25 | def inv_vocab(self):
26 | """Dictionary from vocab id token to text token."""
27 | pass
28 |
29 | @abstractmethod
30 | def tokenize(self, text):
31 | pass
32 |
33 | def detokenize(self, token_ids):
34 | raise NotImplementedError('detokenizer is not implemented for {} '
35 | 'tokenizer'.format(self.name))
36 |
37 | @property
38 | def cls(self):
39 | raise NotImplementedError('CLS is not provided for {} '
40 | 'tokenizer'.format(self.name))
41 |
42 | @property
43 | def sep(self):
44 | raise NotImplementedError('SEP is not provided for {} '
45 | 'tokenizer'.format(self.name))
46 |
47 | @property
48 | def pad(self):
49 | raise NotImplementedError('PAD is not provided for {} '
50 | 'tokenizer'.format(self.name))
51 |
52 | @property
53 | def eod(self):
54 | raise NotImplementedError('EOD is not provided for {} '
55 | 'tokenizer'.format(self.name))
56 |
57 | @property
58 | def mask(self):
59 | raise NotImplementedError('MASK is not provided for {} '
60 | 'tokenizer'.format(self.name))
61 |
62 |
63 | class _SentencePieceTokenizer(AbstractTokenizer):
64 | """SentencePieceTokenizer-Megatron wrapper"""
65 |
66 | def __init__(self, model_file, vocab_extra_ids=0):
67 | name = 'SentencePieceTokenizer'
68 | super().__init__(name)
69 |
70 | import sentencepiece
71 | self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file)
72 | self._initalize(vocab_extra_ids)
73 |
74 | def _populate_vocab(self):
75 | self._vocab = {}
76 | self._inv_vocab = {}
77 |
78 | for i in range(len(self.tokenizer)):
79 | t = self.tokenizer.id_to_piece(i)
80 | self._inv_vocab[i] = t
81 | self._vocab[t] = i
82 |
83 | def _initalize(self, vocab_extra_ids):
84 | self._populate_vocab()
85 | self._special_tokens = {}
86 | self._inv_special_tokens = {}
87 |
88 | self._t5_tokens = []
89 |
90 | def _add_special_token(t):
91 | if t not in self._vocab:
92 | next_id = len(self._vocab)
93 | self._vocab[t] = next_id
94 | self._inv_vocab[next_id] = t
95 | self._special_tokens[t] = self._vocab[t]
96 | self._inv_special_tokens[self._vocab[t]] = t
97 |
98 | _add_special_token('')
99 | self._cls_id = self._vocab['']
100 | _add_special_token('')
101 | self._sep_id = self._vocab['']
102 | _add_special_token('')
103 | self._eod_id = self._vocab['']
104 | _add_special_token('')
105 | self._mask_id = self._vocab['']
106 |
107 | pad_id = self.tokenizer.pad_id()
108 | try:
109 | pad_token = self.tokenizer.id_to_piece(pad_id)
110 | except IndexError:
111 | pad_token = ''
112 | _add_special_token(pad_token)
113 | self._pad_id = self._vocab[pad_token]
114 |
115 | bos_id = self.tokenizer.bos_id()
116 | try:
117 | bos_token = self.tokenizer.id_to_piece(bos_id)
118 | except IndexError:
119 | bos_token = ''
120 | _add_special_token(bos_token)
121 | self._bos_id = self._vocab[bos_token]
122 |
123 | eos_id = self.tokenizer.eos_id()
124 | try:
125 | eos_token = self.tokenizer.id_to_piece(eos_id)
126 | except IndexError:
127 | eos_token = ''
128 | _add_special_token(eos_token)
129 | self._eos_id = self._vocab[eos_token]
130 |
131 | for i in range(vocab_extra_ids):
132 | t = "".format(i)
133 | _add_special_token(t)
134 | self._t5_tokens += [t]
135 |
136 | @property
137 | def vocab_size(self):
138 | return len(self._vocab)
139 |
140 | @property
141 | def vocab(self):
142 | return self._vocab
143 |
144 | @property
145 | def inv_vocab(self):
146 | return self._inv_vocab
147 |
148 | @property
149 | def decoder(self):
150 | return self._inv_vocab
151 |
152 | @property
153 | def encoder(self):
154 | return self._vocab
155 |
156 | # From:
157 | # https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L89
158 | def tokenize(self, text):
159 | ids = []
160 | idx = 0
161 |
162 | while 1:
163 | indices = {}
164 | for token in self._special_tokens:
165 | try:
166 | indices[token] = text[idx:].index(token)
167 | except ValueError:
168 | continue
169 | if len(indices) == 0:
170 | break
171 |
172 | next_token = min(indices, key=indices.get)
173 | next_idx = idx + indices[next_token]
174 |
175 | ids.extend(self.tokenizer.encode_as_ids(text[idx:next_idx]))
176 | ids.append(self._special_tokens[next_token])
177 | idx = next_idx + len(next_token)
178 |
179 | ids.extend(self.tokenizer.encode_as_ids(text[idx:]))
180 | return ids
181 |
182 | # From:
183 | # https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L125
184 | def detokenize(self, ids):
185 | text = ""
186 | last_i = 0
187 |
188 | for i, id in enumerate(ids):
189 | if id in self._inv_special_tokens:
190 | text += self.tokenizer.decode_ids(ids[last_i:i]) + " "
191 | text += self._inv_special_tokens[id] + " "
192 | last_i = i + 1
193 |
194 | text += self.tokenizer.decode_ids(ids[last_i:])
195 | return text
196 |
197 | @property
198 | def cls(self):
199 | return self._cls_id
200 |
201 | @property
202 | def sep(self):
203 | return self._sep_id
204 |
205 | @property
206 | def pad(self):
207 | return self._pad_id
208 |
209 | @property
210 | def bos_token_id(self):
211 | return self._bos_id
212 |
213 | @property
214 | def bos(self):
215 | return self._bos_id
216 |
217 | @property
218 | def eod(self):
219 | return self._eod_id
220 |
221 | @property
222 | def eos_token_id(self):
223 | return self._eos_id
224 |
225 | @property
226 | def eos(self):
227 | return self._eos_id
228 |
229 | @property
230 | def mask(self):
231 | return self._mask_id
232 |
233 | @property
234 | def additional_special_tokens_ids(self):
235 | return [self.vocab[k] for k in self._t5_tokens]
236 |
237 | class _MMSentencePieceTokenizer(_SentencePieceTokenizer):
238 | """SentencePieceTokenizer-Megatron wrapper"""
239 |
240 | def __init__(self, model_file, vocab_extra_ids=0):
241 | super().__init__(model_file, vocab_extra_ids)
242 |
243 |
244 | def _initalize(self, vocab_extra_ids):
245 | self._populate_vocab()
246 | self._special_tokens = {}
247 | self._inv_special_tokens = {}
248 |
249 | self._t5_tokens = []
250 |
251 | def _add_special_token(t):
252 | if t not in self._vocab:
253 | next_id = len(self._vocab)
254 | self._vocab[t] = next_id
255 | self._inv_vocab[next_id] = t
256 | self._special_tokens[t] = self._vocab[t]
257 | self._inv_special_tokens[self._vocab[t]] = t
258 |
259 | _add_special_token('')
260 | self._cls_id = self._vocab['']
261 | _add_special_token('')
262 | self._sep_id = self._vocab['']
263 | _add_special_token('')
264 | self._eod_id = self._vocab['']
265 | _add_special_token('')
266 | self._mask_id = self._vocab['']
267 |
268 | _add_special_token('')
269 | self._soa_id = self._vocab['']
270 | _add_special_token('')
271 | self._eoa_id = self._vocab['']
272 | _add_special_token('')
273 | self._sov_id = self._vocab['']
274 | _add_special_token('')
275 | self._eov_id = self._vocab['']
276 | _add_special_token('')
277 | self._soi_id = self._vocab['']
278 | _add_special_token('')
279 | self._eoi_id = self._vocab['']
280 | _add_special_token('')
281 | self._s_local_id = self._vocab['']
282 | _add_special_token('')
283 | self._e_local_id = self._vocab['']
284 | _add_special_token('')
285 | self._s_global_id = self._vocab['']
286 | _add_special_token('')
287 | self._e_global_id = self._vocab['']
288 | _add_special_token('')
289 | self._stage_1_id = self._vocab['']
290 | _add_special_token('')
291 | self._stage_2_id = self._vocab['']
292 | pad_id = self.tokenizer.pad_id()
293 | try:
294 | pad_token = self.tokenizer.id_to_piece(pad_id)
295 | except IndexError:
296 | pad_token = ''
297 | _add_special_token(pad_token)
298 | self._pad_id = self._vocab[pad_token]
299 |
300 | bos_id = self.tokenizer.bos_id()
301 | try:
302 | bos_token = self.tokenizer.id_to_piece(bos_id)
303 | except IndexError:
304 | bos_token = ''
305 | _add_special_token(bos_token)
306 | self._bos_id = self._vocab[bos_token]
307 |
308 | eos_id = self.tokenizer.eos_id()
309 | try:
310 | eos_token = self.tokenizer.id_to_piece(eos_id)
311 | except IndexError:
312 | eos_token = ''
313 | _add_special_token(eos_token)
314 | self._eos_id = self._vocab[eos_token]
315 |
316 | for i in range(vocab_extra_ids):
317 | t = "".format(i)
318 | _add_special_token(t)
319 | self._t5_tokens += [t]
320 |
321 | @property
322 | def soa(self):
323 | return self._soa_id
324 |
325 | @property
326 | def eoa(self):
327 | return self._eoa_id
328 |
329 | @property
330 | def sov(self):
331 | return self._sov_id
332 |
333 | @property
334 | def eov(self):
335 | return self._eov_id
336 |
337 | @property
338 | def soi(self):
339 | return self._soi_id
340 |
341 | @property
342 | def eoi(self):
343 | return self._eoi_id
344 |
345 | @property
346 | def s_local(self):
347 | return self._s_local_id
348 |
349 | @property
350 | def e_local(self):
351 | return self._e_local_id
352 |
353 | @property
354 | def s_global(self):
355 | return self._s_global_id
356 |
357 | @property
358 | def e_global(self):
359 | return self._e_global_id
360 |
361 | @property
362 | def stage_1(self):
363 | return self._stage_1_id
364 |
365 | @property
366 | def stage_2(self):
367 | return self._stage_2_id
368 |
--------------------------------------------------------------------------------
/code/finetune/core/datasets/blended_megatron_dataset_builder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2 |
3 | import logging
4 | import math
5 | from typing import Any, Callable, List, Optional, Tuple, Type, Union
6 |
7 | import numpy
8 | import torch
9 |
10 | from core.datasets.blended_dataset import BlendedDataset
11 | from core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
12 | from core.datasets.indexed_dataset import MMapIndexedDataset
13 | from core.datasets.megatron_dataset import MegatronDataset
14 | from core.datasets.utils import Split, normalize
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 | DistributedDataset = Union[
19 | BlendedDataset, MegatronDataset, MMapIndexedDataset, torch.utils.data.Dataset
20 | ]
21 |
22 |
23 | class BlendedMegatronDatasetBuilder(object):
24 | """Builder class for the BlendedDataset and MegatronDataset classes
25 |
26 | Args:
27 | cls (Type[MegatronDataset]): The class to instantiate, must inherit from MegatronDataset
28 |
29 | sizes (List[int]): The minimum number of total samples to draw from each split, varies
30 | with blend
31 |
32 | config (BlendedMegatronDatasetConfig): The config object which informs dataset creation
33 | """
34 |
35 | def __init__(
36 | self, cls: Type[MegatronDataset], sizes: List[int], config: BlendedMegatronDatasetConfig,
37 | ):
38 | self.cls = cls
39 | self.sizes = sizes
40 | self.config = config
41 |
42 | def build(self) -> List[Optional[Union[BlendedDataset, MegatronDataset]]]:
43 | """Build all dataset splits according to the provided blend(s)
44 |
45 | This method is distributed-aware and must be called on all ranks.
46 |
47 | The dataset splits returned can vary according to the config. Supply config.blend and
48 | config.split to build BlendedDataset and/or MegatronDataset splits from the same
49 | distribution. Supply config.blend_per_split to build BlendedDataset and/or MegatronDataset
50 | splits from separate distributions.
51 |
52 | Returns:
53 | List[Optional[Union[BlendedDataset, MegatronDataset]]]: A list of either
54 | MegatronDataset or BlendedDataset (or None) per split
55 | """
56 | return self._build_blended_dataset_splits()
57 |
58 | def _build_blended_dataset_splits(
59 | self,
60 | ) -> List[Optional[Union[BlendedDataset, MegatronDataset]]]:
61 | """Build all dataset splits according to the provided blend(s)
62 |
63 | See the BlendedMegatronDatasetBuilder.build alias for more information.
64 |
65 | Returns:
66 | List[Optional[Union[BlendedDataset, MegatronDataset]]]: A list of either
67 | MegatronDataset or BlendedDataset (or None) per split
68 | """
69 |
70 | if self.config.blend:
71 | blend = self.config.blend
72 | split = self.config.split_matrix
73 |
74 | # Blend consists of a single prefix
75 | if len(blend) == 1:
76 | return self._build_megatron_dataset_splits(blend[0], split, self.sizes)
77 |
78 | # Blend consists of multiple weights and prefixes
79 | (
80 | prefix_per_dataset,
81 | weight_per_dataset,
82 | sizes_per_dataset,
83 | ) = _get_prefixes_weights_and_sizes_for_blend(blend, self.sizes)
84 |
85 | megatron_datasets = [[] for _ in range(len(Split))]
86 |
87 | for i in range(len(prefix_per_dataset)):
88 | megatron_datasets_split = self._build_megatron_dataset_splits(
89 | prefix_per_dataset[i], split, sizes_per_dataset[i]
90 | )
91 | for j in range(len(megatron_datasets_split)):
92 | megatron_datasets[j].append(megatron_datasets_split[j])
93 |
94 | # Sum over all contributing datasets, per split
95 | size_per_split = list(map(sum, zip(*sizes_per_dataset)))
96 |
97 | blended_datasets = []
98 |
99 | for i in range(len(megatron_datasets)):
100 | is_none = map(lambda _: _ is None, megatron_datasets[i])
101 |
102 | if split[i] is None:
103 | assert all(is_none)
104 | blended_datasets.append(None)
105 | else:
106 | assert all(is_none) or not any(is_none)
107 | blended_datasets.append(
108 | self.build_generic_dataset(
109 | BlendedDataset,
110 | self.config.is_built_on_rank,
111 | megatron_datasets[i],
112 | weight_per_dataset,
113 | size_per_split[i],
114 | self.config,
115 | )
116 | )
117 |
118 | return blended_datasets
119 |
120 | else:
121 | blended_datasets = []
122 | for i in range(len(Split)):
123 | blend = self.config.blend_per_split[i]
124 |
125 | # Blend is not provided
126 | if not blend:
127 | blended_datasets.append(None)
128 | continue
129 |
130 | split_spoof = [None] * len(Split)
131 | split_spoof[i] = (0.0, 1.0)
132 | sizes_spoof = [0] * len(Split)
133 | sizes_spoof[i] = self.sizes[i]
134 |
135 | # Blend consists of a sigle prefix
136 | if len(blend) == 1:
137 | blended_datasets.append(
138 | self._build_megatron_dataset_splits(blend[0], split_spoof, sizes_spoof)[i]
139 | )
140 |
141 | # Blend consists of multiple weights and prefixes
142 | else:
143 | (
144 | prefix_per_dataset,
145 | weight_per_dataset,
146 | sizes_per_dataset,
147 | ) = _get_prefixes_weights_and_sizes_for_blend(blend, sizes_spoof)
148 |
149 | megatron_datasets = []
150 | for j in range(len(prefix_per_dataset)):
151 | megatron_datasets.append(
152 | self._build_megatron_dataset_splits(
153 | prefix_per_dataset[j], split_spoof, sizes_per_dataset[j],
154 | )[i]
155 | )
156 |
157 | size_per_split = list(map(sum, zip(*sizes_per_dataset)))
158 |
159 | blended_datasets.append(
160 | self.build_generic_dataset(
161 | BlendedDataset,
162 | self.config.is_built_on_rank,
163 | megatron_datasets,
164 | weight_per_dataset,
165 | size_per_split[i],
166 | self.config,
167 | )
168 | )
169 |
170 | return blended_datasets
171 |
172 | def _build_megatron_dataset_splits(
173 | self, path_prefix: str, split: List[float], sizes: List[int],
174 | ) -> List[Optional[MegatronDataset]]:
175 | """Build each MegatronDataset split from a single MMapIndexedDataset
176 |
177 | Args:
178 | path_prefix (str): The MMapIndexedDataset .bin and .idx file prefix
179 |
180 | split (List[Tuple[float, float]]): The dataset split matrix
181 |
182 | sizes (List[int]): The number of total samples to draw from each split
183 |
184 | Returns:
185 | List[Optional[MegatronDataset]]: The MegatronDatset (or None) per split
186 | """
187 | indexed_dataset = self.build_generic_dataset(
188 | MMapIndexedDataset, self.config.is_built_on_rank, path_prefix, self.cls.is_multimodal(),
189 | )
190 |
191 | if indexed_dataset is not None:
192 | if self.cls.is_split_by_sequence():
193 | num_elements = indexed_dataset.sequence_lengths.shape[0]
194 | else:
195 | num_elements = indexed_dataset.document_indices.shape[0] - 1
196 |
197 | split_indices = []
198 | for i, _ in enumerate(Split):
199 | if split[i] is not None:
200 | beg = int(round(split[i][0] * float(num_elements)))
201 | end = int(round(split[i][1] * float(num_elements)))
202 | split_indices.append(
203 | numpy.arange(start=beg, stop=end, step=1, dtype=numpy.int32)
204 | )
205 | else:
206 | split_indices.append(None)
207 | else:
208 | split_indices = [None for _ in Split]
209 |
210 | megatron_datasets = []
211 | for i, _split in enumerate(Split):
212 | if split[i] is None:
213 | megatron_datasets.append(None)
214 | else:
215 | megatron_datasets.append(
216 | self.build_generic_dataset(
217 | self.cls,
218 | self.config.is_built_on_rank,
219 | indexed_dataset,
220 | split_indices[i],
221 | sizes[i],
222 | _split,
223 | self.config,
224 | )
225 | )
226 |
227 | return megatron_datasets
228 |
229 | @staticmethod
230 | def build_generic_dataset(
231 | cls: Type[DistributedDataset], is_built_on_rank: Callable, *args: Any
232 | ) -> Optional[DistributedDataset]:
233 | """Build the DistributedDataset
234 |
235 | Return None if and only if the underlying MegatronDataset class is not built on the current
236 | rank and torch.distributed is initialized.
237 |
238 | Args:
239 | cls (Type[DistributedDataset]): The DistributedDataset class to be built
240 |
241 | args (Tuple[Any]): The positional arguments used to build the provided
242 | DistributedDataset class
243 |
244 | Raises:
245 | Exception: When the dataset constructor raises an OSError
246 |
247 | Returns:
248 | Optional[DistributedDataset]: The DistributedDataset instantion or None
249 | """
250 | if torch.distributed.is_initialized():
251 | rank = torch.distributed.get_rank()
252 |
253 | dataset = None
254 |
255 | # First, build on rank 0
256 | if rank == 0 and is_built_on_rank():
257 | try:
258 | dataset = cls(*args)
259 | except OSError as err:
260 | log = (
261 | f"Failed to write dataset materials to the data cache directory. "
262 | + f"Please supply a directory to which you have write access via "
263 | + f"the path_to_cache attribute in BlendedMegatronDatasetConfig and "
264 | + f"retry. Refer to the preserved traceback above for more information."
265 | )
266 | raise Exception(log) from err
267 |
268 | torch.distributed.barrier()
269 |
270 | # After, build on other ranks
271 | if rank != 0 and is_built_on_rank():
272 | dataset = cls(*args)
273 |
274 | return dataset
275 |
276 | return cls(*args)
277 |
278 |
279 | def _get_prefixes_weights_and_sizes_for_blend(
280 | blend: List[str], target_num_samples_per_split: List[int]
281 | ) -> Tuple[List[str], List[float], List[List[int]]]:
282 | """Determine the contribution of the MegatronDataset splits to the BlendedDataset splits
283 |
284 | Args:
285 | blend (List[str]): e.g. ["30", "path/to/dataset_1_prefix", "70",
286 | "path/to/dataset_2_prefix"]
287 |
288 | target_num_samples_per_split (List[int]): The number of samples to target for each
289 | BlendedDataset split
290 |
291 | Returns:
292 | Tuple[List[str], List[float], List[List[int]]]: The prefix strings e.g.
293 | ["path/to/dataset_1_prefix", "path/to/dataset_2_prefix"], the normalized weights e.g.
294 | [0.3, 0.7], and the number of samples to request per MegatronDataset per split
295 | """
296 | weights, prefixes = zip(
297 | *[(float(blend[i]), blend[i + 1].strip()) for i in range(0, len(blend), 2)]
298 | )
299 |
300 | weights = normalize(weights)
301 |
302 | # Use 0.5% target margin to ensure we satiate the network
303 | sizes_per_dataset = [
304 | [
305 | int(math.ceil(target_num_samples * weight * 1.005))
306 | for target_num_samples in target_num_samples_per_split
307 | ]
308 | for weight in weights
309 | ]
310 |
311 | return prefixes, weights, sizes_per_dataset
312 |
--------------------------------------------------------------------------------
/code/top_200_tags.json:
--------------------------------------------------------------------------------
1 | {
2 | "genre": [
3 | "Pop",
4 | "rock",
5 | "pop",
6 | "electronic",
7 | "Classical",
8 | "R&B",
9 | "Electronic",
10 | "Rock",
11 | "Folk",
12 | "rap",
13 | "classical",
14 | "soundtrack",
15 | "country",
16 | "indie-rock",
17 | "punk",
18 | "hiphop",
19 | "folk",
20 | "jazz",
21 | "Country",
22 | "hip-hop",
23 | "Hip-hop",
24 | "experimental",
25 | "Hip Hop",
26 | "Funk",
27 | "blues",
28 | "ambient",
29 | "Rap",
30 | "Jazz",
31 | "Ambient",
32 | "New Age",
33 | "Blues",
34 | "experimental pop",
35 | "classic rock",
36 | "indie rock",
37 | "alternative rock",
38 | "Reggae",
39 | "Electro pop",
40 | "K-pop",
41 | "Dance",
42 | "Soundtrack",
43 | "Hip hop",
44 | "80s",
45 | "Dancehall",
46 | "Disco",
47 | "House",
48 | "Death Metal",
49 | "Thrash Metal",
50 | "international",
51 | "progressive rock",
52 | "hard rock",
53 | "instrumental",
54 | "Lounge",
55 | "house",
56 | "Latin",
57 | "hardcore",
58 | "Metalcore",
59 | "Soul",
60 | "grunge",
61 | "Easy listening",
62 | "easylistening",
63 | "Indian",
64 | "ethno",
65 | "Hard rock",
66 | "hip hop",
67 | "Indie Pop",
68 | "Electro",
69 | "industrial",
70 | "grindcore",
71 | "post-rock",
72 | "Soul-R&B",
73 | "Reggaeton",
74 | "World",
75 | "latin pop",
76 | "Classic Rock",
77 | "Latin pop",
78 | "Deathcore",
79 | "soul",
80 | "improvisation",
81 | "Chinese",
82 | "techno",
83 | "Salsa",
84 | "indie pop",
85 | "Hardcore",
86 | "拉丁",
87 | "Black metal",
88 | " Americana",
89 | "dance",
90 | "rock nacional",
91 | "tejano",
92 | "indie",
93 | "ambient electronic",
94 | "world",
95 | "Death metal",
96 | "Trap",
97 | "avant-garde",
98 | "Chillout",
99 | "Americana",
100 | "new wave",
101 | "rnb",
102 | "pop rock",
103 | "post-hardcore",
104 | "singer-songwriter",
105 | "pop punk",
106 | "Power metal",
107 | "indie folk",
108 | "opera",
109 | "Metal",
110 | "African",
111 | "instrumental rock",
112 | "Gospel",
113 | "downtempo",
114 | "New Wave",
115 | "Electro-pop",
116 | "rockabilly",
117 | "MPB",
118 | "goth rock",
119 | "soul-R&B",
120 | "Black Metal",
121 | "Dubstep",
122 | "Eurovision",
123 | "Bossa Nova",
124 | "bossanova",
125 | "民谣",
126 | "big band",
127 | "Synthpop",
128 | "死亡金属",
129 | "中国传统音乐",
130 | "glam rock",
131 | "国际音乐",
132 | "latin",
133 | "operatic",
134 | "Melodic Death Metal",
135 | "lounge",
136 | " Regional Mexican",
137 | "instrumental pop",
138 | "emo",
139 | "旋律死亡金属",
140 | "Pop Rock",
141 | "popfolk",
142 | " Latin",
143 | "poprock",
144 | "eurovision",
145 | "Ska",
146 | "Techno",
147 | "disco",
148 | "基督教音乐",
149 | "Indie rock",
150 | "Goregrind",
151 | "8-bit",
152 | "Pop rock",
153 | "screamo",
154 | "Dance pop",
155 | "Guitar",
156 | "chillout",
157 | "beats",
158 | "Big band",
159 | "mpb",
160 | "Bluegrass",
161 | "流行",
162 | "Thrash metal",
163 | "easy listening",
164 | "Samba",
165 | "Heavy metal",
166 | "Symphonic metal",
167 | "Chanson",
168 | "Oriental",
169 | "synthpop",
170 | "Girl group",
171 | "Epic",
172 | "Celtic",
173 | "Screamo",
174 | "Espanol",
175 | "Middle Eastern",
176 | "electro",
177 | " Soul-R&B",
178 | " Classic Rock",
179 | "Heavy Metal",
180 | "dubstep",
181 | "民乐",
182 | "country rock",
183 | "funk",
184 | "ska",
185 | "Indie Rock",
186 | "Choral",
187 | "J-rock",
188 | "shoegaze",
189 | "Rockabilly",
190 | "grime",
191 | "Italian pop",
192 | "摇滚",
193 | " latin",
194 | "Bolero",
195 | " orchestral",
196 | "experimental hip-hop",
197 | "eurodance",
198 | "noise rock",
199 | "electro pop",
200 | "noise",
201 | "Crossover Country",
202 | "Glitch"
203 | ],
204 | "instrument": [
205 | "Piano",
206 | "drums",
207 | "guitar",
208 | "electric guitar",
209 | "Guitar",
210 | "synthesizer",
211 | "Synthesizer",
212 | "Keyboard",
213 | "piano",
214 | "Drums",
215 | "Violin",
216 | "bass",
217 | "acoustic guitar",
218 | "Bass",
219 | "violin",
220 | "voice",
221 | "vocal",
222 | "acousticguitar",
223 | "Electric guitar",
224 | "Acoustic guitar",
225 | "electricguitar",
226 | "Voice",
227 | "keyboard",
228 | "saxophone",
229 | "beat",
230 | "Drum machine",
231 | "Cello",
232 | "harmonica",
233 | "fiddle",
234 | "Percussion",
235 | "beatboxing",
236 | "Vocal",
237 | "鼓",
238 | "Saxophone",
239 | "keys",
240 | "harp",
241 | "Keyboards",
242 | "keyboards",
243 | " harmonica",
244 | "singing",
245 | "吉他",
246 | "贝斯",
247 | "钢琴",
248 | "beats",
249 | "flute",
250 | "bass guitar",
251 | "drum",
252 | "brass",
253 | "Flute",
254 | "Fiddle",
255 | "charango",
256 | "Sitar",
257 | "strings",
258 | "trumpet",
259 | "Brass",
260 | "Vocals",
261 | "Trumpet",
262 | "string",
263 | "Singing",
264 | " banjo",
265 | "drum machine",
266 | "cello",
267 | "Acoustic Guitar",
268 | "glockenspiel",
269 | "computer",
270 | "电吉他",
271 | "合成器",
272 | "键盘",
273 | "mallets",
274 | "原声吉他",
275 | "Drum",
276 | "Bass guitar",
277 | "Dholak",
278 | "congas",
279 | "Electric Guitar",
280 | "二胡",
281 | "鼓机",
282 | "synth",
283 | "Strings",
284 | "小提琴",
285 | "Trombone",
286 | "percussion",
287 | "弦乐",
288 | "electricpiano",
289 | "风琴",
290 | "oboe",
291 | "horns",
292 | "Erhu",
293 | " synthesizer",
294 | "acoustic drums",
295 | " pedal steel guitar",
296 | " Voice",
297 | "Tambourine",
298 | "singer-songwriter",
299 | "Oud",
300 | "Qanun",
301 | "electronic",
302 | " pedal steel",
303 | "rapping",
304 | "Funky bass",
305 | "guitars",
306 | "木吉他",
307 | "Alto saxophone",
308 | "Ukulele",
309 | "扬琴",
310 | "oud",
311 | "sitar",
312 | "打击乐器",
313 | "Synth",
314 | "organ",
315 | "Kanun",
316 | "人声",
317 | "古筝",
318 | " accordion",
319 | "bandura",
320 | "banjo",
321 | "长笛",
322 | "pandeira",
323 | "turntables",
324 | "Alto Saxophone",
325 | " slideguitar",
326 | " electricguitar",
327 | "rap",
328 | "harpsichord",
329 | "萨克斯管",
330 | "maracas",
331 | "口琴",
332 | "Guitars",
333 | "Dobro guitar",
334 | "vocals",
335 | "choir",
336 | "Ableton",
337 | "Horns",
338 | "AcousticGuitar",
339 | "笛子",
340 | "synth drums",
341 | "Glockenspiel",
342 | "Harp",
343 | "zither",
344 | "Dobro",
345 | "Musical instrument",
346 | "electric piano",
347 | "竖琴",
348 | "Horn",
349 | "手风琴",
350 | "None",
351 | "Choir",
352 | "铜管乐器",
353 | "String",
354 | "vocal samples",
355 | "trombone",
356 | "班卓琴",
357 | "hu lu si",
358 | "Pandeira",
359 | "采样器",
360 | " Banjo",
361 | "Synth bass",
362 | "synth bass",
363 | "mallet",
364 | " tabla",
365 | "dulcimer",
366 | "声乐",
367 | "Cavaquinho",
368 | "大提琴",
369 | "toms",
370 | "ney",
371 | " trumpet",
372 | " voice",
373 | "低音",
374 | "Zither",
375 | "shakuhachi",
376 | "主唱",
377 | " electric guitar",
378 | "tambourine",
379 | "Turntables",
380 | "lyrics",
381 | " concertina",
382 | " piano",
383 | " steel guitar",
384 | "Bongos",
385 | "Koto",
386 | "808 bass",
387 | "Marimba",
388 | " drums",
389 | "Dance",
390 | "萨克斯风",
391 | "木琴",
392 | " bass",
393 | "ukulele",
394 | "Steel pan",
395 | "女声",
396 | "键盘乐器",
397 | "whistle",
398 | "soprano saxophone",
399 | "Nylon string guitar",
400 | "synth_lead",
401 | "电脑",
402 | "Shakuhachi",
403 | "oboes",
404 | "Rap"
405 | ],
406 | "mood": [
407 | "Uplifting",
408 | "emotional",
409 | "uplifting",
410 | "happy",
411 | "Inspiring",
412 | "romantic",
413 | "sad",
414 | "Love",
415 | "melancholic",
416 | "dark",
417 | "Upbeat",
418 | "Energetic",
419 | "Romantic",
420 | "Melancholic",
421 | "Nostalgic",
422 | "Calm",
423 | "Hopeful",
424 | "melodic",
425 | "relaxing",
426 | "Romance",
427 | "Emotional",
428 | "Dreamy",
429 | "energetic",
430 | "rebellious",
431 | "Dance",
432 | "inspiring",
433 | " introspective",
434 | "Confident",
435 | "aggressive",
436 | "Positive",
437 | "calm",
438 | "cool",
439 | "Happy",
440 | "hopeful",
441 | "beautiful",
442 | "advertising",
443 | "angry",
444 | "Sad",
445 | "relaxed",
446 | "Celebratory",
447 | "Angry",
448 | "Bold",
449 | "Introspective",
450 | "Optimistic",
451 | "sentimental",
452 | "optimistic",
453 | "Tough",
454 | "motivational",
455 | "Heartfelt",
456 | "Funky",
457 | "communication",
458 | "Danceable",
459 | "vivacious",
460 | "love",
461 | "commercial",
462 | "Vivacious",
463 | "heavy",
464 | "ballad",
465 | "thoughtful",
466 | "fast-paced",
467 | "Futuristic",
468 | "Joyful",
469 | "emotion",
470 | "Soulful",
471 | "attitude",
472 | "positive",
473 | "epic",
474 | "Festive",
475 | "Melodic",
476 | "Dancy",
477 | "Aggressive",
478 | "soft",
479 | "Calming",
480 | "exciting",
481 | "dreamy",
482 | "Epic",
483 | "nostalgic",
484 | "powerful",
485 | "adventure",
486 | "passionate",
487 | "Determined",
488 | "沟通",
489 | "Sensual",
490 | "Playful",
491 | "street",
492 | "heartfelt",
493 | "Rebellious",
494 | "intense",
495 | "Sentimental",
496 | "inspirational",
497 | "travel",
498 | "Adventurous",
499 | "atmospheric",
500 | "summer",
501 | "easygoing",
502 | "Cheerful",
503 | "Cool",
504 | "Dark",
505 | "rock",
506 | "Inspiration",
507 | "Chill",
508 | "Intense",
509 | "confident",
510 | "empowering",
511 | "Violent",
512 | "Intimate",
513 | "longing",
514 | " meditative",
515 | "Attitude",
516 | "romance",
517 | "experimental",
518 | "at sea",
519 | "放松",
520 | "chill",
521 | "Exciting",
522 | "Soothing",
523 | "Empowering",
524 | "暴力",
525 | "Brawny",
526 | "cheerful",
527 | "Motivational",
528 | "Vibraphone",
529 | "tough",
530 | "determined",
531 | "hardcore",
532 | "Reflective",
533 | "funny",
534 | "Peaceful",
535 | "loud",
536 | "Pensive",
537 | "向上",
538 | "playful",
539 | "Furious",
540 | "时尚",
541 | "希望",
542 | "rough",
543 | "Intimacy",
544 | "dance",
545 | "Vibrant",
546 | "Relaxed",
547 | "soundscape",
548 | "Brutal",
549 | "thought-provoking",
550 | "success",
551 | "sleepy",
552 | "Elegant",
553 | "children",
554 | "intimate",
555 | "残酷",
556 | "怀旧",
557 | "improvisational",
558 | "浪漫",
559 | "Ambient",
560 | "Affectionate",
561 | "Gory",
562 | "Dramatic",
563 | "enthusiastic",
564 | "感性",
565 | "ambient",
566 | "Gentle",
567 | "愤怒",
568 | "快乐",
569 | "黑暗",
570 | "brawny",
571 | "Seductive",
572 | "Dancing",
573 | "introspective",
574 | "instrumental",
575 | "Satisfied",
576 | "hard",
577 | "史诗",
578 | " documentary",
579 | " dreamy",
580 | "Lively",
581 | "child",
582 | "sassy",
583 | "dissonant",
584 | "Emotive",
585 | "electronic",
586 | "抒情",
587 | "meditative",
588 | "Gloomy",
589 | "groovy",
590 | " film",
591 | "adventure, emotion",
592 | "ambitious",
593 | "Spiritual",
594 | "christmas",
595 | "reminiscent",
596 | "saloon",
597 | "vintage",
598 | "梦幻",
599 | "爱",
600 | "fast_decay",
601 | "Comedy",
602 | "Asian",
603 | "侵略性",
604 | "Admirative",
605 | " communication",
606 | "忧郁"
607 | ],
608 | "gender": [
609 | "male",
610 | "female",
611 | "singing",
612 | "soprano",
613 | "child",
614 | "human",
615 | "human female voice",
616 | "unspecified",
617 | "screamo",
618 | "mezzo-soprano",
619 | "human voice",
620 | "not specified",
621 | "tenor",
622 | "rapping",
623 | "singing voice",
624 | "squeaky",
625 | "童声",
626 | "children"
627 | ],
628 | "timbre": [
629 | "bright vocal",
630 | "full vocal",
631 | "airy vocal",
632 | "clear vocal",
633 | "mellow vocal",
634 | "dark vocal",
635 | "rich vocal",
636 | "reverb vocal",
637 | "light vocal",
638 | "crisp vocal",
639 | "broad vocal",
640 | "powerful vocal",
641 | "piercing vocal",
642 | "high-pitched vocal",
643 | "bass vocal",
644 | "deep vocal",
645 | "not applicable vocal",
646 | "baritone vocal",
647 | "not specified vocal",
648 | "vibrant vocal",
649 | "boomy vocal",
650 | "varied vocal",
651 | "bouncy vocal",
652 | "range vocal",
653 | "harsh vocal",
654 | " airy vocal",
655 | "round vocal",
656 | "uplifting vocal",
657 | "soft vocal",
658 | "husky vocal",
659 | "tenor vocal",
660 | "pontificate vocal",
661 | "aggressive vocal",
662 | "neat vocal",
663 | "high vocal",
664 | "exuberant vocal",
665 | "open vocal",
666 | "full bodied vocal",
667 | "strong vocal",
668 | "grainy vocal",
669 | "vocal fry vocal",
670 | "gravelly vocal",
671 | "low vocal",
672 | "long_release vocal",
673 | "polished vocal",
674 | "velvet vocal",
675 | "placid vocal",
676 | "plastic vocal",
677 | "sharp vocal",
678 | "robust vocal",
679 | "muffled vocal",
680 | "distortion vocal",
681 | "crunchy vocal",
682 | "resonant vocal",
683 | "pure vocal",
684 | "年轻 vocal",
685 | "preenched vocal",
686 | "gruff vocal",
687 | "raspy vocal",
688 | "passionate vocal",
689 | "nonlinear_env vocal",
690 | "high pitched vocal",
691 | "athletic vocal",
692 | "reedy vocal",
693 | "shimmering vocal",
694 | "charismatic vocal",
695 | "gliding vocal",
696 | "raw vocal",
697 | "plucky vocal",
698 | "loud vocal",
699 | "youthful vocal",
700 | "thin vocal",
701 | "soulful vocal",
702 | "smooth vocal",
703 | "flat vocal",
704 | "tempo-synced vocal",
705 | "opulent vocal",
706 | "variable vocal",
707 | "happy vocal",
708 | "prettily vocal",
709 | "percussive vocal",
710 | "singing voice vocal",
711 | "barrel vocal",
712 | "breezy vocal",
713 | "vocal vocal",
714 | "honeyed vocal",
715 | "vivacious vocal",
716 | "full-bodied vocal",
717 | "persuasive vocal",
718 | "tender vocal",
719 | "potent vocal",
720 | "preppy vocal",
721 | " raspy vocal",
722 | "narrow vocal",
723 | "fruity vocal",
724 | "whiny vocal",
725 | "hollow vocal",
726 | "singing vocal",
727 | "rapping vocal",
728 | "flexible vocal",
729 | " alto vocal",
730 | "sweet vocal",
731 | "agitated vocal",
732 | "shaky vocal",
733 | "dainty vocal",
734 | "明亮 vocal",
735 | "soprano vocal",
736 | "vocal range vocal",
737 | "rough vocal",
738 | "有力 vocal",
739 | "成熟 vocal",
740 | "sultry vocal",
741 | "barren vocal",
742 | "bulky vocal",
743 | "prevalent vocal",
744 | "bellowing vocal",
745 | "dusty vocal",
746 | "elevated vocal",
747 | "wide vocal",
748 | "rumbly vocal",
749 | "shrill vocal",
750 | "prettily produced vocal",
751 | "projected vocal",
752 | "low pitched vocal",
753 | "bold vocal",
754 | "grassy vocal",
755 | "plush vocal",
756 | "glorious vocal",
757 | "elevated pitch vocal",
758 | "whispery vocal",
759 | "long vocal",
760 | "nasal vocal",
761 | "preened vocal",
762 | "squeaky vocal",
763 | "hellosing vocal",
764 | "commanding vocal",
765 | "textural vocal",
766 | "noble vocal",
767 | "frustrated vocal",
768 | "warm vocal",
769 | "punchy vocal",
770 | "pretty vocal",
771 | "changeable vocal",
772 | "mushy vocal",
773 | "vocalist vocal",
774 | "gritty vocal",
775 | "barking vocal",
776 | "human vocal",
777 | "bass heavy vocal",
778 | "dulcet vocal",
779 | " smooth vocal",
780 | "young vocal",
781 | "rhythmic vocal",
782 | "vocals vocal",
783 | "helmet vocal",
784 | "screamy vocal",
785 | "hoarse vocal",
786 | "rebellious vocal",
787 | "soothing vocal",
788 | "童声 vocal",
789 | "bitter vocal",
790 | "为了让声乐更加生动,使用了混响效果。 vocal",
791 | "barrel-shaped vocal",
792 | "reed vocal",
793 | "强有力 vocal",
794 | "低沉 vocal",
795 | "whimsical vocal",
796 | "exaggerated vocal",
797 | "温暖 vocal",
798 | "low-pitched vocal",
799 | "emotional vocal",
800 | "graceful vocal",
801 | "breakable vocal",
802 | "screechy vocal",
803 | "muddy vocal",
804 | "breathy vocal",
805 | "柔和 vocal",
806 | "weathered vocal",
807 | "roaring vocal",
808 | "青春 vocal",
809 | "pensive vocal",
810 | "textured vocal",
811 | "清脆 vocal",
812 | "melodic vocal",
813 | "helmeted vocal",
814 | " velvety vocal",
815 | "充满活力 vocal",
816 | "圆润 vocal",
817 | "preteen vocal",
818 | "rhythm vocal",
819 | "treble vocal",
820 | "shouty vocal",
821 | " husky vocal",
822 | "medium vocal",
823 | "blue vocal",
824 | "screeching vocal",
825 | "multiphonic vocal",
826 | "quaint vocal",
827 | "rhytmic vocal",
828 | "轻盈 vocal"
829 | ]
830 | }
--------------------------------------------------------------------------------
/code/finetune/core/datasets/gpt_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2 |
3 | import logging
4 | import os
5 | import time
6 | from dataclasses import dataclass
7 | from typing import Dict, Tuple, Union
8 |
9 | import numpy
10 | import torch
11 |
12 | # from megatron import get_tokenizer
13 | from core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
14 | from core.datasets.indexed_dataset import MMapIndexedDataset
15 | from core.datasets.megatron_dataset import MegatronDataset
16 | from core.datasets.utils import Split, log_single_rank
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | @dataclass
22 | class GPTDatasetConfig(BlendedMegatronDatasetConfig):
23 | """Configuration object for Megatron Core GPT datasets
24 |
25 | Attributes:
26 | return_document_ids (bool): Whether to return the document ids when querying the dataset.
27 |
28 | reset_position_ids (bool): Option to reset the position IDs in the dataset at an interval
29 |
30 | reset_attention_mask (bool): Option to reset the attention mask from the dataset
31 |
32 | eod_mask_loss (bool): Option to enable the EOD mask loss
33 |
34 | eod_id (int): Has the identity of the end of document
35 |
36 | """
37 |
38 | return_document_ids: bool = False
39 | reset_position_ids: bool = False
40 | reset_attention_mask: bool = False
41 | eod_mask_loss: bool = False
42 | eod_id: int = 0
43 | add_bos: bool = False
44 | enable_shuffle: bool = False
45 |
46 |
47 | class GPTDataset(MegatronDataset):
48 | """The base GPT dataset
49 |
50 | Args:
51 | indexed_dataset (MMapIndexedDataset): The MMapIndexedDataset around which to build the
52 | MegatronDataset
53 |
54 | indexed_indices (numpy.ndarray): The set of the documents indices to expose
55 |
56 | num_samples (int): The number of samples to draw from the indexed dataset
57 |
58 | index_split (Split): The indexed_indices Split
59 |
60 | config (GPTDatasetConfig): The GPT-specific container for all config sourced parameters
61 | """
62 |
63 | def __init__(
64 | self,
65 | indexed_dataset: MMapIndexedDataset,
66 | indexed_indices: numpy.ndarray,
67 | num_samples: int,
68 | index_split: Split,
69 | config: GPTDatasetConfig,
70 | ) -> None:
71 | super().__init__(indexed_dataset, indexed_indices, num_samples, index_split, config)
72 | # tokenizer = get_tokenizer()
73 | # self.bos_id = tokenizer.bos
74 | # self.eod_id = tokenizer.eod
75 |
76 | def _finalize(self) -> None:
77 | """Abstract method implementation
78 |
79 | Load or build/cache the document, sample, and shuffle indices
80 | """
81 | assert isinstance(self.config, GPTDatasetConfig)
82 |
83 | (
84 | self.document_index,
85 | self.sample_index,
86 | self.shuffle_index,
87 | ) = self._build_document_sample_shuffle_indices()
88 |
89 | def __len__(self) -> int:
90 | """Abstract method implementation
91 |
92 | Returns:
93 | int: The length of the dataset
94 | """
95 | return self.sample_index.shape[0] - 1
96 |
97 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
98 | """Abstract method implementation
99 |
100 | Args:
101 | idx (int): The index into the dataset
102 |
103 | Returns:
104 | Dict[str, torch.Tensor]: The text ids wrapped in a dictionary
105 | """
106 | text, _ = self._query_document_sample_shuffle_indices(idx)
107 |
108 | text = torch.from_numpy(text)
109 |
110 | tokens_ = text.long()
111 | labels = tokens_[1:].contiguous()
112 | tokens = tokens_[:-1].contiguous()
113 |
114 | attention_mask, loss_mask, position_ids = _get_ltor_masks_and_position_ids(
115 | tokens,
116 | self.config.eod_id,
117 | self.config.reset_position_ids,
118 | self.config.reset_attention_mask,
119 | self.config.eod_mask_loss,
120 | )
121 |
122 | return {
123 | "input_ids": tokens,
124 | "labels": labels,
125 | "attention_mask": attention_mask,
126 | "loss_mask": loss_mask,
127 | "position_ids": position_ids,
128 | }
129 |
130 | @staticmethod
131 | def is_multimodal() -> bool:
132 | """Abstract method implementation
133 |
134 | Returns:
135 | bool: False
136 | """
137 | return False
138 |
139 | @staticmethod
140 | def is_split_by_sequence() -> bool:
141 | """Abstract method implementation
142 |
143 | Returns:
144 | bool: True
145 | """
146 | return True
147 |
148 | def _query_document_sample_shuffle_indices(
149 | self, idx: int
150 | ) -> Tuple[numpy.ndarray, numpy.ndarray]:
151 | """Get the text (token ids) and document ids for a given index
152 |
153 | Args:
154 | idx (int): The index into the dataset
155 |
156 | Returns:
157 | Tuple[numpy.ndarray, numpy.ndarray]: The text ids and document ids
158 | """
159 | # Do the shuffle mapping
160 | idx = self.shuffle_index[idx]
161 |
162 | # Get the beginning and end documents and offsets
163 | doc_index_beg, doc_index_beg_offset = self.sample_index[idx]
164 | doc_index_end, doc_index_end_offset = self.sample_index[idx + 1]
165 |
166 | document_ids = []
167 | sample_parts = []
168 |
169 | # Sample spans a single document
170 | if doc_index_beg == doc_index_end:
171 | # Add the document id
172 | document_ids.append(self.document_index[doc_index_beg])
173 |
174 | # Add the entire sample
175 | sample_parts.append(
176 | self.indexed_dataset.get(
177 | self.document_index[doc_index_beg],
178 | offset=doc_index_beg_offset,
179 | length=doc_index_end_offset - doc_index_beg_offset + 1,
180 | )
181 | )
182 |
183 | # Sample spans multiple documents
184 | else:
185 | for i in range(doc_index_beg, doc_index_end + 1):
186 | # Add the document id
187 | document_ids.append(self.document_index[i])
188 |
189 | # Add the sample part
190 | offset = 0 if i > doc_index_beg else doc_index_beg_offset
191 | length = None if i < doc_index_end else doc_index_end_offset + 1
192 | sample_parts.append(
193 | self.indexed_dataset.get(self.document_index[i], offset=offset, length=length)
194 | )
195 |
196 | if getattr(self.config, "add_bos"):
197 | sample = sample_parts[0]
198 | add_token = self.bos_id if sample[0] != self.bos_id else self.eod_id
199 | sample_parts.insert(0, numpy.array([add_token], dtype=sample.dtype))
200 |
201 | return (
202 | numpy.array(numpy.concatenate(sample_parts), dtype=numpy.int64),
203 | numpy.array(document_ids, dtype=numpy.int64),
204 | )
205 |
206 | def _build_document_sample_shuffle_indices(
207 | self,
208 | ) -> Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]:
209 | """Build the document index, the sample index, and the shuffle index
210 |
211 | The document index:
212 | -- 1-D
213 | -- An ordered array of document ids
214 |
215 | The sample index:
216 | -- 2-D
217 | -- The document indices and offsets which mark the start of every sample
218 |
219 | The shuffle index:
220 | -- 1-D
221 | -- A random permutation of index range of the sample index
222 |
223 | Returns:
224 | Tuple[numpy.ndarray, numpy.ndarray]: The document index, the sample index, and the
225 | shuffle index
226 |
227 | TODO: Explain the 80% threshold
228 | """
229 | path_to_cache = self.config.path_to_cache
230 | if path_to_cache is None:
231 | path_to_cache = os.path.join(
232 | self.indexed_dataset.path_prefix, "cache", f"{type(self).__name__}_indices"
233 | )
234 |
235 | get_path_to = lambda suffix: os.path.join(
236 | path_to_cache, f"{self.unique_description_hash}-{type(self).__name__}-{suffix}"
237 | )
238 | path_to_description = get_path_to("description.txt")
239 | path_to_document_index = get_path_to("document_index.npy")
240 | path_to_sample_index = get_path_to("sample_index.npy")
241 | path_to_shuffle_index = get_path_to("shuffle_index.npy")
242 | cache_hit = all(
243 | map(
244 | os.path.isfile,
245 | [
246 | path_to_description,
247 | path_to_document_index,
248 | path_to_sample_index,
249 | path_to_shuffle_index,
250 | ],
251 | )
252 | )
253 |
254 | num_tokens_per_epoch = self._get_num_tokens_per_epoch()
255 | num_epochs = self._get_num_epochs(num_tokens_per_epoch)
256 |
257 | if not cache_hit and torch.distributed.get_rank() == 0:
258 | log_single_rank(
259 | logger,
260 | logging.INFO,
261 | f"Build and save the {type(self).__name__} {self.index_split.name} indices",
262 | )
263 |
264 | sequence_length = self.config.sequence_length
265 |
266 | if num_epochs == 1:
267 | separate_final_epoch = False
268 | else:
269 | # Get the number of samples for the last epoch
270 | num_samples_sans_final_epoch = (
271 | (num_epochs - 1) * num_tokens_per_epoch - 1
272 | ) // sequence_length
273 | num_samples_from_final_epoch = self.num_samples - num_samples_sans_final_epoch
274 | num_samples_per_epoch = (num_tokens_per_epoch - 1) // sequence_length
275 |
276 | # num_samples_from_final_epoch should be non-negative
277 | assert num_samples_from_final_epoch >= 0
278 |
279 | # num_samples_from_final_epoch should not exceed max value
280 | assert num_samples_from_final_epoch <= num_samples_per_epoch + 1
281 |
282 | # Separate the final epoch if it falls below the threshold
283 | threshold = 0.80
284 | separate_final_epoch = num_samples_from_final_epoch < int(
285 | threshold * num_samples_per_epoch
286 | )
287 |
288 | log_single_rank(
289 | logger,
290 | logging.DEBUG,
291 | f"> num_samples_from_final_epoch: {num_samples_from_final_epoch}",
292 | )
293 | log_single_rank(logger, logging.DEBUG, f"> threshold: {threshold}")
294 | log_single_rank(
295 | logger, logging.DEBUG, f"> num_samples_per_epoch: {num_samples_per_epoch}"
296 | )
297 |
298 | log_single_rank(
299 | logger, logging.DEBUG, f"> separate_final_epoch: {separate_final_epoch}"
300 | )
301 |
302 | numpy_random_state = numpy.random.RandomState(self.config.random_seed)
303 |
304 | os.makedirs(path_to_cache, exist_ok=True)
305 |
306 | # Write the description
307 | with open(path_to_description, "wt") as writer:
308 | writer.write(self.unique_description)
309 |
310 | # Build the document index
311 | log_single_rank(
312 | logger,
313 | logging.INFO,
314 | f"\tBuild and save the document index to {os.path.basename(path_to_document_index)}",
315 | )
316 | t_beg = time.time()
317 | document_index = _build_document_index(
318 | self.indexed_indices, num_epochs, numpy_random_state, separate_final_epoch, self.config.enable_shuffle
319 | )
320 | numpy.save(path_to_document_index, document_index, allow_pickle=True)
321 | t_end = time.time()
322 | log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
323 |
324 | # Build the sample index
325 | log_single_rank(
326 | logger,
327 | logging.INFO,
328 | f"\tBuild and save the sample index to {os.path.basename(path_to_sample_index)}",
329 | )
330 | t_beg = time.time()
331 | from core.datasets import helpers
332 |
333 | assert document_index.dtype == numpy.int32
334 | assert self.indexed_dataset.sequence_lengths.dtype == numpy.int32
335 | sample_index = helpers.build_sample_idx(
336 | self.indexed_dataset.sequence_lengths,
337 | document_index,
338 | sequence_length,
339 | num_epochs,
340 | num_tokens_per_epoch,
341 | )
342 | numpy.save(path_to_sample_index, sample_index, allow_pickle=True)
343 | t_end = time.time()
344 | log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
345 |
346 | # Build the shuffle index
347 | log_single_rank(
348 | logger,
349 | logging.INFO,
350 | f"\tBuild and save the shuffle index to {os.path.basename(path_to_shuffle_index)}",
351 | )
352 | t_beg = time.time()
353 | if separate_final_epoch:
354 | shuffle_index = _build_shuffle_index(
355 | num_samples_sans_final_epoch, sample_index.shape[0] - 1, numpy_random_state, True
356 | )
357 | else:
358 | shuffle_index = _build_shuffle_index(
359 | sample_index.shape[0] - 1, sample_index.shape[0] - 1, numpy_random_state, True
360 | )
361 | numpy.save(path_to_shuffle_index, shuffle_index, allow_pickle=True)
362 | t_end = time.time()
363 | log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
364 |
365 | log_single_rank(
366 | logger, logging.INFO, f"> total number of samples: {sample_index.shape[0] - 1}"
367 | )
368 | log_single_rank(logger, logging.INFO, f"> total number of epochs: {num_epochs}")
369 |
370 | return document_index, sample_index, shuffle_index
371 |
372 | log_single_rank(
373 | logger, logging.INFO, f"Load the {type(self).__name__} {self.index_split.name} indices"
374 | )
375 |
376 | log_single_rank(
377 | logger,
378 | logging.INFO,
379 | f"\tLoad the document index from {os.path.basename(path_to_document_index)}",
380 | )
381 | t_beg = time.time()
382 | document_index = numpy.load(path_to_document_index, allow_pickle=True, mmap_mode='r')
383 | t_end = time.time()
384 | log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
385 |
386 | log_single_rank(
387 | logger,
388 | logging.INFO,
389 | f"\tLoad the sample index from {os.path.basename(path_to_sample_index)}",
390 | )
391 | t_beg = time.time()
392 | sample_index = numpy.load(path_to_sample_index, allow_pickle=True, mmap_mode='r')
393 | t_end = time.time()
394 | log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
395 |
396 | log_single_rank(
397 | logger,
398 | logging.INFO,
399 | f"\tLoad the shuffle index from {os.path.basename(path_to_shuffle_index)}",
400 | )
401 | t_beg = time.time()
402 | shuffle_index = numpy.load(path_to_shuffle_index, allow_pickle=True, mmap_mode='r')
403 | t_end = time.time()
404 | log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
405 |
406 | log_single_rank(
407 | logger, logging.INFO, f"> total number of samples: {sample_index.shape[0] - 1}"
408 | )
409 | log_single_rank(logger, logging.INFO, f"> total number of epochs: {num_epochs}")
410 |
411 | return document_index, sample_index, shuffle_index
412 |
413 | def _get_num_tokens_per_epoch(self) -> int:
414 | """Calculate the number of tokens in a single epoch
415 |
416 | Returns:
417 | int: The number of tokens in a single epoch
418 | """
419 | return int(numpy.sum(self.indexed_dataset.sequence_lengths[self.indexed_indices]))
420 |
421 | def _get_num_epochs(self, num_tokens_per_epoch: int) -> int:
422 | """Calculate the number of epochs
423 |
424 | Args:
425 | num_tokens_per_epoch (int): The number of tokens in a single epoch
426 |
427 | Returns:
428 | int: The number of epochs
429 | """
430 | num_epochs = 0
431 | num_tokens = 0
432 | num_tokens_requested = (self.num_samples * self.config.sequence_length) + 1
433 | while True:
434 | num_epochs += 1
435 | num_tokens += num_tokens_per_epoch
436 | if num_tokens >= num_tokens_requested:
437 | return num_epochs
438 |
439 |
440 | def _build_document_index(
441 | documents: numpy.ndarray,
442 | num_epochs: int,
443 | numpy_random_state: numpy.random.RandomState,
444 | separate_final_epoch: bool,
445 | enable_shuffle: bool = False,
446 | ) -> numpy.ndarray:
447 | """Build an array with length = num epochs * num documents
448 |
449 | Args:
450 | documents (numpy.ndarray): the subset of exposed document indices
451 |
452 | num_epochs (int): The number of epochs
453 |
454 | numpy_random_state (numpy.random.RandomState): The NumPy random state
455 |
456 | separate_final_epoch (bool): Whether to exclude the last epoch from the global shuffle
457 |
458 | enable_shuffle (bool): Whether to enable the shuffle. Default is False to ensure the reproducibility
459 |
460 | Returns:
461 | numpy.ndarray: The document index
462 |
463 | TODO: Explain separate_final_epoch
464 | """
465 | if not separate_final_epoch or num_epochs == 1:
466 | document_index = numpy.mgrid[0:num_epochs, 0 : len(documents)][1]
467 | document_index[:] = documents
468 | document_index = document_index.reshape(-1)
469 | document_index = document_index.astype(numpy.int32)
470 | if enable_shuffle:
471 | print("INFO: document_index shuffle is enabled...")
472 | numpy_random_state.shuffle(document_index)
473 | else:
474 | print("INFO: document_index shuffle is disabled...")
475 | return document_index
476 |
477 | doc_idx_first = _build_document_index(documents, num_epochs - 1, numpy_random_state, False, enable_shuffle)
478 | doc_idx_last = _build_document_index(documents, 1, numpy_random_state, False, enable_shuffle)
479 | return numpy.concatenate((doc_idx_first, doc_idx_last))
480 |
481 |
482 | def _build_shuffle_index(
483 | num_samples: int, total_size: int, numpy_random_state: numpy.random.RandomState,
484 | enable_shuffle: bool = False,
485 | ) -> numpy.ndarray:
486 | """Build the range [0, size) and shuffle
487 |
488 | Args:
489 | num_samples (int): The size of the first shuffle range [0, num_samples)
490 |
491 | total_size (int): The size of the entire index. If larger than 'num_samples', it defines
492 |
493 | the second shuffle range [num_samples, total_size)
494 |
495 | numpy_random_state (numpy.random.RandomState): The NumPy random state
496 |
497 | Returns:
498 | numpy.ndarray: The shuffle index
499 |
500 | TODO: Explain [0, num_samples) [num_samples, total_size) split
501 | """
502 | dtype_ = numpy.uint32
503 | if total_size >= (numpy.iinfo(numpy.uint32).max - 1):
504 | dtype_ = numpy.int64
505 |
506 | shuffle_idx_first = numpy.arange(start=0, stop=num_samples, step=1, dtype=dtype_)
507 | if enable_shuffle:
508 | print("INFO: shuffle_index shuffle is enabled...")
509 | numpy_random_state.shuffle(shuffle_idx_first)
510 | else:
511 | print("INFO: shuffle_index shuffle is disabled...")
512 | if num_samples == total_size:
513 | return shuffle_idx_first
514 |
515 | shuffle_idx_last = numpy.arange(start=num_samples, stop=total_size, step=1, dtype=dtype_)
516 | if enable_shuffle:
517 | print("INFO: shuffle_index shuffle is enabled...")
518 | numpy_random_state.shuffle(shuffle_idx_last)
519 | else:
520 | print("INFO: shuffle_index shuffle is disabled...")
521 |
522 | return numpy.concatenate((shuffle_idx_first, shuffle_idx_last))
523 |
524 |
525 | def _get_ltor_masks_and_position_ids(
526 | data: torch.Tensor,
527 | eod_token: int,
528 | reset_position_ids: bool,
529 | reset_attention_mask: bool,
530 | eod_mask_loss: bool,
531 | ):
532 | """Build masks and position id for left to right model.
533 |
534 | Args:
535 | data (torch.Tensor): The data tenor that holds the tokens from the dataset
536 |
537 | eod_token (int): ID of the token to that is considered the EOD
538 |
539 | reset_position_ids (bool): Switch to reset the document position ID's
540 |
541 | reset_attention_mask (bool): Switch to reset the attention mask
542 |
543 | eod_mask_loss (bool): Switch to enable the EOD mask loss
544 |
545 | Returns:
546 | torch.Tensor : Attention mask needed to be used for Attention
547 |
548 | torch.Tensor : The mask used for loss value during training
549 |
550 | torch.Tensor : The position ID's of the token
551 |
552 | """
553 |
554 | # Extract batch size and sequence length.
555 | seq_length = data.numel()
556 |
557 | attention_mask = torch.tril(torch.ones((seq_length, seq_length), device=data.device)).unsqueeze(
558 | 0
559 | )
560 |
561 | # Loss mask.
562 | loss_mask = torch.ones(seq_length, dtype=torch.float, device=data.device)
563 | if eod_mask_loss:
564 | loss_mask[data == eod_token] = 0.0
565 |
566 | # Position ids.
567 | position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
568 | # We need to clone as the ids will be modifed based on batch index.
569 | if reset_position_ids:
570 | position_ids = position_ids.clone()
571 |
572 | if reset_position_ids or reset_attention_mask:
573 |
574 | # Find indecies where EOD token is.
575 | eod_index = position_ids[data[b] == eod_token]
576 | # Detach indecies from positions if going to modify positions.
577 | if reset_position_ids:
578 | eod_index = eod_index.clone()
579 |
580 | # Loop through EOD indecies:
581 | prev_index = 0
582 | for j in range(eod_index.numel()):
583 | i = eod_index[j]
584 | # Mask attention loss.
585 | if reset_attention_mask:
586 | attention_mask[0, (i + 1) :, : (i + 1)] = 0
587 | # Reset positions.
588 | if reset_position_ids:
589 | position_ids[(i + 1) :] -= i + 1 - prev_index
590 | prev_index = i + 1
591 |
592 | # Convert attention mask to binary:
593 | attention_mask = attention_mask < 0.5
594 | attention_mask = attention_mask.float()
595 |
596 | return attention_mask, loss_mask, position_ids
--------------------------------------------------------------------------------
/code/finetune/core/datasets/indexed_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # Essentially re-written in entirety
7 |
8 | import logging
9 | import os
10 | import shutil
11 | import struct
12 | import time
13 | from enum import Enum
14 | from functools import lru_cache
15 | from itertools import accumulate
16 | from types import TracebackType
17 | from typing import List, Optional, Tuple, Type, Union
18 |
19 | import numpy
20 | import torch
21 |
22 | from core.datasets.utils import log_single_rank
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 | _INDEX_HEADER = b"MMIDIDX\x00\x00"
27 |
28 |
29 | class DType(Enum):
30 | """The NumPy data type Enum for writing/reading the MMapIndexedDataset indices
31 | """
32 |
33 | uint8 = 1
34 | int8 = 2
35 | int16 = 3
36 | int32 = 4
37 | int64 = 5
38 | float64 = 6
39 | float32 = 7
40 | uint16 = 8
41 |
42 | @classmethod
43 | def code_from_dtype(cls, value: Type[numpy.number]) -> int:
44 | """Get the code from the dtype
45 |
46 | Args:
47 | value (Type[numpy.number]): The dtype
48 |
49 | Returns:
50 | int: The code
51 | """
52 | return cls[value.__name__].value
53 |
54 | @classmethod
55 | def dtype_from_code(cls, value: int) -> Type[numpy.number]:
56 | """Get the dtype from the code
57 |
58 | Args:
59 | value (int): The code
60 |
61 | Returns:
62 | Type[numpy.number]: The dtype
63 | """
64 | return getattr(numpy, cls(value).name)
65 |
66 | @staticmethod
67 | def size(key: Union[int, Type[numpy.number]]) -> int:
68 | """Get the size of the dtype/code in bytes
69 |
70 | Args:
71 | key (Union[int, Type[numpy.number]]): The dtype or code
72 |
73 | Raises:
74 | ValueError: If the key is neither dtype nor integer code
75 |
76 | Returns:
77 | int: The size of the dtype/code in in bytes
78 | """
79 | if isinstance(key, int):
80 | return DType.dtype_from_code(key)().itemsize
81 | elif numpy.number in key.__mro__:
82 | return key().itemsize
83 | else:
84 | raise ValueError
85 |
86 | @staticmethod
87 | def optimal_dtype(cardinality: Optional[int]) -> Type[numpy.number]:
88 | """Get the dtype to use for an index of a certain cardinality
89 |
90 | Args:
91 | cardinality (Optional[int]): The number of elements to be indexed
92 |
93 | Returns:
94 | Type[numpy.number]: The dtype to use for the index
95 | """
96 | if cardinality is not None and cardinality < 65500:
97 | return numpy.uint16
98 | else:
99 | return numpy.int32
100 |
101 |
102 | class _IndexWriter(object):
103 | """Object class to write the index (.idx) file
104 |
105 | Args:
106 | idx_path (str): The path to the index file
107 |
108 | dtype (Type[numpy.number]): The dtype of the index file
109 | """
110 |
111 | def __init__(self, idx_path: str, dtype: Type[numpy.number]) -> None:
112 | self.idx_path = idx_path
113 | self.dtype = dtype
114 |
115 | def __enter__(self) -> "_IndexWriter":
116 | """Enter the context introduced by the 'with' keyword
117 |
118 | Returns:
119 | _IndexWriter: The instance
120 | """
121 | self.idx_writer = open(self.idx_path, "wb")
122 | # fixed, vestigial practice
123 | self.idx_writer.write(_INDEX_HEADER)
124 | # fixed, vestigial practice
125 | self.idx_writer.write(struct.pack(" Optional[bool]:
136 | """Exit the context introduced by the 'with' keyword
137 |
138 | Args:
139 | exc_type (Optional[Type[BaseException]]): Exception type
140 |
141 | exc_val (Optional[BaseException]): Exception value
142 |
143 | exc_tb (Optional[TracebackType]): Exception traceback object
144 |
145 | Returns:
146 | Optional[bool]: Whether to silence the exception
147 | """
148 | self.idx_writer.close()
149 |
150 | def write(
151 | self,
152 | sequence_lengths: List[int],
153 | sequence_modes: Optional[List[int]],
154 | document_indices: List[int],
155 | ) -> None:
156 | """Write the index (.idx) file
157 |
158 | Args:
159 | sequence_lengths (List[int]): The length of each sequence
160 |
161 | sequence_modes (Optional[List[int]]): The mode of each sequences
162 |
163 | document_indices (List[int]): The seqyebce indices demarcating the end of each document
164 | """
165 | sequence_pointers = self._sequence_pointers(sequence_lengths)
166 |
167 | # the number of sequences in the dataset
168 | sequence_count = len(sequence_lengths)
169 | self.idx_writer.write(struct.pack(" List[int]:
196 | """Build the sequence pointers per the sequence lengths and dtype size
197 |
198 | Args:
199 | sequence_lengths (List[int]): The length of each sequence
200 |
201 | Returns:
202 | List[int]: The pointer to the beginning of each sequence
203 | """
204 | itemsize = DType.size(self.dtype)
205 | curr_ptr = 0
206 | list_ptr = []
207 | for length in sequence_lengths:
208 | list_ptr.append(curr_ptr)
209 | curr_ptr += length * itemsize
210 | return list_ptr
211 |
212 |
213 | class _IndexReader(object):
214 | """Object class to read the index (.idx) file
215 |
216 | Args:
217 | idx_path (str): The path to the index file
218 |
219 | multimodal (bool): Whether the dataset is multimodal
220 | """
221 |
222 | def __init__(self, idx_path: str, multimodal: bool) -> None:
223 |
224 | log_single_rank(logger, logging.INFO, f"Load the {type(self).__name__} from {idx_path}")
225 |
226 | with open(idx_path, "rb") as stream:
227 | header = stream.read(9)
228 | assert header == _INDEX_HEADER, f"bad header, cannot read: {idx_path}"
229 |
230 | version = struct.unpack(" time elapsed: {t_end - t_beg:4f} seconds")
252 |
253 | log_single_rank(logger, logging.INFO, f"\tExtract the sequence pointers")
254 | t_beg = time.time()
255 | self.sequence_pointers = numpy.frombuffer(
256 | self.bin_buffer,
257 | dtype=numpy.int64,
258 | count=self.sequence_count,
259 | offset=offset + self.sequence_lengths.nbytes,
260 | )
261 | t_end = time.time()
262 | log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
263 |
264 | log_single_rank(logger, logging.INFO, f"\tExtract the document indices")
265 | t_beg = time.time()
266 | self.document_indices = numpy.frombuffer(
267 | self.bin_buffer,
268 | dtype=numpy.int64,
269 | count=self.document_count,
270 | offset=offset + self.sequence_lengths.nbytes + self.sequence_pointers.nbytes,
271 | )
272 | t_end = time.time()
273 | log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
274 |
275 | self.sequence_modes = None
276 | if multimodal:
277 | log_single_rank(logger, logging.INFO, f"\tExtract the sequence modes")
278 | t_beg = time.time()
279 | self.sequence_modes = numpy.frombuffer(
280 | self.bin_buffer,
281 | dtype=numpy.int8,
282 | count=self.sequence_count,
283 | offset=offset
284 | + self.sequence_lengths.nbytes
285 | + self.sequence_pointers.nbytes
286 | + self.document_indices.nbytes,
287 | )
288 | t_end = time.time()
289 | log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
290 |
291 | assert self.sequence_lengths.shape[0] == len(self)
292 | assert self.sequence_lengths.shape[0] == self.sequence_count
293 | assert self.sequence_lengths.shape[0] == self.document_indices[-1]
294 |
295 | log_single_rank(logger, logging.INFO, f"> total number of sequences: {len(self)}")
296 | log_single_rank(
297 | logger,
298 | logging.INFO,
299 | f"> total number of documents: {self.document_indices.shape[0] - 1}",
300 | )
301 |
302 | def __del__(self) -> None:
303 | """Clean up the object
304 | """
305 | self.bin_buffer_mmap._mmap.close()
306 | del self.bin_buffer_mmap
307 |
308 | def __len__(self) -> int:
309 | """Return the length of the dataset
310 |
311 | Returns:
312 | int: The length of the dataset
313 | """
314 | return self.sequence_count
315 |
316 | @lru_cache(maxsize=8)
317 | def __getitem__(self, idx: int) -> Tuple[numpy.int32, numpy.int64, Optional[numpy.int8]]:
318 | """Return the pointer, length, and mode at the index
319 |
320 | Args:
321 | idx (int): The index into the dataset
322 |
323 | Returns:
324 | Tuple[numpy.int32, numpy.int64, Optional[numpy.int8]]: The pointer, length and mode at
325 | the index
326 | """
327 | return (
328 | self.sequence_pointers[idx],
329 | self.sequence_lengths[idx],
330 | self.sequence_modes[idx] if self.sequence_modes is not None else None,
331 | )
332 |
333 |
334 | class MMapIndexedDataset(torch.utils.data.Dataset):
335 | """The low-level interface dataset class
336 |
337 | Args:
338 | path_prefix (str): The index (.idx) and data (.bin) prefix
339 |
340 | multimodal (bool, optional): Whether the dataset is multimodal. Defaults to False.
341 | """
342 |
343 | def __init__(self, path_prefix: str, multimodal: bool = False) -> None:
344 | super().__init__()
345 | self.path_prefix = None
346 | self.multimodal = None
347 |
348 | self.index = None
349 | self.bin_buffer = None
350 | self.bin_buffer_mmap = None
351 |
352 | self.initialize(path_prefix, multimodal)
353 |
354 | def initialize(self, path_prefix: str, multimodal: bool) -> None:
355 | """Initialize the dataset
356 |
357 | This method is called by MMapIndexedDataset.__init__ during object creation and by
358 | MMapIndexedDataset.__setstate__ during un-puckling
359 |
360 | Args:
361 | path_prefix (str): The index (.idx) and data (.bin) prefix
362 |
363 | multimodal (bool): Whether the dataset is multimodal
364 | """
365 | self.path_prefix = path_prefix
366 | self.multimodal = multimodal
367 | self.index = _IndexReader(get_idx_path(self.path_prefix), self.multimodal)
368 | self.bin_buffer_mmap = numpy.memmap(get_bin_path(self.path_prefix), mode="r", order="C")
369 | self.bin_buffer = memoryview(self.bin_buffer_mmap)
370 |
371 | def __getstate__(self) -> Tuple[str, bool]:
372 | """Get the state during pickling
373 |
374 | Returns:
375 | Tuple[str, bool]: The state tuple
376 | """
377 | return self.path_prefix, self.multimodal
378 |
379 | def __setstate__(self, state: Tuple[str, bool]) -> None:
380 | """Set the state during un-pickling
381 |
382 | Args:
383 | state (Tuple[str, bool]): The state tuple
384 | """
385 | path_prefix, multimodal = state
386 | self.initialize(path_prefix, multimodal)
387 |
388 | def __del__(self) -> None:
389 | """Clean up the object
390 | """
391 | if self.bin_buffer_mmap is not None:
392 | self.bin_buffer_mmap._mmap.close()
393 | del self.bin_buffer_mmap
394 | del self.index
395 |
396 | def __len__(self) -> int:
397 | """Return the length of the dataset i.e. the number of sequences in the index
398 |
399 | Returns:
400 | int: The length of the dataset
401 | """
402 | return len(self.index)
403 |
404 | def __getitem__(
405 | self, idx: Union[int, numpy.integer, slice]
406 | ) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
407 | """Return from the dataset
408 |
409 | Args:
410 | idx (Union[int, numpy.integer, slice]): The index or index slice into the dataset
411 |
412 | Raises:
413 | ValueError: When the index slice is non-contiguous
414 |
415 | TypeError: When the index is of an unexpected type
416 |
417 | Returns:
418 | Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]: The sequence tokens and
419 | modes at the index or index slice
420 | """
421 | if isinstance(idx, (int, numpy.integer)):
422 | sequence_pointer, sequence_length, sequence_mode = self.index[idx]
423 | sequence = numpy.frombuffer(
424 | self.bin_buffer,
425 | dtype=self.index.dtype,
426 | count=sequence_length,
427 | offset=sequence_pointer,
428 | )
429 | return (sequence, sequence_mode) if sequence_mode is not None else sequence
430 | elif isinstance(idx, slice):
431 | start, stop, step = idx.indices(len(self))
432 | if step != 1:
433 | raise ValueError("Slices into indexed_dataset must be contiguous")
434 | sequence_lengths = self.index.sequence_lengths[idx]
435 | sequence_modes = self.index.sequence_modes[idx] if self.multimodal else None
436 | sequence_offsets = list(accumulate(sequence_lengths))
437 | sequences = numpy.split(
438 | numpy.frombuffer(
439 | self.bin_buffer,
440 | dtype=self.index.dtype,
441 | count=sum(sequence_lengths),
442 | offset=self.index.sequence_pointers[start],
443 | ),
444 | sequence_offsets[:-1],
445 | )
446 | return (sequences, sequence_modes) if sequence_modes is not None else sequences
447 | else:
448 | raise TypeError("Unexpected type received for idx: {}".format(type(idx)))
449 |
450 | def get(self, idx: int, offset: int = 0, length: Optional[int] = None) -> numpy.ndarray:
451 | """Retrieve a single item from the dataset with the option to only
452 | return a portion of the item.
453 |
454 | get(idx) is the same as [idx] but get() does not support slicing.
455 | """
456 | sequence_pointer, sequence_length, sequence_mode = self.index[idx]
457 | if length is None:
458 | length = sequence_length - offset
459 | sequence_pointer += offset * DType.size(self.index.dtype)
460 | sequence = numpy.frombuffer(
461 | self.bin_buffer, dtype=self.index.dtype, count=length, offset=sequence_pointer
462 | )
463 | return (sequence, sequence_mode) if sequence_mode is not None else sequence
464 |
465 | @property
466 | def sequence_lengths(self) -> numpy.ndarray:
467 | """Get the sequence lengths
468 |
469 | Returns:
470 | numpy.ndarray: The sequence lengths
471 | """
472 | return self.index.sequence_lengths
473 |
474 | @property
475 | def document_indices(self) -> numpy.ndarray:
476 | """Get the document indices
477 |
478 | Returns:
479 | numpy.ndarray: The document indices
480 | """
481 | return self.index.document_indices
482 |
483 | def get_document_indices(self) -> numpy.ndarray:
484 | """Get the document indices
485 |
486 | This method is slated for deprecation.
487 |
488 | Returns:
489 | numpy.ndarray: The document indices
490 | """
491 | return self.index.document_indices
492 |
493 | def set_document_indices(self, document_indices: numpy.ndarray) -> None:
494 | """Set the document indices
495 |
496 | This method is slated for deprecation.
497 |
498 | Args:
499 | document_indices (numpy.ndarray): The document indices
500 | """
501 | self.index.document_indices = document_indices
502 |
503 | @property
504 | def sequence_modes(self) -> numpy.ndarray:
505 | """Get the sequence modes
506 |
507 | Returns:
508 | numpy.ndarray: The sequence modes
509 | """
510 | return self.index.sequence_modes
511 |
512 | @staticmethod
513 | def exists(path_prefix: str) -> bool:
514 | """Return whether the MMapIndexedDataset exists on disk at the prefix
515 |
516 | Args:
517 | path_prefix (str): The prefix to the index (.idx) and data (.bin) files
518 |
519 | Returns:
520 | bool: Whether the MMapIndexedDataset exists on disk at the prefix
521 | """
522 | return os.path.exists(get_idx_path(path_prefix)) and os.path.exists(
523 | get_bin_path(path_prefix)
524 | )
525 |
526 |
527 | class MMapIndexedDatasetBuilder(object):
528 | """Builder class for the MMapIndexedDataset class
529 |
530 | Args:
531 | bin_path (str): The path to the data (.bin) file
532 |
533 | dtype (Type[numpy.number], optional): The dtype of the index file. Defaults to numpy.int32.
534 |
535 | multimodal (bool, optional): Whether the dataset is multimodal. Defaults to False.
536 | """
537 |
538 | def __init__(
539 | self, bin_path: str, dtype: Type[numpy.number] = numpy.int32, multimodal: bool = False
540 | ) -> None:
541 | self.data_file = open(bin_path, "wb")
542 | self.dtype = dtype
543 | self.multimodal = multimodal
544 |
545 | self.sequence_lengths = []
546 | self.document_indices = [0]
547 | self.sequence_modes = [] if self.multimodal else None
548 |
549 | def add_item(self, tensor: torch.Tensor, mode: int = 0) -> None:
550 | """Add a single item to the dataset
551 |
552 | Args:
553 | tensor (torch.Tensor): The item to add to the data file
554 |
555 | mode (int, optional): The mode for the item. Defaults to 0.
556 | """
557 | np_array = numpy.array(tensor.numpy(), dtype=self.dtype)
558 | self.data_file.write(np_array.tobytes(order="C"))
559 | self.sequence_lengths.append(np_array.size)
560 | if self.multimodal:
561 | self.sequence_modes.append(mode)
562 |
563 | def add_document(
564 | self, tensor: torch.Tensor, lengths: List[int], modes: Optional[List[int]] = None
565 | ) -> None:
566 | """Add an entire document to the dataset
567 |
568 | Args:
569 | tensor (torch.Tensor): The document to add
570 | lengths (List[int]): The lengths of each item in the document
571 | modes (Optional[List[int]], optional): The modes for each item in the document.
572 | Defaults to None.
573 | """
574 | np_array = numpy.array(tensor, dtype=self.dtype)
575 | self.data_file.write(np_array.tobytes(order="C"))
576 | self.sequence_lengths.extend(lengths)
577 | self.document_indices.append(len(self.sequence_lengths))
578 | if self.multimodal:
579 | self.sequence_modes.extend(modes if modes is not None else [0] * lengths)
580 |
581 | def end_document(self) -> None:
582 | """Finalize the document, for use with MMapIndexedDatasetBuilder.add_item
583 | """
584 | self.document_indices.append(len(self.sequence_lengths))
585 |
586 | def add_index(self, path_prefix: str) -> None:
587 | """Add an entire MMapIndexedDataset to the dataset
588 |
589 | Args:
590 | path_prefix (str): The index (.idx) and data (.bin) prefix
591 | """
592 | # Concatenate index
593 | index = _IndexReader(get_idx_path(path_prefix), multimodal=self.multimodal)
594 | assert index.dtype == self.dtype
595 |
596 | offset = len(self.sequence_lengths)
597 | self.sequence_lengths.extend(index.sequence_lengths)
598 | self.document_indices.extend((offset + index.document_indices)[1:])
599 |
600 | if self.multimodal:
601 | self.sequence_modes.extend(index.sequence_modes)
602 |
603 | # Concatenate data
604 | with open(get_bin_path(path_prefix), "rb") as f:
605 | shutil.copyfileobj(f, self.data_file)
606 |
607 | def finalize(self, idx_path: str) -> None:
608 | """Clean up and write the index (.idx) file
609 |
610 | Args:
611 | idx_path (str): The path to the index file
612 | """
613 | self.data_file.close()
614 | with _IndexWriter(idx_path, self.dtype) as writer:
615 | writer.write(self.sequence_lengths, self.sequence_modes, self.document_indices)
616 |
617 |
618 | def get_idx_path(path_prefix: str) -> str:
619 | """Get the path to the index file from the prefix
620 |
621 | Args:
622 | path_prefix (str): The prefix
623 |
624 | Returns:
625 | str: The path to the index file
626 | """
627 | return path_prefix + ".idx"
628 |
629 |
630 | def get_bin_path(path_prefix: str) -> str:
631 | """Get the path to the data file from the prefix
632 |
633 | Args:
634 | path_prefix (str): The prefix
635 |
636 | Returns:
637 | str: The path to the data file
638 | """
639 | return path_prefix + ".bin"
640 |
--------------------------------------------------------------------------------
/code/inference/infer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
4 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
5 | import re
6 | import random
7 | import uuid
8 | import copy
9 | from tqdm import tqdm
10 | from collections import Counter
11 | import argparse
12 | import numpy as np
13 | import torch
14 | import torchaudio
15 | from torchaudio.transforms import Resample
16 | import soundfile as sf
17 | from einops import rearrange
18 | from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
19 | from omegaconf import OmegaConf
20 | from codecmanipulator import CodecManipulator
21 | from mmtokenizer import _MMSentencePieceTokenizer
22 | from models.soundstream_hubert_new import SoundStream
23 | from vocoder import build_codec_model, process_audio
24 | from post_process_audio import replace_low_freq_with_energy_matched
25 |
26 |
27 | parser = argparse.ArgumentParser()
28 | # Model Configuration:
29 | parser.add_argument("--stage1_model", type=str, default="m-a-p/YuE-s1-7B-anneal-en-cot", help="The model checkpoint path or identifier for the Stage 1 model.")
30 | parser.add_argument("--stage2_model", type=str, default="m-a-p/YuE-s2-1B-general", help="The model checkpoint path or identifier for the Stage 2 model.")
31 | parser.add_argument("--max_new_tokens", type=int, default=3000, help="The maximum number of new tokens to generate in one pass during text generation.")
32 | parser.add_argument("--repetition_penalty", type=float, default=1.1, help="repetition_penalty ranges from 1.0 to 2.0 (or higher in some cases). It controls the diversity and coherence of the audio tokens generated. The higher the value, the greater the discouragement of repetition. Setting value to 1.0 means no penalty.")
33 | parser.add_argument("--run_n_segments", type=int, default=2, help="The number of segments to process during the generation.")
34 | parser.add_argument("--stage2_batch_size", type=int, default=4, help="The batch size used in Stage 2 inference.")
35 | # Prompt
36 | parser.add_argument("--genre_txt", type=str, required=True, help="The file path to a text file containing genre tags that describe the musical style or characteristics (e.g., instrumental, genre, mood, vocal timbre, vocal gender). This is used as part of the generation prompt.")
37 | parser.add_argument("--lyrics_txt", type=str, required=True, help="The file path to a text file containing the lyrics for the music generation. These lyrics will be processed and split into structured segments to guide the generation process.")
38 | parser.add_argument("--use_audio_prompt", action="store_true", help="If set, the model will use an audio file as a prompt during generation. The audio file should be specified using --audio_prompt_path.")
39 | parser.add_argument("--audio_prompt_path", type=str, default="", help="The file path to an audio file to use as a reference prompt when --use_audio_prompt is enabled.")
40 | parser.add_argument("--prompt_start_time", type=float, default=0.0, help="The start time in seconds to extract the audio prompt from the given audio file.")
41 | parser.add_argument("--prompt_end_time", type=float, default=30.0, help="The end time in seconds to extract the audio prompt from the given audio file.")
42 | parser.add_argument("--use_dual_tracks_prompt", action="store_true", help="If set, the model will use dual tracks as a prompt during generation. The vocal and instrumental files should be specified using --vocal_track_prompt_path and --instrumental_track_prompt_path.")
43 | parser.add_argument("--vocal_track_prompt_path", type=str, default="", help="The file path to a vocal track file to use as a reference prompt when --use_dual_tracks_prompt is enabled.")
44 | parser.add_argument("--instrumental_track_prompt_path", type=str, default="", help="The file path to an instrumental track file to use as a reference prompt when --use_dual_tracks_prompt is enabled.")
45 | # Output
46 | parser.add_argument("--output_dir", type=str, default="./output", help="The directory where generated outputs will be saved.")
47 | parser.add_argument("--keep_intermediate", action="store_true", help="If set, intermediate outputs will be saved during processing.")
48 | parser.add_argument("--disable_offload_model", action="store_true", help="If set, the model will not be offloaded from the GPU to CPU after Stage 1 inference.")
49 | parser.add_argument("--cuda_idx", type=int, default=0)
50 | parser.add_argument("--seed", type=int, default=42, help="An integer value to reproduce generation.")
51 | # Config for xcodec and upsampler
52 | parser.add_argument('--basic_model_config', default='./xcodec_mini_infer/final_ckpt/config.yaml', help='YAML files for xcodec configurations.')
53 | parser.add_argument('--resume_path', default='./xcodec_mini_infer/final_ckpt/ckpt_00360000.pth', help='Path to the xcodec checkpoint.')
54 | parser.add_argument('--config_path', type=str, default='./xcodec_mini_infer/decoders/config.yaml', help='Path to Vocos config file.')
55 | parser.add_argument('--vocal_decoder_path', type=str, default='./xcodec_mini_infer/decoders/decoder_131000.pth', help='Path to Vocos decoder weights.')
56 | parser.add_argument('--inst_decoder_path', type=str, default='./xcodec_mini_infer/decoders/decoder_151000.pth', help='Path to Vocos decoder weights.')
57 | parser.add_argument('-r', '--rescale', action='store_true', help='Rescale output to avoid clipping.')
58 |
59 |
60 | args = parser.parse_args()
61 | if args.use_audio_prompt and not args.audio_prompt_path:
62 | raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
63 | if args.use_dual_tracks_prompt and not args.vocal_track_prompt_path and not args.instrumental_track_prompt_path:
64 | raise FileNotFoundError("Please offer dual tracks prompt filepath using '--vocal_track_prompt_path' and '--inst_decoder_path', when you enable '--use_dual_tracks_prompt'!")
65 | stage1_model = args.stage1_model
66 | stage2_model = args.stage2_model
67 | cuda_idx = args.cuda_idx
68 | max_new_tokens = args.max_new_tokens
69 | stage1_output_dir = os.path.join(args.output_dir, f"stage1")
70 | stage2_output_dir = stage1_output_dir.replace('stage1', 'stage2')
71 | os.makedirs(stage1_output_dir, exist_ok=True)
72 | os.makedirs(stage2_output_dir, exist_ok=True)
73 | def seed_everything(seed=42):
74 | random.seed(seed)
75 | np.random.seed(seed)
76 | torch.manual_seed(seed)
77 | torch.cuda.manual_seed_all(seed)
78 | torch.backends.cudnn.deterministic = True
79 | torch.backends.cudnn.benchmark = False
80 | seed_everything(args.seed)
81 | # load tokenizer and model
82 | device = torch.device(f"cuda:{cuda_idx}" if torch.cuda.is_available() else "cpu")
83 | mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
84 | model = AutoModelForCausalLM.from_pretrained(
85 | stage1_model,
86 | torch_dtype=torch.bfloat16,
87 | attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
88 | # device_map="auto",
89 | )
90 | # to device, if gpu is available
91 | model.to(device)
92 | model.eval()
93 |
94 | if torch.__version__ >= "2.0.0":
95 | model = torch.compile(model)
96 |
97 | codectool = CodecManipulator("xcodec", 0, 1)
98 | codectool_stage2 = CodecManipulator("xcodec", 0, 8)
99 | model_config = OmegaConf.load(args.basic_model_config)
100 | codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
101 | parameter_dict = torch.load(args.resume_path, map_location='cpu', weights_only=False)
102 | codec_model.load_state_dict(parameter_dict['codec_model'])
103 | codec_model.to(device)
104 | codec_model.eval()
105 |
106 | class BlockTokenRangeProcessor(LogitsProcessor):
107 | def __init__(self, start_id, end_id):
108 | self.blocked_token_ids = list(range(start_id, end_id))
109 |
110 | def __call__(self, input_ids, scores):
111 | scores[:, self.blocked_token_ids] = -float("inf")
112 | return scores
113 |
114 | def load_audio_mono(filepath, sampling_rate=16000):
115 | audio, sr = torchaudio.load(filepath)
116 | # Convert to mono
117 | audio = torch.mean(audio, dim=0, keepdim=True)
118 | # Resample if needed
119 | if sr != sampling_rate:
120 | resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
121 | audio = resampler(audio)
122 | return audio
123 |
124 | def encode_audio(codec_model, audio_prompt, device, target_bw=0.5):
125 | if len(audio_prompt.shape) < 3:
126 | audio_prompt.unsqueeze_(0)
127 | with torch.no_grad():
128 | raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=target_bw)
129 | raw_codes = raw_codes.transpose(0, 1)
130 | raw_codes = raw_codes.cpu().numpy().astype(np.int16)
131 | return raw_codes
132 |
133 | def split_lyrics(lyrics):
134 | pattern = r"\[(\w+)\](.*?)(?=\[|\Z)"
135 | segments = re.findall(pattern, lyrics, re.DOTALL)
136 | structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
137 | return structured_lyrics
138 |
139 | # Call the function and print the result
140 | stage1_output_set = []
141 | # Tips:
142 | # genre tags support instrumental,genre,mood,vocal timbr and vocal gender
143 | # all kinds of tags are needed
144 | with open(args.genre_txt) as f:
145 | genres = f.read().strip()
146 | with open(args.lyrics_txt) as f:
147 | lyrics = split_lyrics(f.read())
148 | # intruction
149 | full_lyrics = "\n".join(lyrics)
150 | prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
151 | prompt_texts += lyrics
152 |
153 |
154 | random_id = uuid.uuid4()
155 | output_seq = None
156 | # Here is suggested decoding config
157 | top_p = 0.93
158 | temperature = 1.0
159 | repetition_penalty = args.repetition_penalty
160 | # special tokens
161 | start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
162 | end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
163 | # Format text prompt
164 | run_n_segments = min(args.run_n_segments+1, len(lyrics))
165 | for i, p in enumerate(tqdm(prompt_texts[:run_n_segments], desc="Stage1 inference...")):
166 | section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
167 | guidance_scale = 1.5 if i <=1 else 1.2
168 | if i==0:
169 | continue
170 | if i==1:
171 | if args.use_dual_tracks_prompt or args.use_audio_prompt:
172 | if args.use_dual_tracks_prompt:
173 | vocals_ids = load_audio_mono(args.vocal_track_prompt_path)
174 | instrumental_ids = load_audio_mono(args.instrumental_track_prompt_path)
175 | vocals_ids = encode_audio(codec_model, vocals_ids, device, target_bw=0.5)
176 | instrumental_ids = encode_audio(codec_model, instrumental_ids, device, target_bw=0.5)
177 | vocals_ids = codectool.npy2ids(vocals_ids[0])
178 | instrumental_ids = codectool.npy2ids(instrumental_ids[0])
179 | ids_segment_interleaved = rearrange([np.array(vocals_ids), np.array(instrumental_ids)], 'b n -> (n b)')
180 | audio_prompt_codec = ids_segment_interleaved[int(args.prompt_start_time*50*2): int(args.prompt_end_time*50*2)]
181 | audio_prompt_codec = audio_prompt_codec.tolist()
182 | elif args.use_audio_prompt:
183 | audio_prompt = load_audio_mono(args.audio_prompt_path)
184 | raw_codes = encode_audio(codec_model, audio_prompt, device, target_bw=0.5)
185 | # Format audio prompt
186 | code_ids = codectool.npy2ids(raw_codes[0])
187 | audio_prompt_codec = code_ids[int(args.prompt_start_time *50): int(args.prompt_end_time *50)] # 50 is tps of xcodec
188 | audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
189 | sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
190 | head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
191 | else:
192 | head_id = mmtokenizer.tokenize(prompt_texts[0])
193 | prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
194 | else:
195 | prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
196 |
197 | prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
198 | input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
199 | # Use window slicing in case output sequence exceeds the context of model
200 | max_context = 16384-max_new_tokens-1
201 | if input_ids.shape[-1] > max_context:
202 | print(f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
203 | input_ids = input_ids[:, -(max_context):]
204 | with torch.no_grad():
205 | output_seq = model.generate(
206 | input_ids=input_ids,
207 | max_new_tokens=max_new_tokens,
208 | min_new_tokens=100,
209 | do_sample=True,
210 | top_p=top_p,
211 | temperature=temperature,
212 | repetition_penalty=repetition_penalty,
213 | eos_token_id=mmtokenizer.eoa,
214 | pad_token_id=mmtokenizer.eoa,
215 | logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
216 | guidance_scale=guidance_scale,
217 | )
218 | if output_seq[0][-1].item() != mmtokenizer.eoa:
219 | tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
220 | output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
221 | if i > 1:
222 | raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
223 | else:
224 | raw_output = output_seq
225 |
226 | # save raw output and check sanity
227 | ids = raw_output[0].cpu().numpy()
228 | soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
229 | eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
230 | if len(soa_idx)!=len(eoa_idx):
231 | raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
232 |
233 | vocals = []
234 | instrumentals = []
235 | range_begin = 1 if args.use_audio_prompt or args.use_dual_tracks_prompt else 0
236 | for i in range(range_begin, len(soa_idx)):
237 | codec_ids = ids[soa_idx[i]+1:eoa_idx[i]]
238 | if codec_ids[0] == 32016:
239 | codec_ids = codec_ids[1:]
240 | codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
241 | vocals_ids = codectool.ids2npy(rearrange(codec_ids,"(n b) -> b n", b=2)[0])
242 | vocals.append(vocals_ids)
243 | instrumentals_ids = codectool.ids2npy(rearrange(codec_ids,"(n b) -> b n", b=2)[1])
244 | instrumentals.append(instrumentals_ids)
245 | vocals = np.concatenate(vocals, axis=1)
246 | instrumentals = np.concatenate(instrumentals, axis=1)
247 | vocal_save_path = os.path.join(stage1_output_dir, f"{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_{random_id}_vtrack".replace('.', '@')+'.npy')
248 | inst_save_path = os.path.join(stage1_output_dir, f"{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_{random_id}_itrack".replace('.', '@')+'.npy')
249 | np.save(vocal_save_path, vocals)
250 | np.save(inst_save_path, instrumentals)
251 | stage1_output_set.append(vocal_save_path)
252 | stage1_output_set.append(inst_save_path)
253 |
254 |
255 | # offload model
256 | if not args.disable_offload_model:
257 | model.cpu()
258 | del model
259 | torch.cuda.empty_cache()
260 |
261 | print("Stage 2 inference...")
262 | model_stage2 = AutoModelForCausalLM.from_pretrained(
263 | stage2_model,
264 | torch_dtype=torch.bfloat16,
265 | attn_implementation="flash_attention_2",
266 | # device_map="auto",
267 | )
268 | model_stage2.to(device)
269 | model_stage2.eval()
270 |
271 | if torch.__version__ >= "2.0.0":
272 | model_stage2 = torch.compile(model_stage2)
273 |
274 | def stage2_generate(model, prompt, batch_size=16):
275 | codec_ids = codectool.unflatten(prompt, n_quantizer=1)
276 | codec_ids = codectool.offset_tok_ids(
277 | codec_ids,
278 | global_offset=codectool.global_offset,
279 | codebook_size=codectool.codebook_size,
280 | num_codebooks=codectool.num_codebooks,
281 | ).astype(np.int32)
282 |
283 | # Prepare prompt_ids based on batch size or single input
284 | if batch_size > 1:
285 | codec_list = []
286 | for i in range(batch_size):
287 | idx_begin = i * 300
288 | idx_end = (i + 1) * 300
289 | codec_list.append(codec_ids[:, idx_begin:idx_end])
290 |
291 | codec_ids = np.concatenate(codec_list, axis=0)
292 | prompt_ids = np.concatenate(
293 | [
294 | np.tile([mmtokenizer.soa, mmtokenizer.stage_1], (batch_size, 1)),
295 | codec_ids,
296 | np.tile([mmtokenizer.stage_2], (batch_size, 1)),
297 | ],
298 | axis=1
299 | )
300 | else:
301 | prompt_ids = np.concatenate([
302 | np.array([mmtokenizer.soa, mmtokenizer.stage_1]),
303 | codec_ids.flatten(), # Flatten the 2D array to 1D
304 | np.array([mmtokenizer.stage_2])
305 | ]).astype(np.int32)
306 | prompt_ids = prompt_ids[np.newaxis, ...]
307 |
308 | codec_ids = torch.as_tensor(codec_ids).to(device)
309 | prompt_ids = torch.as_tensor(prompt_ids).to(device)
310 | len_prompt = prompt_ids.shape[-1]
311 |
312 | block_list = LogitsProcessorList([BlockTokenRangeProcessor(0, 46358), BlockTokenRangeProcessor(53526, mmtokenizer.vocab_size)])
313 |
314 | # Teacher forcing generate loop
315 | for frames_idx in range(codec_ids.shape[1]):
316 | cb0 = codec_ids[:, frames_idx:frames_idx+1]
317 | prompt_ids = torch.cat([prompt_ids, cb0], dim=1)
318 | input_ids = prompt_ids
319 |
320 | with torch.no_grad():
321 | stage2_output = model.generate(input_ids=input_ids,
322 | min_new_tokens=7,
323 | max_new_tokens=7,
324 | eos_token_id=mmtokenizer.eoa,
325 | pad_token_id=mmtokenizer.eoa,
326 | logits_processor=block_list,
327 | )
328 |
329 | assert stage2_output.shape[1] - prompt_ids.shape[1] == 7, f"output new tokens={stage2_output.shape[1]-prompt_ids.shape[1]}"
330 | prompt_ids = stage2_output
331 |
332 | # Return output based on batch size
333 | if batch_size > 1:
334 | output = prompt_ids.cpu().numpy()[:, len_prompt:]
335 | output_list = [output[i] for i in range(batch_size)]
336 | output = np.concatenate(output_list, axis=0)
337 | else:
338 | output = prompt_ids[0].cpu().numpy()[len_prompt:]
339 |
340 | return output
341 |
342 | def stage2_inference(model, stage1_output_set, stage2_output_dir, batch_size=4):
343 | stage2_result = []
344 | for i in tqdm(range(len(stage1_output_set))):
345 | output_filename = os.path.join(stage2_output_dir, os.path.basename(stage1_output_set[i]))
346 |
347 | if os.path.exists(output_filename):
348 | print(f'{output_filename} stage2 has done.')
349 | continue
350 |
351 | # Load the prompt
352 | prompt = np.load(stage1_output_set[i]).astype(np.int32)
353 |
354 | # Only accept 6s segments
355 | output_duration = prompt.shape[-1] // 50 // 6 * 6
356 | num_batch = output_duration // 6
357 |
358 | if num_batch <= batch_size:
359 | # If num_batch is less than or equal to batch_size, we can infer the entire prompt at once
360 | output = stage2_generate(model, prompt[:, :output_duration*50], batch_size=num_batch)
361 | else:
362 | # If num_batch is greater than batch_size, process in chunks of batch_size
363 | segments = []
364 | num_segments = (num_batch // batch_size) + (1 if num_batch % batch_size != 0 else 0)
365 |
366 | for seg in range(num_segments):
367 | start_idx = seg * batch_size * 300
368 | # Ensure the end_idx does not exceed the available length
369 | end_idx = min((seg + 1) * batch_size * 300, output_duration*50) # Adjust the last segment
370 | current_batch_size = batch_size if seg != num_segments-1 or num_batch % batch_size == 0 else num_batch % batch_size
371 | segment = stage2_generate(
372 | model,
373 | prompt[:, start_idx:end_idx],
374 | batch_size=current_batch_size
375 | )
376 | segments.append(segment)
377 |
378 | # Concatenate all the segments
379 | output = np.concatenate(segments, axis=0)
380 |
381 | # Process the ending part of the prompt
382 | if output_duration*50 != prompt.shape[-1]:
383 | ending = stage2_generate(model, prompt[:, output_duration*50:], batch_size=1)
384 | output = np.concatenate([output, ending], axis=0)
385 | output = codectool_stage2.ids2npy(output)
386 |
387 | # Fix invalid codes (a dirty solution, which may harm the quality of audio)
388 | # We are trying to find better one
389 | fixed_output = copy.deepcopy(output)
390 | for i, line in enumerate(output):
391 | for j, element in enumerate(line):
392 | if element < 0 or element > 1023:
393 | counter = Counter(line)
394 | most_frequant = sorted(counter.items(), key=lambda x: x[1], reverse=True)[0][0]
395 | fixed_output[i, j] = most_frequant
396 | # save output
397 | np.save(output_filename, fixed_output)
398 | stage2_result.append(output_filename)
399 | return stage2_result
400 |
401 | stage2_result = stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_size=args.stage2_batch_size)
402 | print(stage2_result)
403 | print('Stage 2 DONE.\n')
404 | # convert audio tokens to audio
405 | def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
406 | folder_path = os.path.dirname(path)
407 | if not os.path.exists(folder_path):
408 | os.makedirs(folder_path)
409 | limit = 0.99
410 | max_val = wav.abs().max()
411 | wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
412 | torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
413 | # reconstruct tracks
414 | recons_output_dir = os.path.join(args.output_dir, "recons")
415 | recons_mix_dir = os.path.join(recons_output_dir, 'mix')
416 | os.makedirs(recons_mix_dir, exist_ok=True)
417 | tracks = []
418 | for npy in stage2_result:
419 | codec_result = np.load(npy)
420 | decodec_rlt=[]
421 | with torch.no_grad():
422 | decoded_waveform = codec_model.decode(torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device))
423 | decoded_waveform = decoded_waveform.cpu().squeeze(0)
424 | decodec_rlt.append(torch.as_tensor(decoded_waveform))
425 | decodec_rlt = torch.cat(decodec_rlt, dim=-1)
426 | save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
427 | tracks.append(save_path)
428 | save_audio(decodec_rlt, save_path, 16000)
429 | # mix tracks
430 | for inst_path in tracks:
431 | try:
432 | if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
433 | and '_itrack' in inst_path:
434 | # find pair
435 | vocal_path = inst_path.replace('_itrack', '_vtrack')
436 | if not os.path.exists(vocal_path):
437 | continue
438 | # mix
439 | recons_mix = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('_itrack', '_mixed'))
440 | vocal_stem, sr = sf.read(inst_path)
441 | instrumental_stem, _ = sf.read(vocal_path)
442 | mix_stem = (vocal_stem + instrumental_stem) / 1
443 | sf.write(recons_mix, mix_stem, sr)
444 | except Exception as e:
445 | print(e)
446 |
447 | # vocoder to upsample audios
448 | vocal_decoder, inst_decoder = build_codec_model(args.config_path, args.vocal_decoder_path, args.inst_decoder_path)
449 | vocoder_output_dir = os.path.join(args.output_dir, 'vocoder')
450 | vocoder_stems_dir = os.path.join(vocoder_output_dir, 'stems')
451 | vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
452 | os.makedirs(vocoder_mix_dir, exist_ok=True)
453 | os.makedirs(vocoder_stems_dir, exist_ok=True)
454 | for npy in stage2_result:
455 | if '_itrack' in npy:
456 | # Process instrumental
457 | instrumental_output = process_audio(
458 | npy,
459 | os.path.join(vocoder_stems_dir, 'itrack.mp3'),
460 | args.rescale,
461 | args,
462 | inst_decoder,
463 | codec_model
464 | )
465 | else:
466 | # Process vocal
467 | vocal_output = process_audio(
468 | npy,
469 | os.path.join(vocoder_stems_dir, 'vtrack.mp3'),
470 | args.rescale,
471 | args,
472 | vocal_decoder,
473 | codec_model
474 | )
475 | # mix tracks
476 | try:
477 | mix_output = instrumental_output + vocal_output
478 | vocoder_mix = os.path.join(vocoder_mix_dir, os.path.basename(recons_mix))
479 | save_audio(mix_output, vocoder_mix, 44100, args.rescale)
480 | print(f"Created mix: {vocoder_mix}")
481 | except RuntimeError as e:
482 | print(e)
483 | print(f"mix {vocoder_mix} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")
484 |
485 | # Post process
486 | replace_low_freq_with_energy_matched(
487 | a_file=recons_mix, # 16kHz
488 | b_file=vocoder_mix, # 48kHz
489 | c_file=os.path.join(args.output_dir, os.path.basename(recons_mix)),
490 | cutoff_freq=5500.0
491 | )
492 |
--------------------------------------------------------------------------------
/code/finetune/core/datasets/helpers.cpp:
--------------------------------------------------------------------------------
1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
2 |
3 | /* Helper methods for fast index mapping builds */
4 |
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | namespace py = pybind11;
15 | using namespace std;
16 |
17 | const int32_t LONG_SENTENCE_LEN = 512;
18 |
19 | void build_blending_indices(py::array_t &dataset_index,
20 | py::array_t &dataset_sample_index,
21 | const py::array_t &weights,
22 | const int32_t num_datasets,
23 | const int64_t size, const bool verbose)
24 | {
25 | /* Given multiple datasets and a weighting array, build samples
26 | such that it follows those wieghts.*/
27 |
28 | if (verbose)
29 | {
30 | std::cout << "> building indices for blended datasets ..." << std::endl;
31 | }
32 |
33 | // Get the pointer access without the checks.
34 | auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
35 | auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
36 | auto weights_ptr = weights.unchecked<1>();
37 |
38 | // Initialize buffer for number of samples used for each dataset.
39 | int64_t current_samples[num_datasets];
40 | for (int64_t i = 0; i < num_datasets; ++i)
41 | {
42 | current_samples[i] = 0;
43 | }
44 |
45 | // For each sample:
46 | for (int64_t sample_idx = 0; sample_idx < size; ++sample_idx)
47 | {
48 |
49 | // Determine where the max error in sampling is happening.
50 | auto sample_idx_double = std::max(static_cast(sample_idx), 1.0);
51 | int64_t max_error_index = 0;
52 | double max_error = weights_ptr[0] * sample_idx_double -
53 | static_cast(current_samples[0]);
54 | for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx)
55 | {
56 | double error = weights_ptr[dataset_idx] * sample_idx_double -
57 | static_cast(current_samples[dataset_idx]);
58 | if (error > max_error)
59 | {
60 | max_error = error;
61 | max_error_index = dataset_idx;
62 | }
63 | }
64 |
65 | // Populate the indices.
66 | dataset_index_ptr[sample_idx] = static_cast(max_error_index);
67 | dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];
68 |
69 | // Update the total samples.
70 | current_samples[max_error_index] += 1;
71 | }
72 |
73 | // print info
74 | if (verbose)
75 | {
76 | std::cout << " > sample ratios:" << std::endl;
77 | for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx)
78 | {
79 | auto ratio = static_cast(current_samples[dataset_idx]) /
80 | static_cast(size);
81 | std::cout << " dataset " << dataset_idx << ", input: " << weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;
82 | }
83 | }
84 | }
85 |
86 | py::array build_sample_idx(const py::array_t &sizes_,
87 | const py::array_t &doc_idx_,
88 | const int32_t seq_length,
89 | const int32_t num_epochs,
90 | const int64_t tokens_per_epoch)
91 | {
92 | /* Sample index (sample_idx) is used for gpt2 like dataset for which
93 | the documents are flattened and the samples are built based on this
94 | 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
95 | where [..., 0] contains the index into `doc_idx` and [..., 1] is the
96 | starting offset in that document.*/
97 |
98 | // Consistency checks.
99 | assert(seq_length > 1);
100 | assert(num_epochs > 0);
101 | assert(tokens_per_epoch > 1);
102 |
103 | // Remove bound checks.
104 | auto sizes = sizes_.unchecked<1>();
105 | auto doc_idx = doc_idx_.unchecked<1>();
106 |
107 | // Mapping and it's length (1D).
108 | int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
109 | int32_t *sample_idx = new int32_t[2 * (num_samples + 1)];
110 |
111 | // Index into sample_idx.
112 | int64_t sample_index = 0;
113 | // Index into doc_idx.
114 | int64_t doc_idx_index = 0;
115 | // Begining offset for each document.
116 | int32_t doc_offset = 0;
117 | // Start with first document and no offset.
118 | sample_idx[2 * sample_index] = doc_idx_index;
119 | sample_idx[2 * sample_index + 1] = doc_offset;
120 | ++sample_index;
121 |
122 | while (sample_index <= num_samples)
123 | {
124 | // Start with a fresh sequence.
125 | int32_t remaining_seq_length = seq_length + 1;
126 | while (remaining_seq_length != 0)
127 | {
128 | // Get the document length.
129 | auto doc_id = doc_idx[doc_idx_index];
130 | auto doc_length = sizes[doc_id] - doc_offset;
131 | // And add it to the current sequence.
132 | remaining_seq_length -= doc_length;
133 | // If we have more than a full sequence, adjust offset and set
134 | // remaining length to zero so we return from the while loop.
135 | // Note that -1 here is for the same reason we have -1 in
136 | // `_num_epochs` calculations.
137 | if (remaining_seq_length <= 0)
138 | {
139 | doc_offset += (remaining_seq_length + doc_length - 1);
140 | remaining_seq_length = 0;
141 | }
142 | else
143 | {
144 | // Otherwise, start from the begining of the next document.
145 | ++doc_idx_index;
146 | doc_offset = 0;
147 | }
148 | }
149 | // Record the sequence.
150 | sample_idx[2 * sample_index] = doc_idx_index;
151 | sample_idx[2 * sample_index + 1] = doc_offset;
152 | ++sample_index;
153 | }
154 |
155 | // Method to deallocate memory.
156 | py::capsule free_when_done(sample_idx, [](void *mem_)
157 | {
158 | int32_t *mem = reinterpret_cast(mem_);
159 | delete[] mem; });
160 |
161 | // Return the numpy array.
162 | const auto byte_size = sizeof(int32_t);
163 | return py::array(std::vector{num_samples + 1, 2}, // shape
164 | {2 * byte_size, byte_size}, // C-style contiguous strides
165 | sample_idx, // the data pointer
166 | free_when_done); // numpy array references
167 | }
168 |
169 | inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
170 | const int32_t max_length,
171 | std::mt19937 &rand32_gen)
172 | {
173 | /* Training sample length. */
174 | if (short_seq_ratio == 0)
175 | {
176 | return max_length;
177 | }
178 | const auto random_number = rand32_gen();
179 | if ((random_number % short_seq_ratio) == 0)
180 | {
181 | return 2 + random_number % (max_length - 1);
182 | }
183 | return max_length;
184 | }
185 |
186 | template
187 | py::array build_mapping_impl(const py::array_t &docs_,
188 | const py::array_t &sizes_,
189 | const int32_t num_epochs,
190 | const uint64_t max_num_samples,
191 | const int32_t max_seq_length,
192 | const double short_seq_prob,
193 | const int32_t seed,
194 | const bool verbose,
195 | const int32_t min_num_sent)
196 | {
197 | /* Build a mapping of (start-index, end-index, sequence-length) where
198 | start and end index are the indices of the sentences in the sample
199 | and sequence-length is the target sequence length.
200 | */
201 |
202 | // Consistency checks.
203 | assert(num_epochs > 0);
204 | assert(max_seq_length > 1);
205 | assert(short_seq_prob >= 0.0);
206 | assert(short_seq_prob <= 1.0);
207 | assert(seed > 0);
208 |
209 | // Remove bound checks.
210 | auto docs = docs_.unchecked<1>();
211 | auto sizes = sizes_.unchecked<1>();
212 |
213 | // For efficiency, convert probability to ratio. Note: rand() generates int.
214 | int32_t short_seq_ratio = 0;
215 | if (short_seq_prob > 0)
216 | {
217 | short_seq_ratio = static_cast(round(1.0 / short_seq_prob));
218 | }
219 |
220 | if (verbose)
221 | {
222 | const auto sent_start_index = docs[0];
223 | const auto sent_end_index = docs[docs_.shape(0) - 1];
224 | const auto num_sentences = sent_end_index - sent_start_index;
225 | cout << " using:" << endl
226 | << std::flush;
227 | cout << " number of documents: " << docs_.shape(0) - 1 << endl
228 | << std::flush;
229 | cout << " sentences range: [" << sent_start_index << ", " << sent_end_index << ")" << endl
230 | << std::flush;
231 | cout << " total number of sentences: " << num_sentences << endl
232 | << std::flush;
233 | cout << " number of epochs: " << num_epochs << endl
234 | << std::flush;
235 | cout << " maximum number of samples: " << max_num_samples << endl
236 | << std::flush;
237 | cout << " maximum sequence length: " << max_seq_length << endl
238 | << std::flush;
239 | cout << " short sequence probability: " << short_seq_prob << endl
240 | << std::flush;
241 | cout << " short sequence ration (1/prob): " << short_seq_ratio << endl
242 | << std::flush;
243 | cout << " seed: " << seed << endl
244 | << std::flush;
245 | }
246 |
247 | // Mapping and it's length (1D).
248 | int64_t num_samples = -1;
249 | DocIdx *maps = NULL;
250 |
251 | // Perform two iterations, in the first iteration get the size
252 | // and allocate memory and in the second iteration populate the map.
253 | bool second = false;
254 | for (int32_t iteration = 0; iteration < 2; ++iteration)
255 | {
256 |
257 | // Set the seed so both iterations produce the same results.
258 | std::mt19937 rand32_gen(seed);
259 |
260 | // Set the flag on second iteration.
261 | second = (iteration == 1);
262 |
263 | // Counters:
264 | uint64_t empty_docs = 0;
265 | uint64_t one_sent_docs = 0;
266 | uint64_t long_sent_docs = 0;
267 |
268 | // Current map index.
269 | uint64_t map_index = 0;
270 |
271 | // For each epoch:
272 | for (int32_t epoch = 0; epoch < num_epochs; ++epoch)
273 | {
274 | if (map_index >= max_num_samples)
275 | {
276 | if (verbose && (!second))
277 | {
278 | cout << " reached " << max_num_samples << " samples after "
279 | << epoch << " epochs ..." << endl
280 | << std::flush;
281 | }
282 | break;
283 | }
284 | // For each document:
285 | for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc)
286 | {
287 |
288 | // Document sentences are in [sent_index_first, sent_index_last)
289 | const auto sent_index_first = docs[doc];
290 | const auto sent_index_last = docs[doc + 1];
291 |
292 | // At the begining of the document previous index is the
293 | // start index.
294 | auto prev_start_index = sent_index_first;
295 |
296 | // Remaining documents.
297 | auto num_remain_sent = sent_index_last - sent_index_first;
298 |
299 | // Some bookkeeping
300 | if ((epoch == 0) && (!second))
301 | {
302 | if (num_remain_sent == 0)
303 | {
304 | ++empty_docs;
305 | }
306 | if (num_remain_sent == 1)
307 | {
308 | ++one_sent_docs;
309 | }
310 | }
311 |
312 | // Detect documents with long sentences.
313 | bool contains_long_sentence = false;
314 | if (num_remain_sent > 1)
315 | {
316 | for (auto sent_index = sent_index_first;
317 | sent_index < sent_index_last; ++sent_index)
318 | {
319 | if (sizes[sent_index] > LONG_SENTENCE_LEN)
320 | {
321 | if ((epoch == 0) && (!second))
322 | {
323 | ++long_sent_docs;
324 | }
325 | contains_long_sentence = true;
326 | break;
327 | }
328 | }
329 | }
330 |
331 | // If we have more than two sentences.
332 | if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence))
333 | {
334 |
335 | // Set values.
336 | auto seq_len = int32_t{0};
337 | auto num_sent = int32_t{0};
338 | auto target_seq_len = get_target_sample_len(short_seq_ratio,
339 | max_seq_length,
340 | rand32_gen);
341 |
342 | // Loop through sentences.
343 | for (auto sent_index = sent_index_first;
344 | sent_index < sent_index_last; ++sent_index)
345 | {
346 |
347 | // Add the size and number of sentences.
348 | seq_len += sizes[sent_index];
349 | ++num_sent;
350 | --num_remain_sent;
351 |
352 | // If we have reached the target length.
353 | // and if not only one sentence is left in the document.
354 | // and if we have at least two sentneces.
355 | // and if we have reached end of the document.
356 | if (((seq_len >= target_seq_len) &&
357 | (num_remain_sent > 1) &&
358 | (num_sent >= min_num_sent)) ||
359 | (num_remain_sent == 0))
360 | {
361 |
362 | // Check for overflow.
363 | if ((3 * map_index + 2) >
364 | std::numeric_limits::max())
365 | {
366 | cout << "number of samples exceeded maximum "
367 | << "allowed by type int64: "
368 | << std::numeric_limits::max()
369 | << endl;
370 | throw std::overflow_error("Number of samples");
371 | }
372 |
373 | // Populate the map.
374 | if (second)
375 | {
376 | const auto map_index_0 = 3 * map_index;
377 | maps[map_index_0] = static_cast(prev_start_index);
378 | maps[map_index_0 + 1] = static_cast(sent_index + 1);
379 | maps[map_index_0 + 2] = static_cast(target_seq_len);
380 | }
381 |
382 | // Update indices / counters.
383 | ++map_index;
384 | prev_start_index = sent_index + 1;
385 | target_seq_len = get_target_sample_len(short_seq_ratio,
386 | max_seq_length,
387 | rand32_gen);
388 | seq_len = 0;
389 | num_sent = 0;
390 | }
391 |
392 | } // for (auto sent_index=sent_index_first; ...
393 | } // if (num_remain_sent > 1) {
394 | } // for (int doc=0; doc < num_docs; ++doc) {
395 | } // for (int epoch=0; epoch < num_epochs; ++epoch) {
396 |
397 | if (!second)
398 | {
399 | if (verbose)
400 | {
401 | cout << " number of empty documents: " << empty_docs << endl
402 | << std::flush;
403 | cout << " number of documents with one sentence: " << one_sent_docs << endl
404 | << std::flush;
405 | cout << " number of documents with long sentences: " << long_sent_docs << endl
406 | << std::flush;
407 | cout << " will create mapping for " << map_index << " samples" << endl
408 | << std::flush;
409 | }
410 | assert(maps == NULL);
411 | assert(num_samples < 0);
412 | maps = new DocIdx[3 * map_index];
413 | num_samples = static_cast(map_index);
414 | }
415 |
416 | } // for (int iteration=0; iteration < 2; ++iteration) {
417 |
418 | // Shuffle.
419 | // We need a 64 bit random number generator as we might have more
420 | // than 2 billion samples.
421 | std::mt19937_64 rand64_gen(seed + 1);
422 | for (auto i = (num_samples - 1); i > 0; --i)
423 | {
424 | const auto j = static_cast(rand64_gen() % (i + 1));
425 | const auto i0 = 3 * i;
426 | const auto j0 = 3 * j;
427 | // Swap values.
428 | swap(maps[i0], maps[j0]);
429 | swap(maps[i0 + 1], maps[j0 + 1]);
430 | swap(maps[i0 + 2], maps[j0 + 2]);
431 | }
432 |
433 | // Method to deallocate memory.
434 | py::capsule free_when_done(maps, [](void *mem_)
435 | {
436 | DocIdx *mem = reinterpret_cast(mem_);
437 | delete[] mem; });
438 |
439 | // Return the numpy array.
440 | const auto byte_size = sizeof(DocIdx);
441 | return py::array(std::vector{num_samples, 3}, // shape
442 | {3 * byte_size, byte_size}, // C-style contiguous strides
443 | maps, // the data pointer
444 | free_when_done); // numpy array references
445 | }
446 |
447 | py::array build_mapping(const py::array_t &docs_,
448 | const py::array_t &sizes_,
449 | const int num_epochs,
450 | const uint64_t max_num_samples,
451 | const int max_seq_length,
452 | const double short_seq_prob,
453 | const int seed,
454 | const bool verbose,
455 | const int32_t min_num_sent)
456 | {
457 |
458 | if (sizes_.size() > std::numeric_limits::max())
459 | {
460 | if (verbose)
461 | {
462 | cout << " using uint64 for data mapping..." << endl
463 | << std::flush;
464 | }
465 | return build_mapping_impl(docs_, sizes_, num_epochs,
466 | max_num_samples, max_seq_length,
467 | short_seq_prob, seed, verbose,
468 | min_num_sent);
469 | }
470 | else
471 | {
472 | if (verbose)
473 | {
474 | cout << " using uint32 for data mapping..." << endl
475 | << std::flush;
476 | }
477 | return build_mapping_impl(docs_, sizes_, num_epochs,
478 | max_num_samples, max_seq_length,
479 | short_seq_prob, seed, verbose,
480 | min_num_sent);
481 | }
482 | }
483 |
484 | template
485 | py::array build_blocks_mapping_impl(const py::array_t &docs_,
486 | const py::array_t &sizes_,
487 | const py::array_t &titles_sizes_,
488 | const int32_t num_epochs,
489 | const uint64_t max_num_samples,
490 | const int32_t max_seq_length,
491 | const int32_t seed,
492 | const bool verbose,
493 | const bool use_one_sent_blocks)
494 | {
495 | /* Build a mapping of (start-index, end-index, sequence-length) where
496 | start and end index are the indices of the sentences in the sample
497 | and sequence-length is the target sequence length.
498 | */
499 |
500 | // Consistency checks.
501 | assert(num_epochs > 0);
502 | assert(max_seq_length > 1);
503 | assert(seed > 0);
504 |
505 | // Remove bound checks.
506 | auto docs = docs_.unchecked<1>();
507 | auto sizes = sizes_.unchecked<1>();
508 | auto titles_sizes = titles_sizes_.unchecked<1>();
509 |
510 | if (verbose)
511 | {
512 | const auto sent_start_index = docs[0];
513 | const auto sent_end_index = docs[docs_.shape(0) - 1];
514 | const auto num_sentences = sent_end_index - sent_start_index;
515 | cout << " using:" << endl
516 | << std::flush;
517 | cout << " number of documents: " << docs_.shape(0) - 1 << endl
518 | << std::flush;
519 | cout << " sentences range: [" << sent_start_index << ", " << sent_end_index << ")" << endl
520 | << std::flush;
521 | cout << " total number of sentences: " << num_sentences << endl
522 | << std::flush;
523 | cout << " number of epochs: " << num_epochs << endl
524 | << std::flush;
525 | cout << " maximum number of samples: " << max_num_samples << endl
526 | << std::flush;
527 | cout << " maximum sequence length: " << max_seq_length << endl
528 | << std::flush;
529 | cout << " seed: " << seed << endl
530 | << std::flush;
531 | }
532 |
533 | // Mapping and its length (1D).
534 | int64_t num_samples = -1;
535 | DocIdx *maps = NULL;
536 |
537 | // Acceptable number of sentences per block.
538 | int min_num_sent = 2;
539 | if (use_one_sent_blocks)
540 | {
541 | min_num_sent = 1;
542 | }
543 |
544 | // Perform two iterations, in the first iteration get the size
545 | // and allocate memory and in the second iteration populate the map.
546 | bool second = false;
547 | for (int32_t iteration = 0; iteration < 2; ++iteration)
548 | {
549 |
550 | // Set the flag on second iteration.
551 | second = (iteration == 1);
552 |
553 | // Current map index.
554 | uint64_t map_index = 0;
555 |
556 | uint64_t empty_docs = 0;
557 | uint64_t one_sent_docs = 0;
558 | uint64_t long_sent_docs = 0;
559 | // For each epoch:
560 | for (int32_t epoch = 0; epoch < num_epochs; ++epoch)
561 | {
562 | // assign every block a unique id
563 | int32_t block_id = 0;
564 |
565 | if (map_index >= max_num_samples)
566 | {
567 | if (verbose && (!second))
568 | {
569 | cout << " reached " << max_num_samples << " samples after "
570 | << epoch << " epochs ..." << endl
571 | << std::flush;
572 | }
573 | break;
574 | }
575 | // For each document:
576 | for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc)
577 | {
578 |
579 | // Document sentences are in [sent_index_first, sent_index_last)
580 | const auto sent_index_first = docs[doc];
581 | const auto sent_index_last = docs[doc + 1];
582 | const auto target_seq_len = max_seq_length - titles_sizes[doc];
583 |
584 | // At the begining of the document previous index is the
585 | // start index.
586 | auto prev_start_index = sent_index_first;
587 |
588 | // Remaining documents.
589 | auto num_remain_sent = sent_index_last - sent_index_first;
590 |
591 | // Some bookkeeping
592 | if ((epoch == 0) && (!second))
593 | {
594 | if (num_remain_sent == 0)
595 | {
596 | ++empty_docs;
597 | }
598 | if (num_remain_sent == 1)
599 | {
600 | ++one_sent_docs;
601 | }
602 | }
603 | // Detect documents with long sentences.
604 | bool contains_long_sentence = false;
605 | if (num_remain_sent >= min_num_sent)
606 | {
607 | for (auto sent_index = sent_index_first;
608 | sent_index < sent_index_last; ++sent_index)
609 | {
610 | if (sizes[sent_index] > LONG_SENTENCE_LEN)
611 | {
612 | if ((epoch == 0) && (!second))
613 | {
614 | ++long_sent_docs;
615 | }
616 | contains_long_sentence = true;
617 | break;
618 | }
619 | }
620 | }
621 | // If we have enough sentences and no long sentences.
622 | if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence))
623 | {
624 |
625 | // Set values.
626 | auto seq_len = int32_t{0};
627 | auto num_sent = int32_t{0};
628 |
629 | // Loop through sentences.
630 | for (auto sent_index = sent_index_first;
631 | sent_index < sent_index_last; ++sent_index)
632 | {
633 |
634 | // Add the size and number of sentences.
635 | seq_len += sizes[sent_index];
636 | ++num_sent;
637 | --num_remain_sent;
638 |
639 | // If we have reached the target length.
640 | // and there are an acceptable number of sentences left
641 | // and if we have at least the minimum number of sentences.
642 | // or if we have reached end of the document.
643 | if (((seq_len >= target_seq_len) &&
644 | (num_remain_sent >= min_num_sent) &&
645 | (num_sent >= min_num_sent)) ||
646 | (num_remain_sent == 0))
647 | {
648 |
649 | // Populate the map.
650 | if (second)
651 | {
652 | const auto map_index_0 = 4 * map_index;
653 | // Each sample has 4 items: the starting sentence index, ending sentence index,
654 | // the index of the document from which the block comes (used for fetching titles)
655 | // and the unique id of the block (used for creating block indexes)
656 |
657 | maps[map_index_0] = static_cast(prev_start_index);
658 | maps[map_index_0 + 1] = static_cast(sent_index + 1);
659 | maps[map_index_0 + 2] = static_cast(doc);
660 | maps[map_index_0 + 3] = static_cast(block_id);
661 | }
662 |
663 | // Update indices / counters.
664 | ++map_index;
665 | ++block_id;
666 | prev_start_index = sent_index + 1;
667 | seq_len = 0;
668 | num_sent = 0;
669 | }
670 | } // for (auto sent_index=sent_index_first; ...
671 | } // if (num_remain_sent > 1) {
672 | } // for (int doc=0; doc < num_docs; ++doc) {
673 | } // for (int epoch=0; epoch < num_epochs; ++epoch) {
674 |
675 | if (!second)
676 | {
677 | if (verbose)
678 | {
679 | cout << " number of empty documents: " << empty_docs << endl
680 | << std::flush;
681 | cout << " number of documents with one sentence: " << one_sent_docs << endl
682 | << std::flush;
683 | cout << " number of documents with long sentences: " << long_sent_docs << endl
684 | << std::flush;
685 | cout << " will create mapping for " << map_index << " samples" << endl
686 | << std::flush;
687 | }
688 | assert(maps == NULL);
689 | assert(num_samples < 0);
690 | maps = new DocIdx[4 * map_index];
691 | num_samples = static_cast(map_index);
692 | }
693 |
694 | } // for (int iteration=0; iteration < 2; ++iteration) {
695 |
696 | // Shuffle.
697 | // We need a 64 bit random number generator as we might have more
698 | // than 2 billion samples.
699 | std::mt19937_64 rand64_gen(seed + 1);
700 | for (auto i = (num_samples - 1); i > 0; --i)
701 | {
702 | const auto j = static_cast(rand64_gen() % (i + 1));
703 | const auto i0 = 4 * i;
704 | const auto j0 = 4 * j;
705 | // Swap values.
706 | swap(maps[i0], maps[j0]);
707 | swap(maps[i0 + 1], maps[j0 + 1]);
708 | swap(maps[i0 + 2], maps[j0 + 2]);
709 | swap(maps[i0 + 3], maps[j0 + 3]);
710 | }
711 |
712 | // Method to deallocate memory.
713 | py::capsule free_when_done(maps, [](void *mem_)
714 | {
715 | DocIdx *mem = reinterpret_cast(mem_);
716 | delete[] mem; });
717 |
718 | // Return the numpy array.
719 | const auto byte_size = sizeof(DocIdx);
720 | return py::array(std::vector{num_samples, 4}, // shape
721 | {4 * byte_size, byte_size}, // C-style contiguous strides
722 | maps, // the data pointer
723 | free_when_done); // numpy array references
724 | }
725 |
726 | py::array build_blocks_mapping(const py::array_t &docs_,
727 | const py::array_t &sizes_,
728 | const py::array_t &titles_sizes_,
729 | const int num_epochs,
730 | const uint64_t max_num_samples,
731 | const int max_seq_length,
732 | const int seed,
733 | const bool verbose,
734 | const bool use_one_sent_blocks)
735 | {
736 |
737 | if (sizes_.size() > std::numeric_limits::max())
738 | {
739 | if (verbose)
740 | {
741 | cout << " using uint64 for data mapping..." << endl
742 | << std::flush;
743 | }
744 | return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_,
745 | num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
746 | }
747 | else
748 | {
749 | if (verbose)
750 | {
751 | cout << " using uint32 for data mapping..." << endl
752 | << std::flush;
753 | }
754 | return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_,
755 | num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
756 | }
757 | }
758 |
759 | PYBIND11_MODULE(helpers, m)
760 | {
761 | m.def("build_mapping", &build_mapping);
762 | m.def("build_blocks_mapping", &build_blocks_mapping);
763 | m.def("build_sample_idx", &build_sample_idx);
764 | m.def("build_blending_indices", &build_blending_indices);
765 | }
766 |
--------------------------------------------------------------------------------