├── .gitignore ├── LICENSE ├── README.md ├── audio_prepro.py ├── coala.py ├── coala ├── id2token_top_1000.json └── scaler_top_1000.pkl ├── main.py ├── requirements.txt ├── utils.py ├── wavegan.py └── word2wave.py /.gitignore: -------------------------------------------------------------------------------- 1 | .gitignore 2 | **venv/ 3 | .vscode 4 | output* 5 | **/__pycache__/* 6 | coala/models/ 7 | **.tar 8 | /wavegan/ 9 | *.png 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Ilaria Manco 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Word2Wave 2 | 3 | Word2Wave is a simple method for text-controlled GAN audio generation. You can either follow the setup instructions below and use the source code and CLI provided in this repo or you can have a play around in the Colab notebook provided. Note that, in both cases, you will need to train a WaveGAN model first. You can also hear some examples [here](https://ilariamanco.com/word2wave/). 4 | 5 | 6 | Colab playground [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1c9DdSN_oiv0rcL9SH-W8-jfhcQf6iVYy?usp=sharing) 7 | 8 | ## Setup 9 | 10 | First, clone the repository 11 | ```clone 12 | git clone https://www.github.com/ilaria-manco/word2wave 13 | ``` 14 | 15 | Create a virtual environment and install the requirements: 16 | ```setup 17 | cd word2wave 18 | python3 -m venv /path/to/venv/ 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | ### WaveGAN generator 23 | Word2Wave requires a pre-trained WaveGAN generator. In my experiments, I trained my own on the [Freesound Loop Dataset](https://zenodo.org/record/3967852#.YIlF931KhhE), using [this implementation](https://github.com/mostafaelaraby/wavegan-pytorch). To download the FSL dataset do: 24 | 25 | ```bash 26 | $ wget https://zenodo.org/record/3967852/files/FSL10K.zip?download=1 27 | ``` 28 | 29 | and then train following the instructions in the WaveGAN repo. Once trained, place the model in the `wavegan` folder: 30 | 31 | ``` 32 | 📂wavegan 33 | ┗ 📜gan_.tar 34 | ``` 35 | 36 | ### Pre-trained COALA encoders 37 | You'll need to download the pre-trained weights for the COALA tag and audio encoders from the official [repo](https://github.com/xavierfav/coala). Note that the repo provides weights for the model trained with different configurations (e.g. different weights in the loss components). For more details on this, you can refer to the original code and paper. To download the model weights, you can run the following commands (or the equivalent for the desired model configuration) 38 | 39 | ```bash 40 | $ wget https://raw.githubusercontent.com/xavierfav/coala/master/saved_models/dual_ae_c/audio_encoder_epoch_200.pt 41 | $ wget https://raw.githubusercontent.com/xavierfav/coala/master/saved_models/dual_ae_c/tag_encoder_epoch_200.pt 42 | ``` 43 | 44 | Once downloaded, place them in the `coala/models` folder: 45 | ``` 46 | 📂coala 47 | ┣ 📂models 48 | ┣ 📂dual_ae_c 49 | ┣ 📜audio_encoder_epoch_200.pt 50 | ┗ 📜tag_encoder_epoch_200.pt 51 | ``` 52 | 53 | ## How to use 54 | For text-to-audio generation using the default parameters, simply do 55 | 56 | ``` 57 | $ python main.py "text prompt" --wavegan_path --output_dir 58 | ``` 59 | 60 | ## Citations 61 | Some of the code in this repo is adapted from the official [COALA repo](https://github.com/xavierfav/coala) and @mostafaelaraby's [PyTorch implenentation](https://github.com/mostafaelaraby/wavegan-pytorch) of the WaveGAN model. 62 | 63 | ```bibtex 64 | @inproceedings{donahue2018adversarial, 65 | title={Adversarial Audio Synthesis}, 66 | author={Donahue, Chris and McAuley, Julian and Puckette, Miller}, 67 | booktitle={International Conference on Learning Representations}, 68 | year={2018} 69 | } 70 | ``` 71 | 72 | ```bibtex 73 | @article{favory2020coala, 74 | title={Coala: Co-aligned autoencoders for learning semantically enriched audio representations}, 75 | author={Favory, Xavier and Drossos, Konstantinos and Virtanen, Tuomas and Serra, Xavier}, 76 | journal={arXiv preprint arXiv:2006.08386}, 77 | year={2020} 78 | } 79 | ``` 80 | -------------------------------------------------------------------------------- /audio_prepro.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import torch 3 | import torchaudio 4 | from torchaudio import transforms as T 5 | from matplotlib import pyplot as plt 6 | 7 | n_fft = 1024 8 | win_length = None 9 | hop_length = 512 10 | n_mels = 96 11 | sample_rate = 16000 12 | 13 | mel_spectrogram = T.MelSpectrogram( 14 | sample_rate=sample_rate, 15 | n_fft=n_fft, 16 | win_length=win_length, 17 | hop_length=hop_length, 18 | center=True, 19 | pad_mode="reflect", 20 | power=1.0, 21 | norm='slaney', 22 | onesided=True, 23 | n_mels=n_mels, 24 | window_fn=torch.hamming_window 25 | ) 26 | 27 | def resample(source_sr, target_sr): 28 | resample_transform = T.Resample(source_sr, target_sr) 29 | return resample_transform 30 | 31 | def plot_spectrogram(spec, title=None, ylabel='freq_bin', aspect='auto', xmax=None): 32 | fig, axs = plt.subplots(1, 1) 33 | axs.set_title(title or 'Spectrogram (db)') 34 | axs.set_ylabel(ylabel) 35 | axs.set_xlabel('frame') 36 | im = axs.imshow(librosa.power_to_db(spec), origin='lower', aspect=aspect) 37 | if xmax: 38 | axs.set_xlim((0, xmax)) 39 | fig.colorbar(im, ax=axs) 40 | plt.show(block=False) 41 | 42 | plt.savefig("spec.png") 43 | 44 | def pad(tensor, sampe_rate): 45 | z = torch.zeros(10*sample_rate, dtype=torch.float32) 46 | z[:tensor.size(0)] = tensor 47 | z = z + 5*1e-4*torch.rand(z.size(0)) 48 | return z 49 | 50 | def preprocess_audio(audio="/content/test_file.wav", transform=mel_spectrogram): 51 | if isinstance(audio, str): 52 | audio, sr = torchaudio.load(audio) 53 | audio = resample(sr, sample_rate)(audio) 54 | # downmix to mono 55 | audio = torch.mean(audio, dim=0) 56 | else: 57 | pass 58 | # audio = audio[:sample_rate] 59 | audio = pad(audio, sample_rate) 60 | if transform is not None: 61 | audio = transform(audio)[:96, :96] 62 | audio = torch.log(audio + torch.finfo(torch.float32).eps) 63 | return audio 64 | -------------------------------------------------------------------------------- /coala.py: -------------------------------------------------------------------------------- 1 | """ Code from the original COALA implementation: https://github.com/xavierfav/coala.""" 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import Sequential, Linear, Dropout, ReLU, Sigmoid, Conv2d, ConvTranspose2d, BatchNorm1d, BatchNorm2d, LeakyReLU 6 | 7 | class Flatten(nn.Module): 8 | def forward(self, input): 9 | return input.view(input.size(0), -1) 10 | 11 | 12 | class UnFlatten(nn.Module): 13 | def forward(self, input, size=128): 14 | return input.view(input.size(0), size, 3, 3) 15 | 16 | 17 | class AudioEncoder(nn.Module): 18 | def __init__(self): 19 | super(AudioEncoder, self).__init__() 20 | 21 | self.audio_encoder = Sequential( 22 | Conv2d(1, 128, kernel_size=4, stride=2, padding=1, padding_mode='zeros'), 23 | BatchNorm2d(128), 24 | ReLU(), # 128x48x48 25 | Dropout(.25), 26 | Conv2d(128, 128, kernel_size=4, stride=2, padding=1, padding_mode='zeros'), 27 | BatchNorm2d(128), 28 | ReLU(), # 128x24x24 29 | Dropout(.25), 30 | Conv2d(128, 128, kernel_size=4, stride=2, padding=1, padding_mode='zeros'), 31 | BatchNorm2d(128), 32 | ReLU(), # 128x12x12 33 | Dropout(.25), 34 | Conv2d(128, 128, kernel_size=4, stride=2, padding=1, padding_mode='zeros'), 35 | BatchNorm2d(128), 36 | ReLU(), # 128x6x6 37 | Dropout(.25), 38 | Conv2d(128, 128, kernel_size=4, stride=2, padding=1, padding_mode='zeros'), 39 | BatchNorm2d(128), 40 | ReLU(), # 128x3x3 41 | Dropout(.25), 42 | Flatten(), 43 | ) 44 | 45 | self.fc_audio = Sequential( 46 | Linear(1152, 1152, bias=False), 47 | Dropout(0.25), 48 | ) 49 | 50 | def forward(self, x): 51 | z = self.audio_encoder(x) 52 | z_d = self.fc_audio(z) 53 | return z, z_d 54 | 55 | 56 | class TagEncoder(nn.Module): 57 | def __init__(self): 58 | super(TagEncoder, self).__init__() 59 | 60 | self.tag_encoder = Sequential( 61 | Linear(1000, 512), 62 | BatchNorm1d(512), 63 | ReLU(), 64 | Dropout(.25), 65 | Linear(512, 512), 66 | BatchNorm1d(512), 67 | ReLU(), 68 | Dropout(.25), 69 | Linear(512, 1152), 70 | BatchNorm1d(1152), 71 | ReLU(), 72 | Dropout(.25), 73 | ) 74 | 75 | self.fc_tag = Sequential( 76 | Linear(1152, 1152, bias=False), 77 | Dropout(.25), 78 | ) 79 | 80 | def forward(self, tags): 81 | z = self.tag_encoder(tags) 82 | z_d = self.fc_tag(z) 83 | return z, z_d 84 | -------------------------------------------------------------------------------- /coala/id2token_top_1000.json: -------------------------------------------------------------------------------- 1 | {"0": "close", "1": "closing", "2": "drawer", "3": "foley", "4": "kitchen", "5": "open", "6": "opening", "7": "bell", "8": "chime", "9": "cymbal", "10": "finger", "11": "hand", "12": "percussion", "13": "crunch", "14": "movement", "15": "rustle", "16": "spooky", "17": "magazine", "18": "page", "19": "paper", "20": "turn", "21": "hit", "22": "orchestral", "23": "shaker", "24": "bird", "25": "drum", "26": "kick", "27": "sample", "28": "snare", "29": "bassdrum", "30": "multisample", "31": "click", "32": "computer", "33": "game", "34": "interface", "35": "mouse", "36": "office", "37": "sound", "38": "acoustic", "39": "analog", "40": "fantasy", "41": "feedback", "42": "film", "43": "microphone", "44": "movie", "45": "retro", "46": "signal", "47": "soundesign", "48": "soundtrack", "49": "synth", "50": "synthesizer", "51": "vintage", "52": "alien", "53": "death", "54": "insect", "55": "pain", "56": "squelch", "57": "wet", "58": "bas", "59": "beat", "60": "dance", "61": "dark", "62": "electric", "63": "electro", "64": "loop", "65": "button", "66": "pres", "67": "switch", "68": "tick", "69": "ambiance", "70": "ambience", "71": "ambient", "72": "atmosphere", "73": "cinematic", "74": "looping", "75": "piano", "76": "funk", "77": "groove", "78": "animal", "79": "cry", "80": "digital", "81": "gameboy", "82": "glitch", "83": "sfx", "84": "synthesized", "85": "broken", "86": "drone", "87": "futuristic", "88": "noise", "89": "gun", "90": "military", "91": "rifle", "92": "shot", "93": "bang", "94": "cartoon", "95": "fx", "96": "clap", "97": "clapping", "98": "female", "99": "sexy", "100": "woman", "101": "creepy", "102": "drama", "103": "fear", "104": "ghost", "105": "haunted", "106": "horror", "107": "scary", "108": "suspense", "109": "terrifying", "110": "terror", "111": "thrill", "112": "clip", "113": "edit", "114": "transition", "115": "monster", "116": "reverb", "117": "voice", "118": "bark", "119": "dog", "120": "human", "121": "belch", "122": "burp", "123": "audio", "124": "breath", "125": "groan", "126": "male", "127": "moan", "128": "vocal", "129": "vocalization", "130": "zombie", "131": "machine", "132": "videogame", "133": "fiction", "134": "science", "135": "shortwave", "136": "square", "137": "wave", "138": "gabber", "139": "hardcore", "140": "kickdrum", "141": "techno", "142": "body", "143": "bubble", "144": "liquid", "145": "bpm", "146": "hip", "147": "hiphop", "148": "hop", "149": "fart", "150": "car", "151": "metal", "152": "resonance", "153": "rim", "154": "ring", "155": "wheel", "156": "action", "157": "attack", "158": "classic", "159": "fire", "160": "firing", "161": "galaxy", "162": "laser", "163": "shoot", "164": "shooting", "165": "space", "166": "spaceship", "167": "strike", "168": "video", "169": "weapon", "170": "cat", "171": "meow", "172": "pet", "173": "halloween", "174": "old", "175": "silly", "176": "groovy", "177": "perc", "178": "rhythm", "179": "man", "180": "mouth", "181": "guitar", "182": "jump", "183": "hinge", "184": "squeaky", "185": "short", "186": "arcade", "187": "abstract", "188": "crazy", "189": "electronic", "190": "freaky", "191": "random", "192": "strange", "193": "weird", "194": "battle", "195": "blade", "196": "clang", "197": "clank", "198": "combat", "199": "fight", "200": "medieval", "201": "pipe", "202": "sliding", "203": "sword", "204": "high", "205": "quality", "206": "beam", "207": "knock", "208": "thud", "209": "flick", "210": "low", "211": "phaser", "212": "buzz", "213": "ground", "214": "hum", "215": "effect", "216": "explode", "217": "impact", "218": "thump", "219": "echo", "220": "drop", "221": "fall", "222": "soft", "223": "texture", "224": "bag", "225": "plastic", "226": "sinister", "227": "knife", "228": "engine", "229": "motor", "230": "english", "231": "speak", "232": "girl", "233": "spoken", "234": "scream", "235": "screaming", "236": "filter", "237": "industrial", "238": "mic", "239": "slide", "240": "ufo", "241": "ethnic", "242": "harp", "243": "melodic", "244": "melody", "245": "riff", "246": "string", "247": "drip", "248": "splash", "249": "water", "250": "indoor", "251": "box", "252": "metallic", "253": "small", "254": "object", "255": "gunshot", "256": "metalic", "257": "spring", "258": "beep", "259": "bloop", "260": "gamesound", "261": "plop", "262": "hat", "263": "door", "264": "tech", "265": "unlock", "266": "gate", "267": "oneshot", "268": "shut", "269": "slam", "270": "chiptune", "271": "cool", "272": "school", "273": "tune", "274": "baby", "275": "child", "276": "design", "277": "synthetic", "278": "sub", "279": "distortion", "280": "big", "281": "deep", "282": "heavy", "283": "delay", "284": "fast", "285": "owi", "286": "scrape", "287": "slow", "288": "steel", "289": "burst", "290": "hard", "291": "ringing", "292": "ting", "293": "scratching", "294": "scraping", "295": "power", "296": "bottle", "297": "container", "298": "gold", "299": "footstep", "300": "stair", "301": "swish", "302": "army", "303": "creature", "304": "long", "305": "bassline", "306": "funky", "307": "soundeffect", "308": "speech", "309": "cough", "310": "future", "311": "amsterdam", "312": "time", "313": "coin", "314": "foot", "315": "rock", "316": "running", "317": "shaking", "318": "step", "319": "stone", "320": "walking", "321": "creak", "322": "squeak", "323": "wood", "324": "wooden", "325": "stop", "326": "start", "327": "medium", "328": "crowd", "329": "person", "330": "room", "331": "cut", "332": "rip", "333": "tear", "334": "evil", "335": "laugh", "336": "laughing", "337": "robot", "338": "robotic", "339": "break", "340": "breaking", "341": "glas", "342": "percussive", "343": "chain", "344": "iron", "345": "rattle", "346": "ding", "347": "key", "348": "stretch", "349": "cup", "350": "pouring", "351": "ignition", "352": "lighter", "353": "pan", "354": "growl", "355": "house", "356": "nature", "357": "rain", "358": "storm", "359": "thunder", "360": "alert", "361": "cell", "362": "notification", "363": "phone", "364": "ringtone", "365": "tone", "366": "android", "367": "automation", "368": "cyborg", "369": "droid", "370": "mechanical", "371": "artificial", "372": "electromechanic", "373": "machinery", "374": "breakbeat", "375": "beatbox", "376": "magic", "377": "spell", "378": "electricity", "379": "processed", "380": "roar", "381": "acid", "382": "chord", "383": "club", "384": "dub", "385": "rave", "386": "trance", "387": "station", "388": "forest", "389": "item", "390": "pitch", "391": "rubber", "392": "zoom", "393": "wav", "394": "his", "395": "tape", "396": "beast", "397": "experimental", "398": "lofi", "399": "blast", "400": "bomb", "401": "boom", "402": "explosion", "403": "explosive", "404": "firework", "405": "pistol", "406": "soldier", "407": "train", "408": "woosh", "409": "scifi", "410": "chair", "411": "floor", "412": "moving", "413": "tile", "414": "beer", "415": "food", "416": "garbage", "417": "trash", "418": "air", "419": "breathing", "420": "dripping", "421": "faucet", "422": "home", "423": "name", "424": "talk", "425": "word", "426": "sink", "427": "pop", "428": "table", "429": "creaky", "430": "handle", "431": "brush", "432": "falling", "433": "dish", "434": "plate", "435": "ceramic", "436": "fork", "437": "clean", "438": "household", "439": "rub", "440": "chip", "441": "dirty", "442": "sweep", "443": "cleaner", "444": "window", "445": "empty", "446": "spray", "447": "swoosh", "448": "pot", "449": "creaking", "450": "squeaking", "451": "music", "452": "drink", "453": "soda", "454": "shake", "455": "soundfx", "456": "street", "457": "bd", "458": "ride", "459": "hh", "460": "tom", "461": "roll", "462": "chinese", "463": "gong", "464": "violin", "465": "pad", "466": "soundscape", "467": "keyboard", "468": "flute", "469": "snap", "470": "fill", "471": "pour", "472": "loud", "473": "scratch", "474": "spanish", "475": "dj", "476": "layer", "477": "swosh", "478": "whoosh", "479": "clicking", "480": "whistle", "481": "wolf", "482": "swing", "483": "group", "484": "speaking", "485": "talking", "486": "war", "487": "reload", "488": "simple", "489": "ui", "490": "edm", "491": "background", "492": "sounddesign", "493": "quantized", "494": "recording", "495": "popping", "496": "blood", "497": "gore", "498": "squish", "499": "board", "500": "menu", "501": "punch", "502": "error", "503": "factory", "504": "line", "505": "shatter", "506": "wobble", "507": "psy", "508": "desktop", "509": "startup", "510": "typing", "511": "move", "512": "ping", "513": "bathroom", "514": "passage", "515": "mono", "516": "jazz", "517": "phrase", "518": "single", "519": "stereo", "520": "tambourine", "521": "eerie", "522": "vehicle", "523": "tribal", "524": "shoe", "525": "cold", "526": "freeze", "527": "intro", "528": "grunt", "529": "applause", "530": "snapping", "531": "book", "532": "pencil", "533": "eating", "534": "hardstyle", "535": "screech", "536": "aman", "537": "pitched", "538": "dropping", "539": "cardboard", "540": "shower", "541": "ice", "542": "cute", "543": "happy", "544": "kid", "545": "young", "546": "crackle", "547": "noisy", "548": "record", "549": "turntable", "550": "vinyl", "551": "dry", "552": "funny", "553": "siren", "554": "flame", "555": "aluminum", "556": "crush", "557": "camera", "558": "drumloop", "559": "bras", "560": "bullet", "561": "electronica", "562": "bounce", "563": "closed", "564": "lock", "565": "acapella", "566": "fly", "567": "singing", "568": "vox", "569": "howl", "570": "kalimba", "571": "xylophone", "572": "blip", "573": "dropped", "574": "vacuum", "575": "zip", "576": "undead", "577": "golpe", "578": "live", "579": "rythm", "580": "note", "581": "vocoder", "582": "plucked", "583": "city", "584": "craft", "585": "cloth", "586": "analogue", "587": "raw", "588": "vibrato", "589": "hihat", "590": "crash", "591": "disgusting", "592": "gros", "593": "stick", "594": "natural", "595": "giggle", "596": "tension", "597": "edited", "598": "layered", "599": "recorded", "600": "friction", "601": "grind", "602": "shuffle", "603": "transformation", "604": "select", "605": "announcement", "606": "call", "607": "bike", "608": "drive", "609": "piezo", "610": "smash", "611": "wind", "612": "drumkit", "613": "ga", "614": "drag", "615": "dubstep", "616": "atmospheric", "617": "dnb", "618": "die", "619": "warning", "620": "slap", "621": "work", "622": "laughter", "623": "leaf", "624": "distorted", "625": "audience", "626": "bar", "627": "cheer", "628": "improvised", "629": "request", "630": "radio", "631": "static", "632": "loopable", "633": "chirp", "634": "whisper", "635": "dramatic", "636": "rise", "637": "tap", "638": "impulse", "639": "fighting", "640": "shout", "641": "command", "642": "test", "643": "mallet", "644": "pedal", "645": "push", "646": "sustain", "647": "flip", "648": "rap", "649": "bowl", "650": "clink", "651": "money", "652": "real", "653": "realistic", "654": "currency", "655": "rhythmic", "656": "korg", "657": "snarl", "658": "angry", "659": "bone", "660": "flesh", "661": "gaming", "662": "bit", "663": "fm", "664": "compressed", "665": "shutter", "666": "whip", "667": "light", "668": "mobile", "669": "telephone", "670": "sharp", "671": "lid", "672": "farm", "673": "gravel", "674": "outdoor", "675": "walk", "676": "one", "677": "construction", "678": "hardware", "679": "tool", "680": "toy", "681": "balloon", "682": "harmonica", "683": "instrument", "684": "horn", "685": "studio", "686": "large", "687": "yamaha", "688": "tin", "689": "extreme", "690": "application", "691": "bright", "692": "gui", "693": "rumble", "694": "crack", "695": "remix", "696": "ship", "697": "organic", "698": "animation", "699": "blowing", "700": "american", "701": "hi", "702": "jungle", "703": "pressure", "704": "minimal", "705": "fun", "706": "epic", "707": "ball", "708": "jar", "709": "crunchy", "710": "spoon", "711": "sine", "712": "pc", "713": "next", "714": "text", "715": "white", "716": "alarm", "717": "clock", "718": "pull", "719": "hall", "720": "desk", "721": "detuned", "722": "blow", "723": "match", "724": "oscillator", "725": "lfo", "726": "synthesi", "727": "waveform", "728": "clatter", "729": "modular", "730": "bu", "731": "field", "732": "tube", "733": "hammer", "734": "wine", "735": "run", "736": "velocity", "737": "hurt", "738": "tabla", "739": "zap", "740": "fist", "741": "knocking", "742": "moog", "743": "cracking", "744": "zipper", "745": "elevator", "746": "pong", "747": "atari", "748": "boing", "749": "saw", "750": "fi", "751": "sci", "752": "drill", "753": "punchy", "754": "tight", "755": "ripping", "756": "harsh", "757": "concrete", "758": "jingle", "759": "smack", "760": "snow", "761": "djembe", "762": "contrabas", "763": "organ", "764": "splat", "765": "generator", "766": "tree", "767": "coffee", "768": "sport", "769": "new", "770": "bongo", "771": "conga", "772": "barking", "773": "yell", "774": "metro", "775": "bleep", "776": "indiedev", "777": "audacity", "778": "gameaudio", "779": "gamedev", "780": "granular", "781": "datum", "782": "stab", "783": "portuguese", "784": "number", "785": "block", "786": "annoying", "787": "pickup", "788": "boy", "789": "double", "790": "bath", "791": "letter", "792": "ableton", "793": "lead", "794": "convolution", "795": "ir", "796": "response", "797": "wah", "798": "tv", "799": "formant", "800": "tearing", "801": "dragon", "802": "rolling", "803": "free", "804": "triangle", "805": "generated", "806": "japanese", "807": "trumpet", "808": "spark", "809": "set", "810": "weather", "811": "rough", "812": "training", "813": "toilet", "814": "destruction", "815": "base", "816": "rpg", "817": "cd", "818": "pulse", "819": "bend", "820": "bending", "821": "bent", "822": "circuit", "823": "comedy", "824": "drain", "825": "morph", "826": "stinger", "827": "transformer", "828": "casio", "829": "sampled", "830": "bos", "831": "indie", "832": "sampling", "833": "orchestra", "834": "fuzzy", "835": "metronome", "836": "slash", "837": "sampler", "838": "reaktor", "839": "midi", "840": "chippy", "841": "breakcore", "842": "idm", "843": "reverse", "844": "vibrate", "845": "mix", "846": "slouse", "847": "rubbish", "848": "roland", "849": "event", "850": "staccato", "851": "woodwind", "852": "instrumental", "853": "clash", "854": "pack", "855": "draw", "856": "pen", "857": "underground", "858": "alphabet", "859": "sequence", "860": "god", "861": "cutlery", "862": "vst", "863": "kit", "864": "resonant", "865": "folk", "866": "african", "867": "psytrance", "868": "dialogue", "869": "sigh", "870": "tracker", "871": "root", "872": "sabian", "873": "filtered", "874": "synthetizer", "875": "effected", "876": "harmony", "877": "bamboo", "878": "puzzle", "879": "interaction", "880": "classical", "881": "nylon", "882": "collection", "883": "contact", "884": "drumset", "885": "cz", "886": "czech", "887": "pansk", "888": "panska", "889": "prepared", "890": "cowbell", "891": "custom", "892": "unprocessed", "893": "sax", "894": "saxophone", "895": "sonido", "896": "humor", "897": "pro", "898": "pluck", "899": "wire", "900": "fat", "901": "struck", "902": "china", "903": "harmonic", "904": "nintendo", "905": "wub", "906": "zildjian", "907": "punk", "908": "howling", "909": "ddrm", "910": "deckardsdream", "911": "odd", "912": "repetition", "913": "public", "914": "subtractive", "915": "flatulation", "916": "flatulence", "917": "tama", "918": "valve", "919": "syllable", "920": "compmusic", "921": "poot", "922": "experiment", "923": "sandyrb", "924": "freesound", "925": "viru", "926": "image", "927": "malfunction", "928": "skrillex", "929": "mumble", "930": "ludwig", "931": "pizzicato", "932": "accent", "933": "cello", "934": "ti", "935": "carnatic", "936": "technica", "937": "barcelona", "938": "elsodelesculture", "939": "reggae", "940": "gamelan", "941": "goa", "942": "idiophone", "943": "qmul", "944": "cityring", "945": "portugal", "946": "spankmyfilth", "947": "shure", "948": "mezzoforte", "949": "tenuto", "950": "mridangam", "951": "musical", "952": "bassoon", "953": "idradio", "954": "sonsdebarcelona", "955": "rasta", "956": "snr", "957": "kingkorg", "958": "gamedeveloping", "959": "indiegamedev", "960": "digitopia", "961": "porto", "962": "sica", "963": "netherland", "964": "noizdumpster", "965": "thelowerrhythm", "966": "tlr", "967": "germany", "968": "iitm", "969": "multi", "970": "delphidebrain", "971": "crossword", "972": "gridplay", "973": "solving", "974": "electribe", "975": "indiegame", "976": "extended", "977": "technique", "978": "daxophone", "979": "phoneme", "980": "nord", "981": "rekombinacje", "982": "chordophone", "983": "viola", "984": "aerophone", "985": "texttospeech", "986": "ppg", "987": "mojomill", "988": "beyer", "989": "electrovoice", "990": "josephson", "991": "suicidity", "992": "tony", "993": "lesson", "994": "auditory", "995": "metalwork", "996": "spiccato", "997": "oldtrombone", "998": "davood", "999": "bunt"} -------------------------------------------------------------------------------- /coala/scaler_top_1000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ilaria-manco/word2wave/d375f1a3134c96fe386932931b0e9ce9bd6240d5/coala/scaler_top_1000.pkl -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import librosa 4 | import os 5 | import json 6 | import logging 7 | import numpy as np 8 | import torchaudio 9 | 10 | from word2wave import Word2Wave 11 | 12 | logging.basicConfig(level = logging.INFO) 13 | import warnings 14 | warnings.filterwarnings('ignore') 15 | 16 | def play_my_words(text_prompt, args): 17 | word2wave = Word2Wave(args) 18 | word2wave.cuda() 19 | 20 | for name, param in word2wave.named_parameters(): 21 | # if name != "latents" and "generator" not in name: 22 | if name != "latents": 23 | param.requires_grad = False 24 | 25 | optimizer = torch.optim.Adam( 26 | params=[word2wave.latents], 27 | lr=args.lr, 28 | betas=(0.9, 0.999) 29 | ) 30 | 31 | i = 0 32 | 33 | _, words_in_dict, words_not_in_dict = word2wave.tokenize_text(text_prompt) 34 | if not words_in_dict: 35 | raise Exception("All the words in the text prompt are out-of-vocabulary, please try with another prompt") 36 | elif words_not_in_dict: 37 | missing_words = ", ".join(words_not_in_dict) 38 | logging.info("Out-of-vocabulary words found, ignoring: \"{}\"".format(missing_words)) 39 | logging.info("Making sounds to match the following text: {}".format(" ".join(words_in_dict))) 40 | 41 | while i < args.steps: 42 | audio, loss = word2wave(text_prompt) 43 | 44 | optimizer.zero_grad() 45 | loss.backward() 46 | optimizer.step() 47 | 48 | if i % 100 == 0: 49 | print(f'Step {i}', f'|| Loss: {loss.data.cpu().numpy()[0]}') 50 | # print(word2wave.latents) 51 | 52 | if loss <= args.threshold: 53 | break 54 | 55 | i += 1 56 | 57 | audio_to_save = np.array(audio.detach().cpu().numpy()) 58 | librosa.output.write_wav(os.path.join(args.output_dir, text_prompt + ".wav"), audio_to_save, args.sample_rate) 59 | 60 | if loss > args.threshold: 61 | logging.info("The optimisation failed to generate audio that is sufficiently similar to the given prompt. You may wish to try again.") 62 | 63 | 64 | if __name__ == "__main__": 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument("text_prompt", type=str, default="water", help="text prompt to guide the audio generation") 67 | parser.add_argument("--lr", type=float, default=0.04, help="learning rate") 68 | parser.add_argument("--steps", type=int, default=10000, help="number of optimization steps") 69 | parser.add_argument("--coala_model_name", type=str, default="dual_e_c", help="coala model name (can be one of [dual_e_c, dual_ae_c]") 70 | parser.add_argument("--wavegan_path", type=str, default="wavegan/gan_fs_loop_32.tar", help="path to the pretrained wavegan model") 71 | parser.add_argument("--threshold", type=float, default=0.15, help="threshold below which optimisation stops") 72 | parser.add_argument("--batch", type=bool, default=False, help="whether to run batch of experiments with all tags") 73 | parser.add_argument("--output_dir", type=str, default="output_new", help="path to store results") 74 | parser.add_argument("--sample_rate", type=int, default=16000) 75 | 76 | 77 | args = parser.parse_args() 78 | 79 | if not os.path.exists(args.output_dir): 80 | os.mkdir(args.output_dir) 81 | 82 | if args.batch: 83 | id2tag = json.load(open('coala/id2token_top_1000.json', 'rb')) 84 | for id, tag in id2tag.items(): 85 | play_my_words(tag, args) 86 | 87 | else: 88 | play_my_words(args.text_prompt, args) 89 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=2.2.4 2 | numpy>=1.16.3 3 | librosa==0.6.3 4 | pescador>=2.0.1 5 | torch>=1.1.0 6 | tqdm>=4.32.1 7 | numba==0.49.0 8 | torchaudio==0.8.1 9 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | class MinMaxScaler(): 2 | """ 3 | Transforms each channel to the range [0, 1]. 4 | """ 5 | def __call__(self, tensor): 6 | for ch in tensor: 7 | scale = 1.0 / (ch.max(dim=0)[0] - ch.min(dim=0)[0]) 8 | ch.mul_(scale).sub_(ch.min(dim=0)[0]) 9 | return tensor 10 | 11 | def sample_noise(size, latent_dim=100): 12 | noise = torch.FloatTensor(size, latent_dim) 13 | noise.data.normal_() 14 | return noise 15 | 16 | def latent_space_interpolation(generator, n_samples=10, source=None, target=None): 17 | if source is None and target is None: 18 | random_samples = sample_noise(2, 100) 19 | source = random_samples[0] 20 | target = random_samples[1] 21 | with torch.no_grad(): 22 | interpolated_z = [] 23 | for alpha in np.linspace(0, 1, n_samples): 24 | interpolation = alpha * source + ((1 - alpha) * target) 25 | interpolated_z.append(interpolation) 26 | 27 | interpolated_z = torch.stack(interpolated_z) 28 | generated_audio = generator(interpolated_z) 29 | return generated_audio 30 | -------------------------------------------------------------------------------- /wavegan.py: -------------------------------------------------------------------------------- 1 | """Code adapted from https://github.com/mostafaelaraby/wavegan-pytorch""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import Parameter 6 | import torch.nn.functional as F 7 | import torch.utils.data 8 | 9 | 10 | class Transpose1dLayer(nn.Module): 11 | def __init__( 12 | self, 13 | in_channels, 14 | out_channels, 15 | kernel_size, 16 | stride, 17 | padding=11, 18 | upsample=None, 19 | output_padding=1, 20 | use_batch_norm=False, 21 | ): 22 | super(Transpose1dLayer, self).__init__() 23 | self.upsample = upsample 24 | reflection_pad = nn.ConstantPad1d(kernel_size // 2, value=0) 25 | conv1d = nn.Conv1d(in_channels, out_channels, kernel_size, stride) 26 | conv1d.weight.data.normal_(0.0, 0.02) 27 | Conv1dTrans = nn.ConvTranspose1d( 28 | in_channels, out_channels, kernel_size, stride, padding, output_padding 29 | ) 30 | batch_norm = nn.BatchNorm1d(out_channels) 31 | if self.upsample: 32 | operation_list = [reflection_pad, conv1d] 33 | else: 34 | operation_list = [Conv1dTrans] 35 | 36 | if use_batch_norm: 37 | operation_list.append(batch_norm) 38 | self.transpose_ops = nn.Sequential(*operation_list) 39 | 40 | def forward(self, x): 41 | if self.upsample: 42 | x = nn.functional.interpolate(x, scale_factor=self.upsample, mode="nearest") 43 | return self.transpose_ops(x) 44 | 45 | 46 | class Conv1D(nn.Module): 47 | def __init__( 48 | self, 49 | input_channels, 50 | output_channels, 51 | kernel_size, 52 | alpha=0.2, 53 | shift_factor=2, 54 | stride=4, 55 | padding=11, 56 | use_batch_norm=False, 57 | drop_prob=0, 58 | ): 59 | super(Conv1D, self).__init__() 60 | self.conv1d = nn.Conv1d( 61 | input_channels, output_channels, kernel_size, stride=stride, padding=padding 62 | ) 63 | self.batch_norm = nn.BatchNorm1d(output_channels) 64 | self.phase_shuffle = PhaseShuffle(shift_factor) 65 | self.alpha = alpha 66 | self.use_batch_norm = use_batch_norm 67 | self.use_phase_shuffle = shift_factor == 0 68 | self.use_drop = drop_prob > 0 69 | self.dropout = nn.Dropout2d(drop_prob) 70 | 71 | def forward(self, x): 72 | x = self.conv1d(x) 73 | if self.use_batch_norm: 74 | x = self.batch_norm(x) 75 | x = F.leaky_relu(x, negative_slope=self.alpha) 76 | if self.use_phase_shuffle: 77 | x = self.phase_shuffle(x) 78 | if self.use_drop: 79 | x = self.dropout(x) 80 | return x 81 | 82 | 83 | class PhaseShuffle(nn.Module): 84 | """ 85 | Performs phase shuffling, i.e. shifting feature axis of a 3D tensor 86 | by a random integer in {-n, n} and performing reflection padding where 87 | necessary. 88 | """ 89 | 90 | def __init__(self, shift_factor): 91 | super(PhaseShuffle, self).__init__() 92 | self.shift_factor = shift_factor 93 | 94 | def forward(self, x): 95 | if self.shift_factor == 0: 96 | return x 97 | # uniform in (L, R) 98 | k_list = ( 99 | torch.Tensor(x.shape[0]).random_(0, 2 * self.shift_factor + 1) 100 | - self.shift_factor 101 | ) 102 | k_list = k_list.numpy().astype(int) 103 | 104 | # Combine sample indices into lists so that less shuffle operations 105 | # need to be performed 106 | k_map = {} 107 | for idx, k in enumerate(k_list): 108 | k = int(k) 109 | if k not in k_map: 110 | k_map[k] = [] 111 | k_map[k].append(idx) 112 | 113 | # Make a copy of x for our output 114 | x_shuffle = x.clone() 115 | 116 | # Apply shuffle to each sample 117 | for k, idxs in k_map.items(): 118 | if k > 0: 119 | x_shuffle[idxs] = F.pad(x[idxs][..., :-k], (k, 0), mode="reflect") 120 | else: 121 | x_shuffle[idxs] = F.pad(x[idxs][..., -k:], (0, -k), mode="reflect") 122 | 123 | assert x_shuffle.shape == x.shape, "{}, {}".format(x_shuffle.shape, x.shape) 124 | return x_shuffle 125 | 126 | 127 | class WaveGANGenerator(nn.Module): 128 | def __init__( 129 | self, 130 | model_size=64, 131 | ngpus=1, 132 | num_channels=1, 133 | verbose=False, 134 | upsample=True, 135 | slice_len=16384, 136 | use_batch_norm=False, 137 | ): 138 | super(WaveGANGenerator, self).__init__() 139 | assert slice_len in [16384, 32768, 65536] # used to predict longer utterances 140 | 141 | self.ngpus = ngpus 142 | self.model_size = model_size # d 143 | self.num_channels = num_channels # c 144 | latent_dim = 100 145 | self.verbose = verbose 146 | self.use_batch_norm = use_batch_norm 147 | 148 | self.dim_mul = 16 if slice_len == 16384 else 32 149 | 150 | self.fc1 = nn.Linear(latent_dim, 4 * 4 * model_size * self.dim_mul) 151 | self.bn1 = nn.BatchNorm1d(num_features=model_size * self.dim_mul) 152 | 153 | stride = 4 154 | if upsample: 155 | stride = 1 156 | upsample = 4 157 | 158 | deconv_layers = [ 159 | Transpose1dLayer( 160 | self.dim_mul * model_size, 161 | (self.dim_mul * model_size) // 2, 162 | 25, 163 | stride, 164 | upsample=upsample, 165 | use_batch_norm=use_batch_norm, 166 | ), 167 | Transpose1dLayer( 168 | (self.dim_mul * model_size) // 2, 169 | (self.dim_mul * model_size) // 4, 170 | 25, 171 | stride, 172 | upsample=upsample, 173 | use_batch_norm=use_batch_norm, 174 | ), 175 | Transpose1dLayer( 176 | (self.dim_mul * model_size) // 4, 177 | (self.dim_mul * model_size) // 8, 178 | 25, 179 | stride, 180 | upsample=upsample, 181 | use_batch_norm=use_batch_norm, 182 | ), 183 | Transpose1dLayer( 184 | (self.dim_mul * model_size) // 8, 185 | (self.dim_mul * model_size) // 16, 186 | 25, 187 | stride, 188 | upsample=upsample, 189 | use_batch_norm=use_batch_norm, 190 | ), 191 | ] 192 | 193 | if slice_len == 16384: 194 | deconv_layers.append( 195 | Transpose1dLayer( 196 | (self.dim_mul * model_size) // 16, 197 | num_channels, 198 | 25, 199 | stride, 200 | upsample=upsample, 201 | ) 202 | ) 203 | elif slice_len == 32768: 204 | deconv_layers += [ 205 | Transpose1dLayer( 206 | (self.dim_mul * model_size) // 16, 207 | model_size, 208 | 25, 209 | stride, 210 | upsample=upsample, 211 | use_batch_norm=use_batch_norm, 212 | ), 213 | Transpose1dLayer(model_size, num_channels, 25, 2, upsample=upsample), 214 | ] 215 | elif slice_len == 65536: 216 | deconv_layers += [ 217 | Transpose1dLayer( 218 | (self.dim_mul * model_size) // 16, 219 | model_size, 220 | 25, 221 | stride, 222 | upsample=upsample, 223 | use_batch_norm=use_batch_norm, 224 | ), 225 | Transpose1dLayer( 226 | model_size, num_channels, 25, stride, upsample=upsample 227 | ), 228 | ] 229 | else: 230 | raise ValueError("slice_len {} value is not supported".format(slice_len)) 231 | 232 | self.deconv_list = nn.ModuleList(deconv_layers) 233 | for m in self.modules(): 234 | if isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear): 235 | nn.init.kaiming_normal_(m.weight.data) 236 | 237 | def forward(self, x): 238 | x = self.fc1(x).view(-1, self.dim_mul * self.model_size, 16) 239 | if self.use_batch_norm: 240 | x = self.bn1(x) 241 | x = F.relu(x) 242 | if self.verbose: 243 | print(x.shape) 244 | 245 | for deconv in self.deconv_list[:-1]: 246 | x = F.relu(deconv(x)) 247 | if self.verbose: 248 | print(x.shape) 249 | output = torch.tanh(self.deconv_list[-1](x)) 250 | return output 251 | 252 | 253 | class WaveGANDiscriminator(nn.Module): 254 | def __init__( 255 | self, 256 | model_size=64, 257 | ngpus=1, 258 | num_channels=1, 259 | shift_factor=2, 260 | alpha=0.2, 261 | verbose=False, 262 | slice_len=16384, 263 | use_batch_norm=False, 264 | ): 265 | super(WaveGANDiscriminator, self).__init__() 266 | assert slice_len in [16384, 32768, 65536] # used to predict longer utterances 267 | 268 | self.model_size = model_size # d 269 | self.ngpus = ngpus 270 | self.use_batch_norm = use_batch_norm 271 | self.num_channels = num_channels # c 272 | self.shift_factor = shift_factor # n 273 | self.alpha = alpha 274 | self.verbose = verbose 275 | 276 | conv_layers = [ 277 | Conv1D( 278 | num_channels, 279 | model_size, 280 | 25, 281 | stride=4, 282 | padding=11, 283 | use_batch_norm=use_batch_norm, 284 | alpha=alpha, 285 | shift_factor=shift_factor, 286 | ), 287 | Conv1D( 288 | model_size, 289 | 2 * model_size, 290 | 25, 291 | stride=4, 292 | padding=11, 293 | use_batch_norm=use_batch_norm, 294 | alpha=alpha, 295 | shift_factor=shift_factor, 296 | ), 297 | Conv1D( 298 | 2 * model_size, 299 | 4 * model_size, 300 | 25, 301 | stride=4, 302 | padding=11, 303 | use_batch_norm=use_batch_norm, 304 | alpha=alpha, 305 | shift_factor=shift_factor, 306 | ), 307 | Conv1D( 308 | 4 * model_size, 309 | 8 * model_size, 310 | 25, 311 | stride=4, 312 | padding=11, 313 | use_batch_norm=use_batch_norm, 314 | alpha=alpha, 315 | shift_factor=shift_factor, 316 | ), 317 | Conv1D( 318 | 8 * model_size, 319 | 16 * model_size, 320 | 25, 321 | stride=4, 322 | padding=11, 323 | use_batch_norm=use_batch_norm, 324 | alpha=alpha, 325 | shift_factor=0 if slice_len == 16384 else shift_factor, 326 | ), 327 | ] 328 | self.fc_input_size = 256 * model_size 329 | if slice_len == 32768: 330 | conv_layers.append( 331 | Conv1D( 332 | 16 * model_size, 333 | 32 * model_size, 334 | 25, 335 | stride=2, 336 | padding=11, 337 | use_batch_norm=use_batch_norm, 338 | alpha=alpha, 339 | shift_factor=0, 340 | ) 341 | ) 342 | self.fc_input_size = 480 * model_size 343 | elif slice_len == 65536: 344 | conv_layers.append( 345 | Conv1D( 346 | 16 * model_size, 347 | 32 * model_size, 348 | 25, 349 | stride=4, 350 | padding=11, 351 | use_batch_norm=use_batch_norm, 352 | alpha=alpha, 353 | shift_factor=0, 354 | ) 355 | ) 356 | self.fc_input_size = 512 * model_size 357 | 358 | self.conv_layers = nn.ModuleList(conv_layers) 359 | 360 | self.fc1 = nn.Linear(self.fc_input_size, 1) 361 | 362 | for m in self.modules(): 363 | if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear): 364 | nn.init.kaiming_normal_(m.weight.data) 365 | 366 | def forward(self, x): 367 | for conv in self.conv_layers: 368 | x = conv(x) 369 | if self.verbose: 370 | print(x.shape) 371 | x = x.view(-1, self.fc_input_size) 372 | if self.verbose: 373 | print(x.shape) 374 | 375 | return self.fc1(x) 376 | -------------------------------------------------------------------------------- /word2wave.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import logging 5 | from urllib.request import urlretrieve 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from wavegan import WaveGANGenerator 11 | from coala import TagEncoder, AudioEncoder 12 | from audio_prepro import preprocess_audio 13 | 14 | logging.basicConfig(level = logging.INFO) 15 | 16 | class Word2Wave(nn.Module): 17 | def __init__(self, args): 18 | super(Word2Wave, self).__init__() 19 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | self.coala_model_name = args.coala_model_name 21 | self.wavegan_path = args.wavegan_path 22 | 23 | self.load_wavegan() 24 | self.load_coala() 25 | self.init_latents() 26 | 27 | def load_wavegan(self, slice_len=16384, model_size=32): 28 | path_to_model = os.path.join(self.wavegan_path) 29 | self.generator = WaveGANGenerator(slice_len=slice_len, model_size=model_size, use_batch_norm=False,num_channels=1) 30 | checkpoint = torch.load(path_to_model, map_location=self.device) 31 | self.generator.load_state_dict(checkpoint['generator']) 32 | 33 | def load_coala(self): 34 | coala_path = os.path.join("coala/models", self.coala_model_name) 35 | tag_encoder_url = "https://github.com/xavierfav/coala/blob/master/saved_models/{}/tag_encoder_epoch_200.pt".format(self.coala_model_name) 36 | audio_encoder_url = "https://github.com/xavierfav/coala/blob/master/saved_models/{}/audio_encoder_epoch_200.pt".format(self.coala_model_name) 37 | tag_encoder_path = os.path.join(coala_path, os.path.basename(tag_encoder_url)) 38 | audio_encoder_path = os.path.join(coala_path, os.path.basename(audio_encoder_url)) 39 | # TODO below does not work due to corrupted download - download manually instead 40 | if not os.path.exists(coala_path): 41 | os.mkdir(coala_path) 42 | logging.info("Downloading COALA model weights from {}".format(audio_encoder_url)) 43 | urlretrieve(tag_encoder_url, tag_encoder_path) 44 | urlretrieve(audio_encoder_url, audio_encoder_path) 45 | 46 | self.tag_encoder = TagEncoder() 47 | self.tag_encoder.load_state_dict(torch.load(tag_encoder_path)) 48 | self.tag_encoder.eval() 49 | 50 | self.audio_encoder = AudioEncoder() 51 | self.audio_encoder.load_state_dict(torch.load(audio_encoder_path)) 52 | self.audio_encoder.eval() 53 | 54 | id2tag = json.load(open('coala/id2token_top_1000.json', 'rb')) 55 | self.tag2id = {tag: id for id, tag in id2tag.items()} 56 | 57 | def init_latents(self, size=1, latent_dim=100): 58 | noise = torch.FloatTensor(size, latent_dim) 59 | noise.data.normal_() 60 | self.latents = torch.nn.Parameter(noise) 61 | 62 | def tokenize_text(self, text_prompt): 63 | words_not_in_dict = [word for word in text_prompt.split(" ") if word not in self.tag2id.keys()] 64 | words_in_dict = [word for word in text_prompt.split(" ") if word in self.tag2id.keys()] 65 | tokenized_text = [int(self.tag2id[word]) for word in words_in_dict] 66 | return tokenized_text, words_in_dict, words_not_in_dict 67 | 68 | def encode_text(self, text_prompt): 69 | word_ids,_, _ = self.tokenize_text(text_prompt) 70 | sentence_embedding = torch.zeros(1152).to(self.device) 71 | 72 | tag_vector = torch.zeros(len(word_ids), 1000).to(self.device) 73 | for index, word in enumerate(word_ids): 74 | tag_vector[index, word] = 1 75 | 76 | embedding, embedding_d = self.tag_encoder(tag_vector) 77 | sentence_embedding = embedding_d.mean(dim=0) 78 | return sentence_embedding 79 | 80 | def encode_audio(self, audio): 81 | x = preprocess_audio(audio).to(self.device) 82 | scaler = pickle.load(open('coala/scaler_top_1000.pkl', 'rb')) 83 | x *= torch.tensor(scaler.scale_).to(self.device) 84 | x += torch.tensor(scaler.min_).to(self.device) 85 | x = torch.clamp(x, scaler.feature_range[0], scaler.feature_range[1]) 86 | embedding, embedding_d = self.audio_encoder(x.unsqueeze(0).unsqueeze(0)) 87 | return embedding_d 88 | 89 | def latent_space_interpolation(self, latents=None, n_samples=1): 90 | if latents is None: 91 | z_test = sample_noise(2) 92 | else: 93 | z_test = latents 94 | interpolates = [] 95 | for alpha in np.linspace(0, 1, n_samples): 96 | interpolate_vec = alpha * z_test[0] + ((1 - alpha) * z_test[1]) 97 | interpolates.append(interpolate_vec) 98 | interpolates = torch.stack(interpolates) 99 | generated_audio = self.generator(interpolates) 100 | return generated_audio 101 | 102 | def synthesise_audio(self, noise): 103 | generated_audio = self.generator(noise).view(-1) 104 | return generated_audio 105 | 106 | def coala_loss(self, audio, text): 107 | text_embedding = self.encode_text(text) 108 | audio_embedding = self.encode_audio(audio) 109 | 110 | text_embedding = text_embedding / text_embedding.norm() 111 | audio_embedding = audio_embedding / audio_embedding.norm() 112 | 113 | cos_dist = (1 - audio_embedding @ text_embedding.t()) / 2 114 | 115 | return cos_dist 116 | 117 | def forward(self, text): 118 | audio = self.generator(self.latents).view(-1) 119 | loss = self.coala_loss(audio, text) 120 | return audio, loss 121 | --------------------------------------------------------------------------------