├── .gitignore ├── README.md ├── demos ├── 808 bass_evolution.wav ├── 808 bass_final.wav ├── bird_evolution.wav ├── bird_final.wav ├── hihat_evolution.wav ├── hihat_final.wav ├── human scream_evolution.wav ├── human scream_final.wav ├── kick drum_evolution.wav ├── kick drum_final.wav ├── piano_evolution.wav ├── piano_final.wav ├── rain_evolution.wav ├── rain_final.wav ├── whistling_evolution.wav └── whistling_final.wav ├── genfx.py ├── index.html ├── misc ├── evolution.png └── logo.svg └── textsynth.py /.gitignore: -------------------------------------------------------------------------------- 1 | # custom 2 | timbreCLIP/ 3 | data/ 4 | artefacts/ 5 | .DS_Store 6 | # all DS_Store files 7 | **/.DS_Store 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # poetry 105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 109 | #poetry.lock 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | #pdm.lock 114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 115 | # in version control. 116 | # https://pdm.fming.dev/#use-with-ide 117 | .pdm.toml 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Environments 130 | .env 131 | .venv 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ 161 | 162 | # PyCharm 163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 165 | # and can be added to the global gitignore or merged into this file. For a more nuclear 166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 167 | #.idea/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![logo](misc/logo.svg) 2 | 3 | 4 | Generate synthesizer sounds from text prompts with a simple evolutionary algorithm. 5 | 6 | **Audio examples here: https://erl-j.github.io/textsynth/** 7 | 8 | 9 | Synth: https://github.com/torchsynth/torchsynth 10 | 11 | Audio-Text cross modal embedding: https://github.com/LAION-AI/CLAP 12 | 13 | ## How it works 14 | Start with randomly initialized synthesizer sounds. Each iteration, the current synthesizer sounds are evaluated on how well they match the text prompt. The best sounds are then combined and mutated to generate new sounds for the next iteration. 200 generations w/ 50 samples takes about ~20s on a 3090 (not tested on CPU). 15 | 16 | 17 | example 18 | 19 | ## Future work 20 | 21 | - Install guide / requirements.txt 22 | - Diversity preservation. 23 | - Open ended exploration. 24 | - RL 25 | - Neural nets? 26 | 27 | ## Cite 28 | 29 | If you found this project useful please cite: 30 | 31 | ```BibTex 32 | @software{ 33 | textsynth, 34 | author = {Nicolas Jonason}, 35 | title = {TextSynth: Generate synthesizer sounds from text prompts with a simple evolutionary algorithm}, 36 | month = october, 37 | year = 2023, 38 | } 39 | ``` 40 | -------------------------------------------------------------------------------- /demos/808 bass_evolution.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erl-j/textsynth/f6e2ef6268c0c2d43eb67b6b1db232ca4cf2a089/demos/808 bass_evolution.wav -------------------------------------------------------------------------------- /demos/808 bass_final.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erl-j/textsynth/f6e2ef6268c0c2d43eb67b6b1db232ca4cf2a089/demos/808 bass_final.wav -------------------------------------------------------------------------------- /demos/bird_evolution.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erl-j/textsynth/f6e2ef6268c0c2d43eb67b6b1db232ca4cf2a089/demos/bird_evolution.wav -------------------------------------------------------------------------------- /demos/bird_final.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erl-j/textsynth/f6e2ef6268c0c2d43eb67b6b1db232ca4cf2a089/demos/bird_final.wav -------------------------------------------------------------------------------- /demos/hihat_evolution.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erl-j/textsynth/f6e2ef6268c0c2d43eb67b6b1db232ca4cf2a089/demos/hihat_evolution.wav -------------------------------------------------------------------------------- /demos/hihat_final.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erl-j/textsynth/f6e2ef6268c0c2d43eb67b6b1db232ca4cf2a089/demos/hihat_final.wav -------------------------------------------------------------------------------- /demos/human scream_evolution.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erl-j/textsynth/f6e2ef6268c0c2d43eb67b6b1db232ca4cf2a089/demos/human scream_evolution.wav -------------------------------------------------------------------------------- /demos/human scream_final.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erl-j/textsynth/f6e2ef6268c0c2d43eb67b6b1db232ca4cf2a089/demos/human scream_final.wav -------------------------------------------------------------------------------- /demos/kick drum_evolution.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erl-j/textsynth/f6e2ef6268c0c2d43eb67b6b1db232ca4cf2a089/demos/kick drum_evolution.wav -------------------------------------------------------------------------------- /demos/kick drum_final.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erl-j/textsynth/f6e2ef6268c0c2d43eb67b6b1db232ca4cf2a089/demos/kick drum_final.wav -------------------------------------------------------------------------------- /demos/piano_evolution.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erl-j/textsynth/f6e2ef6268c0c2d43eb67b6b1db232ca4cf2a089/demos/piano_evolution.wav -------------------------------------------------------------------------------- /demos/piano_final.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erl-j/textsynth/f6e2ef6268c0c2d43eb67b6b1db232ca4cf2a089/demos/piano_final.wav -------------------------------------------------------------------------------- /demos/rain_evolution.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erl-j/textsynth/f6e2ef6268c0c2d43eb67b6b1db232ca4cf2a089/demos/rain_evolution.wav -------------------------------------------------------------------------------- /demos/rain_final.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erl-j/textsynth/f6e2ef6268c0c2d43eb67b6b1db232ca4cf2a089/demos/rain_final.wav -------------------------------------------------------------------------------- /demos/whistling_evolution.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erl-j/textsynth/f6e2ef6268c0c2d43eb67b6b1db232ca4cf2a089/demos/whistling_evolution.wav -------------------------------------------------------------------------------- /demos/whistling_final.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erl-j/textsynth/f6e2ef6268c0c2d43eb67b6b1db232ca4cf2a089/demos/whistling_final.wav -------------------------------------------------------------------------------- /genfx.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import torch 3 | import numpy as np 4 | import seaborn as sns 5 | import matplotlib.pyplot as plt 6 | import torch 7 | from torchsynth.synth import Voice,SynthConfig 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | from IPython.display import clear_output 11 | from sklearn.cluster import KMeans 12 | from tqdm import tqdm 13 | import IPython.display as display 14 | import pandas as pd 15 | import laion_clap 16 | import librosa 17 | from pedalboard import ( 18 | Pedalboard, 19 | Chorus, 20 | Reverb, 21 | PitchShift, 22 | Delay, 23 | Compressor, 24 | Distortion, 25 | LadderFilter, 26 | Mix, 27 | Gain, 28 | ) 29 | 30 | 31 | SAMPLE_RATE=48000 32 | 33 | # Run on the GPU if it's available 34 | if torch.cuda.is_available(): 35 | device = torch.device("cuda:3") 36 | else: 37 | device = torch.device("cpu") 38 | 39 | # disable gradient calculation 40 | torch.set_grad_enabled(False) 41 | 42 | def play(a): 43 | display.display(display.Audio(a, rate=SAMPLE_RATE)) 44 | 45 | 46 | class EffectChain: 47 | def __init__(self): 48 | self.parameters = [ 49 | "drive_db", 50 | "reverb_room_size", 51 | "reverb_damping", 52 | "reverb_mix", 53 | "chorus_rate", 54 | "chorus_depth", 55 | "chorus_mix", 56 | "delay_time", 57 | "delay_feedback", 58 | "delay_mix", 59 | "compressor_threshold", 60 | "compressor_ratio", 61 | "octave_up_mix", 62 | "octave_down_mix", 63 | "low_pass_cutoff", 64 | "high_pass_cutoff", 65 | ] 66 | 67 | def get_n_parameters(self): 68 | return len(self.parameters) 69 | 70 | def tensor2pedalboard(self, tensor): 71 | 72 | # parameter 2 value 73 | p2v = dict(zip(self.parameters, tensor)) 74 | 75 | board = Pedalboard( 76 | [ 77 | Chorus( 78 | rate_hz=p2v["chorus_rate"] * 2, 79 | depth=p2v["chorus_depth"], 80 | mix=p2v["chorus_mix"], 81 | ), 82 | Reverb( 83 | room_size=p2v["reverb_room_size"], 84 | damping=p2v["reverb_damping"], 85 | dry_level=1 - p2v["reverb_mix"], 86 | wet_level=p2v["reverb_mix"], 87 | ), 88 | Delay( 89 | delay_seconds=p2v["delay_time"], 90 | feedback=p2v["delay_feedback"], 91 | mix=p2v["delay_mix"], 92 | ), 93 | Compressor( 94 | threshold_db=-p2v["compressor_threshold"] * 10, 95 | ratio=1.0 + p2v["compressor_ratio"] * 100.0, 96 | ), 97 | Distortion(drive_db=p2v["drive_db"] * 50), 98 | Mix( 99 | [ 100 | Pedalboard( 101 | [ 102 | PitchShift(semitones=12), 103 | Gain(gain_db=-40 + 40 * p2v["octave_up_mix"]), 104 | ] 105 | ), 106 | Pedalboard( 107 | [ 108 | PitchShift(semitones=-12), 109 | Gain(gain_db=-40 + 40 * p2v["octave_down_mix"]), 110 | ] 111 | ), 112 | Gain(gain_db=0), 113 | ] 114 | ), 115 | LadderFilter( 116 | mode=LadderFilter.Mode.HPF12, 117 | cutoff_hz=p2v["high_pass_cutoff"] * 16000, 118 | ), 119 | LadderFilter( 120 | mode=LadderFilter.Mode.LPF12, 121 | cutoff_hz=p2v["low_pass_cutoff"] * 16000, 122 | ), 123 | ] 124 | ) 125 | 126 | return board 127 | 128 | def __call__(self, audio, tensor): 129 | board = self.tensor2pedalboard(tensor) 130 | return board(audio, sample_rate=SAMPLE_RATE) 131 | 132 | 133 | CLIP_DURATION=1 134 | BATCH_SIZE=40 135 | N_PARENTS=64 136 | 137 | N_SAVED_PER_TARGET=16 138 | 139 | MUTATION_RATE=0.01 140 | TEMPERATURE=2.0 141 | 142 | MIDI_F0=53 143 | 144 | 145 | # initialize population 146 | activation = lambda x: 0.0000001 + 0.9999 * (torch.cos(x * np.pi * 2) + 1) / 2 147 | effect_chain = EffectChain() 148 | 149 | 150 | source_path = "./data/nylon.wav" 151 | source_audio = librosa.load(source_path, sr=SAMPLE_RATE, duration=CLIP_DURATION)[0] 152 | 153 | dummy_p=torch.randn((BATCH_SIZE, effect_chain.get_n_parameters())) 154 | 155 | clap_model = laion_clap.CLAP_Module(enable_fusion=True, device=device) 156 | clap_model.load_ckpt() # download the default pretrained checkpoint. 157 | 158 | def embed_audio(audio_data): 159 | audio_embed = clap_model.get_audio_embedding_from_data(x = audio_data, use_tensor=True) 160 | return audio_embed 161 | 162 | def embed_text(text_data): 163 | text_embed = clap_model.get_text_embedding(text_data, use_tensor=True) 164 | return text_embed 165 | 166 | # random p 167 | p = torch.rand(dummy_p.shape).to(device)*2*np.pi 168 | 169 | def mutate(p): 170 | #mask=(torch.rand(p.shape,device=p.device)0.5 177 | return p1*mask+p2*(~mask) 178 | 179 | #%% 180 | PROMPT = "a bass guitar" 181 | # random p 182 | p = torch.rand(dummy_p.shape).to(device)*2*np.pi 183 | records = [] 184 | zt = embed_text([PROMPT,PROMPT])[:1] 185 | 186 | generation=0 187 | while True: 188 | # synthesize 189 | audio = [ 190 | effect_chain(source_audio, activation(p[i])) 191 | for i in range(BATCH_SIZE) 192 | ] 193 | # turn into tensor 194 | audio = torch.tensor(audio).float().to(device) 195 | # peak normalize each sample 196 | peaks = torch.max(torch.abs(audio),dim=1,keepdim=True)[0] 197 | audio = audio/peaks 198 | # embed audio 199 | za = embed_audio(audio) 200 | 201 | # novelty search 202 | 203 | # 204 | # get fitness by measuring similarity to target 205 | similarity = torch.nn.functional.cosine_similarity(za,zt) 206 | for b in range(BATCH_SIZE): 207 | records.append({"generation":generation,"similarity":similarity[b].item(),"p":p[b].detach().cpu().numpy()}) 208 | 209 | fitness = torch.softmax(similarity/TEMPERATURE,dim=0) 210 | 211 | p1 = p[torch.multinomial(fitness,BATCH_SIZE,replacement=True)] 212 | p2 = p[torch.multinomial(fitness,BATCH_SIZE,replacement=True)] 213 | 214 | p = crossover(p1,p2) 215 | 216 | p = mutate(p) 217 | 218 | generation+=1 219 | 220 | if generation%1==0: 221 | 222 | # clear output 223 | clear_output(wait=True) 224 | # plot sorted fitness 225 | plt.plot(torch.sort(fitness).values.detach().cpu().numpy()) 226 | plt.show() 227 | 228 | # show scatter plot of similarity 229 | sns.scatterplot(data=pd.DataFrame(records),x="generation",y="similarity",alpha=0.5) 230 | plt.show() 231 | # play audio of samples sorted by similarity 232 | play(audio[torch.argsort(-similarity)].flatten().detach().cpu().numpy()) 233 | 234 | 235 | # %% 236 | -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | textsynth logo 9 |

10 | Generate synthesizer sounds from text prompts with a simple evolutionary algorithm. 11 |

12 | Github link 13 | 14 |

15 | 16 | 17 |

18 |

19 | Start with randomly initialized synthesizer sounds. Each iteration, the current synthesizer sounds are evaluated 20 | on how well they match the text prompt. The best sounds are then combined and mutated to generate new sounds for 21 | the next iteration. 200 generations w/ 50 samples takes about ~20s on a 3090 (not tested on CPU). 22 |

23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 | 79 | 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /misc/evolution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erl-j/textsynth/f6e2ef6268c0c2d43eb67b6b1db232ca4cf2a089/misc/evolution.png -------------------------------------------------------------------------------- /misc/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /textsynth.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import torch 3 | import numpy as np 4 | import seaborn as sns 5 | import matplotlib.pyplot as plt 6 | import torch 7 | from torchsynth.synth import Voice,SynthConfig 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | from IPython.display import clear_output 11 | from sklearn.cluster import KMeans 12 | from tqdm import tqdm 13 | import IPython.display as display 14 | import pandas as pd 15 | import laion_clap 16 | import soundfile as sf 17 | 18 | SAMPLE_RATE=48000 19 | 20 | # Run on the GPU if it's available 21 | if torch.cuda.is_available(): 22 | device = torch.device("cuda:3") 23 | else: 24 | device = torch.device("cpu") 25 | 26 | # disable gradient calculation 27 | torch.set_grad_enabled(False) 28 | 29 | def play(a): 30 | display.display(display.Audio(a, rate=SAMPLE_RATE)) 31 | 32 | class SynthWrapper(): 33 | def __init__(self,synth): 34 | self.synth=synth 35 | self.dummy_parameter_dict=self.synth.get_parameters() 36 | 37 | def parameterdict2tensor(self,parameters): 38 | out=[] 39 | for p in parameters.values(): 40 | out.append(p.data) 41 | return torch.stack(out,dim=-1) 42 | 43 | def tensor2parameterdict(self,tensor): 44 | parameter_dict = self.dummy_parameter_dict.copy() 45 | for i,key in enumerate(parameter_dict.keys()): 46 | parameter_dict[key].data=tensor[:,i] 47 | return parameter_dict 48 | 49 | def from_0to1(self,parameterdict): 50 | for key in parameterdict.keys(): 51 | parameterdict[key].data=parameterdict[key].from_0to1() 52 | return parameterdict 53 | 54 | def synthesize(self,tensor,MIDI_F0=None): 55 | with torch.no_grad(): 56 | parameter_dict=self.tensor2parameterdict(tensor) 57 | #parameter_dict=self.from_0to1(parameter_dict) 58 | if MIDI_F0 is not None: 59 | parameter_dict[('keyboard', 'midi_f0')].data=parameter_dict[('keyboard', 'midi_f0')].data*0.0+MIDI_F0/127.0 60 | self.synth.freeze_parameters(parameter_dict) 61 | return self.synth.output() 62 | 63 | def get_number_of_parameters(self,): 64 | return len(self.dummy_parameter_dict.keys()) 65 | 66 | def get_parameter_tensor(self,): 67 | return self.parameterdict2tensor(self.synth.get_parameters()) 68 | 69 | CLIP_DURATION=1 70 | BATCH_SIZE=50 71 | N_PARENTS=64 72 | 73 | N_SAVED_PER_TARGET=16 74 | 75 | MUTATION_RATE=0.1 76 | TEMPERATURE=0.1 77 | 78 | MIDI_F0=53 79 | 80 | config = SynthConfig(batch_size=BATCH_SIZE,sample_rate=SAMPLE_RATE,reproducible=False,buffer_size_seconds=CLIP_DURATION) 81 | synth = SynthWrapper(Voice(config).to(device)) 82 | dummy_p=synth.get_parameter_tensor() 83 | activation = lambda x: (torch.sin(x)+1.0)/2.0 #torch.nn.functional.sigmoid(x)# 84 | 85 | clap_model = laion_clap.CLAP_Module(enable_fusion=True, device=device) 86 | clap_model.load_ckpt() # download the default pretrained checkpoint. 87 | 88 | def embed_audio(audio_data): 89 | audio_embed = clap_model.get_audio_embedding_from_data(x = audio_data, use_tensor=True) 90 | return audio_embed 91 | 92 | def embed_text(text_data): 93 | text_embed = clap_model.get_text_embedding(text_data, use_tensor=True) 94 | return text_embed 95 | 96 | # random p 97 | p = torch.rand(dummy_p.shape).to(device)*2*np.pi 98 | 99 | def mutate(p): 100 | p += torch.randn(p.shape,device=p.device)*MUTATION_RATE 101 | return p 102 | 103 | def crossover(p1,p2): 104 | mask = torch.rand(p1.shape,device=p1.device)>0.5 105 | return p1*mask+p2*(~mask) 106 | 107 | prompts=[ 108 | "kick drum", 109 | "human scream", 110 | "808 bass", 111 | "piano", 112 | "hihat", 113 | "whistling", 114 | "rain", 115 | "bird", 116 | ] 117 | for PROMPT in prompts: 118 | # random p 119 | p = torch.rand(dummy_p.shape).to(device)*2*np.pi 120 | records = [] 121 | zt = embed_text([PROMPT,PROMPT])[:1] 122 | 123 | generation=0 124 | best_audio = [] 125 | for gens in tqdm(range(200)): 126 | # synthesize 127 | audio = synth.synthesize(activation(p),MIDI_F0) 128 | # peak normalize each sample 129 | peaks = torch.max(torch.abs(audio),dim=1,keepdim=True)[0] 130 | audio = audio/peaks 131 | # embed audio 132 | za = embed_audio(audio) 133 | 134 | # TODO: diversity preservation 135 | 136 | # get fitness by measuring similarity to target 137 | similarity = torch.nn.functional.cosine_similarity(za,zt) 138 | for b in range(BATCH_SIZE): 139 | records.append({"generation":generation,"similarity":similarity[b].item(),"p":p[b].detach().cpu().numpy()}) 140 | 141 | fitness = torch.softmax(similarity/TEMPERATURE,dim=0) 142 | 143 | p1 = p[torch.multinomial(fitness,BATCH_SIZE,replacement=True)] 144 | p2 = p[torch.multinomial(fitness,BATCH_SIZE,replacement=True)] 145 | 146 | p = crossover(p1,p2) 147 | 148 | p = mutate(p) 149 | 150 | generation+=1 151 | 152 | # save best audio 153 | best_audio.append(audio[torch.argmax(fitness)].detach().cpu().numpy()) 154 | 155 | if generation%100==0: 156 | 157 | # clear output 158 | clear_output(wait=True) 159 | # plot sorted fitness 160 | plt.plot(torch.sort(fitness).values.detach().cpu().numpy()) 161 | plt.show() 162 | 163 | # show scatter plot of similarity 164 | sns.scatterplot(data=pd.DataFrame(records),x="generation",y="similarity",alpha=0.5) 165 | plt.show() 166 | # play audio of samples sorted by similarity 167 | play(audio[torch.argsort(-similarity)].flatten().detach().cpu().numpy()) 168 | 169 | sf.write(f"./results/{PROMPT}_final.wav",audio[torch.argsort(-similarity)].flatten().detach().cpu().numpy(),SAMPLE_RATE) 170 | 171 | df = pd.DataFrame(records) 172 | 173 | # edit together all best audio with 0.25 second each 174 | audio = np.array(best_audio)[:,:int(SAMPLE_RATE*0.5)].flatten() 175 | 176 | sf.write(f"./results/{PROMPT}_evolution.wav",audio,SAMPLE_RATE) --------------------------------------------------------------------------------