├── experiment ├── __init__.py ├── allocate.py ├── experiment1.py └── experiment2.py ├── tests ├── test_models.py ├── test_datasets.py ├── test_utils.py ├── test_audio_env.py ├── test_room_types.py ├── test_agent.py └── test_transforms.py ├── src ├── __init__.py ├── audio_room │ ├── envs │ │ ├── __init__.py │ │ └── audio_env.py │ └── __init__.py ├── constants.py ├── audio_processing.py ├── room_types.py ├── plot_runs.py ├── utils.py ├── models.py ├── agent.py ├── datasets.py └── transforms.py ├── otoworld.png ├── notebooks ├── df_sdr └── df_sir ├── sounds ├── car │ └── car_horn.wav ├── siren │ └── siren.wav ├── samples │ ├── male │ │ ├── 051a050a.wav │ │ ├── 051a050b.wav │ │ └── 051a050c.wav │ └── female │ │ ├── 050a050a.wav │ │ ├── 050a050b.wav │ │ └── 050a050c.wav └── phone │ └── cellphone_ringing.wav ├── requirements.txt ├── LICENSE ├── .gitignore └── README.md /experiment/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.0" -------------------------------------------------------------------------------- /src/audio_room/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .audio_env import AudioEnv 2 | -------------------------------------------------------------------------------- /otoworld.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pseeth/otoworld/HEAD/otoworld.png -------------------------------------------------------------------------------- /notebooks/df_sdr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pseeth/otoworld/HEAD/notebooks/df_sdr -------------------------------------------------------------------------------- /notebooks/df_sir: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pseeth/otoworld/HEAD/notebooks/df_sir -------------------------------------------------------------------------------- /sounds/car/car_horn.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pseeth/otoworld/HEAD/sounds/car/car_horn.wav -------------------------------------------------------------------------------- /sounds/siren/siren.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pseeth/otoworld/HEAD/sounds/siren/siren.wav -------------------------------------------------------------------------------- /sounds/samples/male/051a050a.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pseeth/otoworld/HEAD/sounds/samples/male/051a050a.wav -------------------------------------------------------------------------------- /sounds/samples/male/051a050b.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pseeth/otoworld/HEAD/sounds/samples/male/051a050b.wav -------------------------------------------------------------------------------- /sounds/samples/male/051a050c.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pseeth/otoworld/HEAD/sounds/samples/male/051a050c.wav -------------------------------------------------------------------------------- /sounds/phone/cellphone_ringing.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pseeth/otoworld/HEAD/sounds/phone/cellphone_ringing.wav -------------------------------------------------------------------------------- /sounds/samples/female/050a050a.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pseeth/otoworld/HEAD/sounds/samples/female/050a050a.wav -------------------------------------------------------------------------------- /sounds/samples/female/050a050b.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pseeth/otoworld/HEAD/sounds/samples/female/050a050b.wav -------------------------------------------------------------------------------- /sounds/samples/female/050a050c.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pseeth/otoworld/HEAD/sounds/samples/female/050a050c.wav -------------------------------------------------------------------------------- /src/audio_room/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | register( 4 | id="audio-room-v0", entry_point="audio_room.envs:AudioEnv", 5 | ) 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pyroomacoustics==0.4.1 2 | gym 3 | ffprobe 4 | jupyter 5 | git+git://github.com/nussl/nussl 6 | soxbindings 7 | numba==0.48.0 8 | jupyter 9 | seaborn^XX^ 10 | -------------------------------------------------------------------------------- /tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | # NOTE: to run, need to cd into tests/, then run pytest 2 | import sys 3 | sys.path.append("../src") 4 | sys.path.append('../experiment') 5 | 6 | import room_types 7 | import agent 8 | import constants 9 | import datasets 10 | import experiment1 11 | 12 | def test_buffer_data(): 13 | # TODO 14 | pass -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # NOTE: to run, need to cd into tests/, then run pytest 2 | import sys 3 | sys.path.append("../src") 4 | 5 | import numpy as np 6 | 7 | import utils 8 | import constants 9 | 10 | 11 | def test_choose_random_files(): 12 | # test a range of different numbers of files 13 | for num_sources in range(5): 14 | num_sources = np.random.randint(1, 5) 15 | paths = utils.choose_random_files(num_sources=num_sources) 16 | 17 | assert(len(paths) == num_sources) 18 | 19 | # ensure we collect from correct folder 20 | random_file = np.random.choice(paths) 21 | assert( 22 | random_file.startswith(constants.DIR_FEMALE) \ 23 | or random_file.startswith(constants.DIR_MALE) 24 | ) 25 | 26 | # .wav or .mp3 or.. (need to be audio files) 27 | assert(random_file.endswith(constants.AUDIO_EXTENSION)) 28 | -------------------------------------------------------------------------------- /src/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # source data 4 | DIR_MALE = "../sounds/dry_recordings/dev/051/" 5 | DIR_FEMALE = "../sounds/dry_recordings/dev/050/" 6 | DIR_CAR = '../sounds/car/' 7 | DIR_PHONE = '../sounds/siren/' 8 | AUDIO_EXTENSION = ".wav" 9 | 10 | # saved data during experiment 11 | DATA_PATH = "../data" 12 | DIR_PREV_STATES = os.path.join(DATA_PATH, 'prev_states/') 13 | DIR_NEW_STATES = os.path.join(DATA_PATH, 'new_states/') 14 | DIR_DATASET_ITEMS = os.path.join(DATA_PATH, 'dataset_items/') 15 | MODEL_SAVE_PATH = '../models/' 16 | DIST_URL = "init_dist_to_target.p" 17 | STEPS_URL = "steps_to_completion.p" 18 | REWARD_URL = "rewards_per_episode.p" 19 | PRETRAIN_PATH = '../models/pretrained.pth' 20 | 21 | # audio stuff 22 | RESAMPLE_RATE = 8000 23 | 24 | # env stuff 25 | DIST_BTWN_EARS = 0.15 26 | 27 | # max and min values of exploration rate 28 | MAX_EPSILON = 0.9 29 | MIN_EPSILON = 0.1 30 | 31 | # reward structure (keep as floats) 32 | STEP_PENALTY = -0.5 33 | TURN_OFF_REWARD = 100.0 34 | ORIENT_PENALTY = -0.1 35 | 36 | # dataset 37 | MAX_BUFFER_ITEMS = 10000 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 pseeth 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/test_audio_env.py: -------------------------------------------------------------------------------- 1 | # NOTE: to run, need to cd into tests/, then run pytest 2 | import sys 3 | sys.path.append("../src") 4 | 5 | import gym 6 | import numpy as np 7 | 8 | import audio_room 9 | import utils 10 | import room_types 11 | 12 | def test_audio_env(): 13 | # Shoebox Room 14 | room = room_types.ShoeBox(x_length=10, y_length=10) 15 | 16 | agent_loc = np.array([3, 8]) 17 | 18 | # Set up the audio/gym environment 19 | env = gym.make( 20 | "audio-room-v0", 21 | room_config=room.generate(), 22 | agent_loc=agent_loc, 23 | corners=room.corners, 24 | max_order=10, 25 | step_size=1.0, 26 | acceptable_radius=0.5, 27 | ) 28 | 29 | # store initial room obj 30 | init_room = env.room 31 | 32 | # test step (taking actions) 33 | # remember: 0,0 is at the bottom left 34 | env.step(action=0) # step left 35 | assert(np.allclose(env.agent_loc, np.array([2, 8]))) 36 | env.step(action=1) # step right 37 | assert(np.allclose(env.agent_loc, np.array([3, 8]))) 38 | env.step(action=2) # step up 39 | assert(np.allclose(env.agent_loc, np.array([3, 9]))) 40 | env.step(action=3) # step down 41 | assert(np.allclose(env.agent_loc, np.array([3, 8]))) 42 | 43 | # test move function 44 | env._move_agent([5, 5]) 45 | assert(env.agent_loc == [5, 5]) 46 | 47 | # ensure the room is the same dimensions 48 | # even though its a different q object 49 | new_room = env.room 50 | for idx, wall in enumerate(init_room.walls): 51 | assert(np.allclose(wall.corners, new_room.walls[idx].corners)) 52 | 53 | # test reset 54 | 55 | 56 | -------------------------------------------------------------------------------- /src/audio_processing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def ipd_ild_features(stft_data, sampling_rate=8000): 6 | """ 7 | Computes interphase difference (IPD) and interlevel difference (ILD) for a 8 | stereo spectrogram. 9 | Args: 10 | stft_data (Torch tensor): Tensor of shape batch_size, time_frames, mag+phase, channels, sources 11 | sampling_rate (int): The rate at which data is sampled. Default = 8000 Hz 12 | Returns: 13 | ipd (``Torch tensor`): Interphase difference between two channels 14 | ild (``Torch tensor``): Interlevel difference between two channels 15 | """ 16 | 17 | # The mag and phase are concatenated together, so each of them will have half the dimensionality 18 | mag_phase_dim = stft_data.shape[2]//2 19 | 20 | # Separate the data by channels 21 | stft_ch_one = stft_data[:, :, :, 0] 22 | stft_ch_two = stft_data[:, :, :, 1] 23 | 24 | # Calculate ILD over the magnitudes 25 | 26 | # Extract the magnitudes from the stft data 27 | stft_ch_one_mag = stft_ch_one[:, :, 0:mag_phase_dim] 28 | stft_ch_two_mag = stft_ch_two[:, :, 0:mag_phase_dim] 29 | vol = torch.abs(stft_ch_one_mag) + torch.abs(stft_ch_two_mag) 30 | ild = torch.abs(stft_ch_one_mag) / (torch.abs(stft_ch_two_mag) + 1e-4) 31 | ild = 20 * torch.log10(ild + 1e-8) 32 | 33 | # Extract the phase from the stft data 34 | phase_ch_two = stft_ch_one[..., -mag_phase_dim:] 35 | phase_ch_one = stft_ch_two[..., -mag_phase_dim:] 36 | 37 | ipd = torch.fmod(phase_ch_two - phase_ch_one, np.pi) 38 | 39 | # Output shape of ILD and IPD = [batch_size, time_frames, mag_phase_dim, sources] 40 | return ipd, ild, vol 41 | 42 | -------------------------------------------------------------------------------- /tests/test_room_types.py: -------------------------------------------------------------------------------- 1 | # NOTE: to run, need to cd into tests/, then run pytest 2 | import sys 3 | sys.path.append("../src") 4 | 5 | from pyroomacoustics import ShoeBox, Room 6 | 7 | import room_types 8 | import constants 9 | 10 | 11 | def test_polygon_num_sides(): 12 | """ 13 | Test Polygon room class with different number of sides 14 | """ 15 | for num_sides in range(3, 11): 16 | room = room_types.Polygon(n=num_sides, r=2) 17 | points = room.generate() 18 | 19 | # only x and y-coordinates for 2d polygon 20 | assert(len(points) == 2) 21 | 22 | # create pra room 23 | pra_room = Room.from_corners(points, fs=constants.RESAMPLE_RATE) 24 | 25 | # assert center is inside of room 26 | assert(pra_room.is_inside([room.x_center, room.y_center])) 27 | 28 | # apparently polygons don't need to be convex 29 | # pra_room.convex_hull() 30 | # assert(len(pra_room.obstructing_walls) == 0) 31 | 32 | 33 | def test_shoebox(): 34 | """ 35 | Testing (our) ShoeBox room class 36 | """ 37 | # 2d shoebox room is a rectangle (4 walls) 38 | room = room_types.ShoeBox() 39 | points = room.generate() 40 | 41 | # sanity, rectangles only have two different lengths 42 | assert(len(points) == 2) 43 | 44 | # ensure class variables equal points returned by generate 45 | assert(room.x_length == points[0]) 46 | assert(room.y_length == points[1]) 47 | 48 | # not sure how this would fail 49 | assert((room.x_length * room.y_length) == (points[0] * points[1])) 50 | 51 | # create pra room 52 | pra_room = ShoeBox(points) 53 | 54 | # test whether it is a convex hull (this should be ensured by pra) 55 | pra_room.convex_hull() 56 | assert(len(pra_room.obstructing_walls) == 0) 57 | -------------------------------------------------------------------------------- /tests/test_agent.py: -------------------------------------------------------------------------------- 1 | # NOTE: to run, need to cd into tests/, then run pytest 2 | import sys 3 | sys.path.append("../src") 4 | 5 | import gym 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import torch 9 | 10 | import room_types 11 | import agent 12 | import audio_room 13 | import utils 14 | import constants 15 | import nussl 16 | from datasets import BufferData 17 | from agent import RandomAgent 18 | 19 | 20 | def test_experiment_shoebox(): 21 | """ 22 | Testing a run with ShoeBox room 23 | 24 | TODO 25 | """ 26 | # Shoebox Room 27 | room = room_types.ShoeBox(x_length=10, y_length=10) 28 | 29 | agent_loc = np.array([3, 8]) 30 | 31 | # Set up the gym environment 32 | env = gym.make( 33 | "audio-room-v0", 34 | room_config=room.generate(), 35 | agent_loc=agent_loc, 36 | corners=room.corners, 37 | max_order=10, 38 | step_size=1.0, 39 | acceptable_radius=0.8, 40 | ) 41 | 42 | # create buffer data folders 43 | utils.create_buffer_data_folders() 44 | 45 | tfm = nussl.datasets.transforms.Compose([ 46 | nussl.datasets.transforms.GetAudio(mix_key='new_state'), 47 | nussl.datasets.transforms.ToSeparationModel(), 48 | nussl.datasets.transforms.GetExcerpt(excerpt_length=32000, tf_keys=['mix_audio'], time_dim=1), 49 | ]) 50 | 51 | # create dataset object (subclass of nussl.datasets.BaseDataset) 52 | dataset = BufferData(folder=constants.DIR_DATASET_ITEMS, to_disk=True, transform=tfm) 53 | 54 | # Load the agent class 55 | a = agent.RandomAgent(env=env, dataset=dataset, episodes=2, max_steps=10, plot_reward_vs_steps=False) 56 | a.fit() 57 | 58 | # what should we assert? 59 | #assert() 60 | 61 | 62 | def test_experiment_polygon(): 63 | """ 64 | Testing a run with Polygon room 65 | 66 | TODO 67 | """ 68 | # Uncomment for Polygon Room 69 | room = room_types.Polygon(n=6, r=2, x_center=5, y_center=5) 70 | -------------------------------------------------------------------------------- /src/room_types.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | """ 4 | Currently, we are only working with rooms that are convex and 2D. 5 | """ 6 | 7 | class Polygon: 8 | def __init__(self, n, r, x_center=0, y_center=0, theta=0): 9 | """ 10 | This class represents a polygon room. 11 | pyroomacoustics.Room.from_corners() is called to create a polygon room. 12 | 13 | Args: 14 | n (int): Number of vertices of the polygon 15 | r (int): Radius of the polygon 16 | x_center (int): x coordinate of center of polygon 17 | y_center (int): y coordinate of center of polygon 18 | theta (int): polygon construction parameter 19 | """ 20 | self.n = n 21 | self.r = r 22 | self.x_center = x_center 23 | self.y_center = y_center 24 | assert(theta >= 0.0) 25 | self.theta = theta 26 | self.corners = True 27 | 28 | def generate(self): 29 | """ 30 | This function generates a polygon and returns the points 31 | 32 | Returns: 33 | x_points (List[np.array]): x-coordinates of points that makeup the room 34 | y_points (List[np.array]): y-coordinates of points that makeup the room 35 | """ 36 | numbers = np.array([i for i in range(self.n)]) 37 | 38 | x_points = ( 39 | self.r * (2 * np.cos((2 * np.pi * numbers) / self.n + self.theta)) + self.x_center 40 | ) 41 | y_points = ( 42 | self.r * (2 * np.sin((2 * np.pi * numbers) / self.n) + self.theta) + self.y_center 43 | ) 44 | 45 | return [x_points, y_points] 46 | 47 | 48 | class ShoeBox: 49 | def __init__(self, x_length=10, y_length=10): 50 | """ 51 | This class represents a shoe box (rectangular) room. It is a wrapper 52 | for the pyroomacoustics.Room.ShoeBox class. We are sticking with 53 | 2D rooms (4 walls) for now, though PRA supports 3D rooms (6 walls) 54 | 55 | Args: 56 | x_length (float): the horizontal length of the room 57 | y_length (float): the vertical length of the room 58 | """ 59 | self.x_length = x_length 60 | self.y_length = y_length 61 | self.corners = False 62 | 63 | def generate(self): 64 | """ 65 | This function generates a shoebox and returns the points 66 | 67 | Returns: 68 | List[int]: x_length and y_length 69 | """ 70 | return [self.x_length, self.y_length] 71 | -------------------------------------------------------------------------------- /experiment/allocate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Script to wait for an open GPU, then to run the job. Use with 5 | `tsp` (task-spooler) for a common workflow, by queuing a job 6 | that looks for a GPU and then runs the actual job. 7 | 8 | Make this executable: chmod +x allocate.py 9 | 10 | This stacks jobs on a single GPU if memory is available and the 11 | other process belongs to you. Otherwise, it finds a completely 12 | unused GPU to run your command on. 13 | 14 | Usage: 15 | 16 | # allocates 1 gpu for train script 17 | ./allocate.py 1 python train.py ... 18 | # allocates 2 gpus for train script 19 | ./allocate.py 2 python train.py ... 20 | 21 | Requirements: 22 | 23 | pip install nvgpu 24 | """ 25 | 26 | import subprocess 27 | import argparse 28 | import time 29 | import logging 30 | import sys 31 | import os, pwd 32 | import nvgpu 33 | from nvgpu.list_gpus import device_statuses 34 | 35 | logging.basicConfig(level=logging.INFO) 36 | mem_threshold = 50 37 | 38 | def run(cmd): 39 | print(cmd) 40 | subprocess.run([cmd], shell=True) 41 | 42 | def _allocate_gpu(num_gpus): 43 | current_user = pwd.getpwuid(os.getuid()).pw_name 44 | gpu_info = nvgpu.gpu_info() 45 | device_info = device_statuses() 46 | 47 | # assume nothing is available 48 | completely_available = [False for _ in gpu_info] 49 | same_user_available = [False for _ in gpu_info] 50 | 51 | for i, (_info, _device) in enumerate(zip(gpu_info, device_info)): 52 | completely_available[i] = _device['is_available'] 53 | if _info['mem_used_percent'] < mem_threshold and current_user in _device['users']: 54 | same_user_available[i] = True 55 | 56 | available_gpus = same_user_available 57 | if sum(same_user_available) == 0: 58 | available_gpus = completely_available 59 | 60 | available_gpus = [i for i, val in enumerate(available_gpus) if val] 61 | 62 | return available_gpus[:num_gpus] 63 | 64 | if __name__ == "__main__": 65 | args = sys.argv 66 | 67 | num_gpus = int(sys.argv[1]) 68 | cmd = sys.argv[2:] 69 | 70 | available_gpus = _allocate_gpu(num_gpus) 71 | 72 | while len(available_gpus) < num_gpus: 73 | logging.info("Waiting for available GPUs. Checking again in 30 seconds.") 74 | available_gpus = _allocate_gpu(num_gpus) 75 | time.sleep(30) 76 | 77 | available_gpus = ','.join(map(str, available_gpus)) 78 | CUDA_VISIBLE_DEVICES = f'CUDA_VISIBLE_DEVICES={available_gpus}' 79 | cmd = ' '.join(cmd) 80 | cmd = f"{CUDA_VISIBLE_DEVICES} {cmd}" 81 | run(cmd) 82 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # vscode stuff 132 | .vscode/ 133 | 134 | # dstore 135 | *.DS_Store 136 | 137 | # ide stuff 138 | .idea/ 139 | 140 | # nussl tutorial 141 | nussl-tutorial.ipynb 142 | refactor-test.ipynb 143 | 144 | # data 145 | data/ 146 | experiment/runs/ 147 | models/ 148 | sounds/dry_recordings 149 | 150 | *tfevents* 151 | -------------------------------------------------------------------------------- /tests/test_transforms.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("../src") 3 | 4 | from pyroomacoustics import ShoeBox, Room 5 | 6 | 7 | import gym 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | import torch 11 | import room_types 12 | import agent 13 | import audio_room 14 | import utils 15 | import constants 16 | import nussl 17 | from datasets import BufferData 18 | import time 19 | import audio_processing 20 | from models import RnnAgent 21 | import transforms 22 | from transforms import GetExcerpt 23 | 24 | """ 25 | Takes too long to run and isn't a great test. Assertion will be added in GetExcerpt to ensure 26 | the mixes for the prev and new state are the same length 27 | """ 28 | 29 | def test_mix_lengths(): 30 | pass 31 | # # Shoebox Room 32 | # room = room_types.ShoeBox(x_length=10, y_length=10) 33 | 34 | # # Uncomment for Polygon Room 35 | # # room = room_types.Polygon(n=6, r=2, x_center=5, y_center=5) 36 | 37 | # agent_loc = np.array([3, 8]) 38 | 39 | # # Set up the gym environment 40 | # env = gym.make( 41 | # "audio-room-v0", 42 | # room_config=room.generate(), 43 | # agent_loc=agent_loc, 44 | # corners=room.corners, 45 | # max_order=10, 46 | # step_size=1.0, 47 | # acceptable_radius=0.8, 48 | # ) 49 | 50 | # # create buffer data folders 51 | # utils.create_buffer_data_folders() 52 | 53 | # tfm = transforms.Compose([ 54 | # transforms.GetAudio(mix_key=['prev_state', 'new_state']), 55 | # transforms.ToSeparationModel(), 56 | # transforms.GetExcerpt(excerpt_length=32000, 57 | # tf_keys=['mix_audio_prev_state', 'mix_audio_new_state'], time_dim=1), 58 | # ]) 59 | 60 | # dataset = BufferData(folder=constants.DIR_DATASET_ITEMS, to_disk=False, transform=tfm) 61 | 62 | # # run really short experiment to generate data 63 | # # Define the relevant dictionaries 64 | # env_config = {'env': env, 'dataset': dataset, 'episodes': 3, 'max_steps': 3, 'plot_reward_vs_steps': False, 65 | # 'stable_update_freq': 3, 'epsilon': 0.7, 'save_freq': 1} 66 | # dataset_config = {'batch_size': 3, 'num_updates': 2, 'save_path': '../models/'} 67 | # rnn_agent = RnnAgent(env_config=env_config, dataset_config=dataset_config) 68 | 69 | # rnn_agent.fit() 70 | 71 | # # test mix lengths, want to be equal (THIS MAY TAKE A WHILE) 72 | # # this test is not perfect, may get lucky and have all dataset items be same length, but it's already expensive to compute 73 | # for t in tfm.transforms: 74 | # lengths = [] 75 | # print(len(dataset)) 76 | # for i in range(len(dataset)): 77 | # if isinstance(t, GetExcerpt): 78 | # data = t(dataset[i]) 79 | # for k, v in data.items(): 80 | # if k in t.time_frequency_keys: 81 | # lengths.append(v.size()) 82 | # print(lengths) 83 | # assert(lengths[0] == lengths[1]) 84 | 85 | # test_mix_lengths() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

OtoWorld

2 | 3 | OtoWorld is an interactive environment in which agents must learn to listen in order to solve navigational tasks. The purpose of OtoWorld is to facilitate reinforcement learning research in computer audition, where agents must learn to listen to the world around them to navigate. 4 | 5 | **Note:** Currently the focus is on audio source separation. 6 | 7 | OtoWorld is built on three open source libraries: OpenAI [`gym`](https://gym.openai.com/) for environment and agent interaction, [`pyroomacoustics`](https://github.com/LCAV/pyroomacoustics) for ray-tracing and acoustics simulation, and [`nussl`](https://github.com/nussl/nussl) for training deep computer audition models. OtoWorld is the audio analogue of GridWorld, a simple navigation game. OtoWorld can be easily extended to more complex environments and games. 8 | 9 | To solve one episode of OtoWorld, an agent must move towards each sounding source in the auditory scene and "turn it off". The agent receives no other input than the current sound of the room. The sources are placed randomly within the room and can vary in number. The agent receives a reward for turning off a source. 10 | 11 | [Read the OtoWorld Paper here](https://arxiv.org/abs/2007.06123) 12 |
13 | 14 | 15 | ![OtoWorld Environment](otoworld.png) 16 | 17 | 18 | ## Installation 19 | Clone the repository 20 | ``` 21 | git clone https://github.com/pseeth/otoworld.git 22 | ``` 23 | Create a conda environment: 24 | ``` 25 | conda create -n otoworld python==3.7 26 | ``` 27 | Activate the environment: 28 | ``` 29 | conda activate otoworld 30 | ``` 31 | Install requirements: 32 | ``` 33 | pip install -r requirements.txt 34 | ``` 35 | Install ffmpeg from conda distribution (Note: Pypi distribution of ffmpeg is outdated): 36 | ``` 37 | conda install ffmpeg 38 | ``` 39 | If using a **CUDA-enabled GPU (highly recommended)**, install Pytorch `1.4` from official source: 40 | ``` 41 | pip install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html 42 | ``` 43 | otherwise: 44 | ``` 45 | pip install torch==1.4.0 torchvision==0.5.0 46 | ``` 47 | 48 | You may ignore warnings if about certain dependencies for now. 49 | 50 | 51 | ## Additional Installation Notes - Linux 52 | * Linux users may need to install the sound file library if it is not present in the system. It can be done using the following command: 53 | ``` 54 | sudo apt-get install libsndfile1 55 | ``` 56 | 57 | This should take care of a common `musdb` error. 58 | 59 | ## Demo and Tutorial 60 | You can get familiar with OtoWorld using our tutorial notebook: [Tutorial Notebook](https://github.com/pseeth/otoworld/blob/master/notebooks/tutorial.ipynb). 61 | 62 | Run 63 | ``` 64 | jupyter notebook 65 | ``` 66 | and navigate to `notebooks/tutorial.ipynb`. 67 | 68 | ## Experiments 69 | You can view (and run) examples of experiments: 70 | ``` 71 | cd experiments/ 72 | 73 | python experiment1.py 74 | ``` 75 | 76 | Please create your own experiments and see if you can win OtoWorld! You will need a GPU running CUDA to be able to perform any meaningful experiments. 77 | 78 | ## Is It Running Properly? 79 | You should a message indicating the experiment is running, such as this: 80 | ``` 81 | ------------------------------ 82 | - Starting to Fit Agent 83 | ------------------------------- 84 | ``` 85 | 86 | You may get a warning about `SoX`. Ignore this for now. You're good to go! 87 | 88 | ## Citing 89 | ``` 90 | @inproceedings {otoworld 91 | author = {Omkar Ranadive and Grant Gasser and David Terpay and Prem Seetharaman}, 92 | title = "OtoWorld: Towards Learning to Separate by Learning to Move", 93 | journal = "Self Supervision in Audio and Speech Workshop, 37th International Conference on Machine Learning ({ICML} 2020), Vienna, Austria", 94 | year = 2020 95 | } 96 | ``` 97 | -------------------------------------------------------------------------------- /experiment/experiment1.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("../src/") 3 | from datetime import datetime 4 | import os 5 | 6 | import gym 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import torch 10 | import logging 11 | from torch.utils.tensorboard import SummaryWriter 12 | 13 | import room_types 14 | import agent 15 | import audio_room 16 | import utils 17 | import constants 18 | import nussl 19 | from datasets import BufferData 20 | import time 21 | import audio_processing 22 | from models import RnnAgent 23 | import transforms 24 | 25 | import warnings 26 | warnings.filterwarnings("ignore") 27 | 28 | """ 29 | One of our main experiments for OtoWorld introductory paper 30 | """ 31 | 32 | # Shoebox Room 33 | nussl.utils.seed(0) 34 | room = room_types.ShoeBox(x_length=8, y_length=8) 35 | 36 | # Uncomment for Polygon Room 37 | #room = room_types.Polygon(n=6, r=2, x_center=5, y_center=5) 38 | 39 | 40 | source_folders_dict = {'../sounds/phone/': 1, 41 | '../sounds/siren/': 1} 42 | 43 | # Set up the gym environment 44 | env = gym.make( 45 | "audio-room-v0", 46 | room_config=room.generate(), 47 | source_folders_dict=source_folders_dict, 48 | corners=room.corners, 49 | max_order=10, 50 | step_size=.5, 51 | acceptable_radius=1.0, 52 | absorption=1.0, 53 | ) 54 | env.seed(0) 55 | 56 | # create buffer data folders 57 | utils.create_buffer_data_folders() 58 | 59 | # fixing lengths 60 | tfm = transforms.Compose([ 61 | transforms.GetAudio(mix_key=['prev_state', 'new_state']), 62 | transforms.ToSeparationModel(), 63 | transforms.GetExcerpt(excerpt_length=32000, 64 | tf_keys=['mix_audio_prev_state'], time_dim=1), 65 | transforms.GetExcerpt(excerpt_length=32000, 66 | tf_keys=['mix_audio_new_state'], time_dim=1) 67 | ]) 68 | 69 | # create dataset object (subclass of nussl.datasets.BaseDataset) 70 | dataset = BufferData( 71 | folder=constants.DIR_DATASET_ITEMS, 72 | to_disk=True, 73 | transform=tfm 74 | ) 75 | 76 | # define tensorboard writer, name the experiment! 77 | exp_name = 'pretrain-150eps' 78 | exp_id = '{}_{}'.format(exp_name, datetime.now().strftime('%d_%m_%Y-%H_%M_%S')) 79 | writer = SummaryWriter('runs/{}'.format(exp_id)) 80 | 81 | 82 | # Define the relevant dictionaries 83 | env_config = { 84 | 'env': env, 85 | 'dataset': dataset, 86 | 'episodes': 150, 87 | 'max_steps': 1000, 88 | 'stable_update_freq': 150, 89 | 'save_freq': 1, 90 | 'play_audio': False, 91 | 'show_room': False, 92 | 'writer': writer, 93 | 'dense': True, 94 | 'decay_rate': 0.0002, # trial and error 95 | 'decay_per_ep': True 96 | } 97 | 98 | save_path = os.path.join(constants.MODEL_SAVE_PATH, exp_name) 99 | dataset_config = { 100 | 'batch_size': 10, 101 | 'num_updates': 2, 102 | 'save_path': save_path 103 | } 104 | 105 | # clear save_path folder for each experiment 106 | utils.clear_models_folder(save_path) 107 | 108 | rnn_config = { 109 | 'bidirectional': True, 110 | 'dropout': 0.3, 111 | 'filter_length': 256, 112 | 'hidden_size': 50, 113 | 'hop_length': 64, 114 | 'mask_activation': ['softmax'], 115 | 'mask_complex': False, 116 | 'mix_key': 'mix_audio', 117 | 'normalization_class': 'BatchNorm', 118 | 'num_audio_channels': 1, 119 | 'num_filters': 256, 120 | 'num_layers': 1, 121 | 'num_sources': 2, 122 | 'rnn_type': 'lstm', 123 | 'window_type': 'sqrt_hann', 124 | } 125 | 126 | stft_config = { 127 | 'hop_length': 64, 128 | 'num_filters': 256, 129 | 'direction': 'transform', 130 | 'window_type': 'sqrt_hann' 131 | } 132 | 133 | rnn_agent = RnnAgent( 134 | env_config=env_config, 135 | dataset_config=dataset_config, 136 | rnn_config=rnn_config, 137 | stft_config=stft_config, 138 | learning_rate=.001, 139 | pretrained=True 140 | ) 141 | torch.autograd.set_detect_anomaly(True) 142 | rnn_agent.fit() 143 | 144 | -------------------------------------------------------------------------------- /experiment/experiment2.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("../src/") 3 | from datetime import datetime 4 | import os 5 | 6 | import gym 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import torch 10 | import logging 11 | from torch.utils.tensorboard import SummaryWriter 12 | 13 | import room_types 14 | import agent 15 | import audio_room 16 | import utils 17 | import constants 18 | import nussl 19 | from datasets import BufferData 20 | import time 21 | import audio_processing 22 | from models import RnnAgent 23 | import transforms 24 | 25 | import warnings 26 | warnings.filterwarnings("ignore") 27 | 28 | """ 29 | One of our main experiments for OtoWorld introductory paper 30 | """ 31 | 32 | # Shoebox Room 33 | nussl.utils.seed(7) 34 | room = room_types.ShoeBox(x_length=8, y_length=8) 35 | 36 | # Uncomment for Polygon Room 37 | #room = room_types.Polygon(n=6, r=2, x_center=5, y_center=5) 38 | 39 | source_folders_dict = {'../sounds/phone/': 1, 40 | '../sounds/siren/': 1} 41 | 42 | # Set up the gym environment 43 | env = gym.make( 44 | "audio-room-v0", 45 | room_config=room.generate(), 46 | source_folders_dict=source_folders_dict, 47 | corners=room.corners, 48 | max_order=10, 49 | step_size=1.0, 50 | acceptable_radius=1.0, 51 | absorption=1.0, 52 | reset_sources=False, 53 | same_config=True 54 | ) 55 | env.seed(7) 56 | 57 | # create buffer data folders 58 | utils.create_buffer_data_folders() 59 | 60 | # fixing lengths 61 | tfm = transforms.Compose([ 62 | transforms.GetAudio(mix_key=['prev_state', 'new_state']), 63 | transforms.ToSeparationModel(), 64 | transforms.GetExcerpt(excerpt_length=32000, 65 | tf_keys=['mix_audio_prev_state'], time_dim=1), 66 | transforms.GetExcerpt(excerpt_length=32000, 67 | tf_keys=['mix_audio_new_state'], time_dim=1) 68 | ]) 69 | 70 | # create dataset object (subclass of nussl.datasets.BaseDataset) 71 | dataset = BufferData( 72 | folder=constants.DIR_DATASET_ITEMS, 73 | to_disk=True, 74 | transform=tfm 75 | ) 76 | 77 | # define tensorboard writer, name the experiment! 78 | exp_name = 'evaluate' 79 | exp_id = '{}_{}'.format(exp_name, datetime.now().strftime('%d_%m_%Y-%H_%M_%S')) 80 | writer = SummaryWriter('runs/{}'.format(exp_id)) 81 | 82 | # Define the relevant dictionaries 83 | env_config = { 84 | 'env': env, 85 | 'dataset': dataset, 86 | 'episodes': 200, 87 | 'max_steps': 1000, 88 | 'stable_update_freq': 200, 89 | 'save_freq': 2, 90 | 'play_audio': False, 91 | 'show_room': False, 92 | 'writer': writer, 93 | 'dense': True, 94 | 'decay_rate': 0.01, # trial and error 95 | 'decay_per_ep': True, 96 | 'validation_freq': 5 97 | } 98 | 99 | if env_config['decay_per_ep']: 100 | end_epsilon = constants.MIN_EPSILON + (constants.MAX_EPSILON - constants.MIN_EPSILON) * np.exp(-env_config['decay_rate'] * env_config['episodes']) 101 | print('\nEpsilon value at last episode ({}): {}'.format(env_config['episodes'], end_epsilon)) 102 | 103 | save_path = os.path.join(constants.MODEL_SAVE_PATH, exp_name) 104 | dataset_config = { 105 | 'batch_size': 50, 106 | 'num_updates': 2, 107 | 'save_path': save_path 108 | } 109 | 110 | # clear save_path folder for each experiment 111 | utils.clear_models_folder(save_path) 112 | 113 | rnn_config = { 114 | 'bidirectional': True, 115 | 'dropout': 0.3, 116 | 'filter_length': 256, 117 | 'hidden_size': 50, 118 | 'hop_length': 64, 119 | 'mask_activation': ['softmax'], 120 | 'mask_complex': False, 121 | 'mix_key': 'mix_audio', 122 | 'normalization_class': 'BatchNorm', 123 | 'num_audio_channels': 1, 124 | 'num_filters': 256, 125 | 'num_layers': 1, 126 | 'num_sources': 2, 127 | 'rnn_type': 'lstm', 128 | 'window_type': 'sqrt_hann', 129 | } 130 | 131 | stft_config = { 132 | 'hop_length': 64, 133 | 'num_filters': 256, 134 | 'direction': 'transform', 135 | 'window_type': 'sqrt_hann' 136 | } 137 | 138 | rnn_agent = RnnAgent( 139 | env_config=env_config, 140 | dataset_config=dataset_config, 141 | rnn_config=rnn_config, 142 | stft_config=stft_config, 143 | learning_rate=.01, 144 | ) 145 | torch.autograd.set_detect_anomaly(True) 146 | rnn_agent.fit() 147 | -------------------------------------------------------------------------------- /src/plot_runs.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import requests 3 | import pandas as pd 4 | import seaborn as sns 5 | import re 6 | 7 | 8 | port_number = 6006 9 | path = '../models/evaluation_data/' 10 | sns.set(style="whitegrid") 11 | dpi = 300 12 | 13 | 14 | def create_csv(file_name): 15 | # Form two urls - for mean reward and cumulative reward 16 | 17 | csv_url_cumul = "http://localhost:{}/data/plugin/scalars/scalars?tag=Reward%2Fcumulative&run={}&format=csv".format( 18 | port_number, file_name 19 | ) 20 | csv_url_mean = "http://localhost:{}/data/plugin/scalars/scalars?tag=Reward%2Fmean_per_episode&run={}&format=csv".format( 21 | port_number, file_name 22 | ) 23 | 24 | # Create CSV for mean rewards 25 | req = requests.get(csv_url_mean) 26 | url_content = req.content 27 | csv_file = open('{}.csv.'.format(path+file_name+'_mean'), 'wb') 28 | csv_file.write(url_content) 29 | csv_file.close() 30 | 31 | # Create csv for cumulative rewards 32 | req = requests.get(csv_url_cumul) 33 | url_content = req.content 34 | csv_file = open('{}.csv.'.format(path + file_name+'_cumul'), 'wb') 35 | csv_file.write(url_content) 36 | csv_file.close() 37 | 38 | 39 | def create_plots_single(file_name): 40 | 41 | mean_rewards = pd.read_csv(path+file_name+'_mean.csv') 42 | cumul_rewards = pd.read_csv(path+file_name+'_cumul.csv') 43 | 44 | sns_plot = sns.relplot(x="Step", y="Value", data=cumul_rewards, kind="line") 45 | sns_plot.set(xlabel='Step', ylabel='Cumulative Reward') 46 | sns_plot.savefig(path+file_name+'_cumul.png', dpi=dpi) 47 | 48 | sns_plot = sns.relplot(x="Step", y="Value", data=mean_rewards, kind="line") 49 | sns_plot.set(xlabel='Step', ylabel='Mean Reward') 50 | sns_plot.savefig(path+file_name+'_mean.png', dpi=dpi) 51 | 52 | 53 | def create_plots_multiple(file_names): 54 | 55 | combined_data_mean = pd.DataFrame(columns=['Wall time', 'Step', 'Value', 'Number']) 56 | combined_data_cumul = pd.DataFrame(columns=['Wall time', 'Step', 'Value', 'Number']) 57 | for i, file_name in enumerate(file_names): 58 | mean_rewards = pd.read_csv(path+file_name+'_mean.csv') 59 | cumul_rewards = pd.read_csv(path+file_name+'_cumul.csv') 60 | mean_rewards['Number'] = i 61 | cumul_rewards['Number'] = i 62 | combined_data_mean = pd.concat([combined_data_mean, mean_rewards]) 63 | combined_data_cumul = pd.concat([combined_data_cumul, cumul_rewards]) 64 | 65 | sns_plot = sns.relplot(x="Step", y="Value", data=combined_data_mean, kind="line", hue="Number") 66 | sns_plot.set(xlabel='Step', ylabel='Mean Reward') 67 | 68 | 69 | sns_plot.savefig(path + '_mean_combined.png', dpi=dpi) 70 | 71 | sns_plot = sns.relplot(x="Step", y="Value", data=combined_data_cumul, kind="line", hue="Number") 72 | sns_plot.set(xlabel='Step', ylabel='Cumulative Reward') 73 | sns_plot.savefig(path+'_cumul-combined.png', dpi=dpi) 74 | 75 | 76 | def generate_data_from_log(file_name): 77 | steps_second = [] 78 | finished_steps = [] 79 | with open(file_name, 'r') as f: 80 | lines = f.readlines() 81 | counter = 0 82 | while counter < len(lines): 83 | cur_line = lines[counter] 84 | if 'Episode: ' in cur_line: 85 | ep_num = int(re.findall(r'\d+', cur_line)[0]) 86 | # Skip validation episodes 87 | if ep_num % 5 == 0: 88 | pass 89 | else: 90 | # Grab the data 91 | counter += 2 92 | cur_line = lines[counter] 93 | finished_count = int(re.findall(r'\d+', cur_line)[0]) 94 | finished_steps.append(finished_count) 95 | counter += 2 96 | cur_line = lines[counter] 97 | sps = float(re.findall(r'\d+\.\d+', cur_line)[0]) 98 | steps_second.append(sps) 99 | counter += 1 100 | 101 | print(len(finished_steps), len(steps_second)) 102 | print(finished_steps) 103 | print(steps_second) 104 | 105 | return finished_steps, steps_second 106 | 107 | 108 | if __name__ == '__main__': 109 | file_names = ["test-exp-5-50eps_test_simp_env_validation-2_15_06_2020-02_33_50", 110 | "exp5-200eps-final-run_15_06_2020-06_30_00"] 111 | 112 | # for file_name in file_names: 113 | # create_csv(file_name=file_name) 114 | # create_plots_single(file_name) 115 | # 116 | # create_plots_multiple(file_names) 117 | 118 | generate_data_from_log(file_name=path+'run.txt') -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import shutil 6 | import random 7 | import torch 8 | 9 | import constants 10 | 11 | def autoclip(model, percentile, grad_norms=None): 12 | if grad_norms is None: 13 | grad_norms = [] 14 | 15 | def _get_grad_norm(model): 16 | total_norm = 0 17 | for p in model.parameters(): 18 | if p.requires_grad and p.grad is not None: 19 | param_norm = p.grad.data.norm(2) 20 | total_norm += param_norm.item() ** 2 21 | total_norm = total_norm ** (1. / 2) 22 | return total_norm 23 | 24 | grad_norms.append(_get_grad_norm(model)) 25 | clip_value = np.percentile(grad_norms, percentile) 26 | 27 | torch.nn.utils.clip_grad_norm_( 28 | model.parameters(), clip_value) 29 | return grad_norms 30 | 31 | 32 | def choose_random_files(source_folders_dict): 33 | """ 34 | Function returns random source files from provided folders. 35 | 36 | Args: 37 | source_folders_dict (Dict[str, int]): specify how many source files to choose from each folder 38 | e.g. 39 | { 40 | 'car_horn_source_folder': 1, 41 | 'phone_ringing_source_folder': 1 42 | } 43 | 44 | This would choose 1 source file from each folder 45 | Returns: 46 | paths (List[str]): the paths to two wav files 47 | """ 48 | paths = [] 49 | 50 | for folder, num_sources in source_folders_dict.items(): 51 | files = os.listdir(folder) 52 | 53 | source_files = [] 54 | random_indices = np.random.permutation(len(files)) 55 | for i in random_indices: 56 | if files[i].endswith(constants.AUDIO_EXTENSION): 57 | source_files.append(os.path.join(folder, files[i])) 58 | 59 | if len(source_files) == num_sources: 60 | break 61 | 62 | paths.extend(source_files) 63 | 64 | return paths 65 | 66 | 67 | def log_dist_and_num_steps(init_dist_to_target, steps_to_completion): 68 | """ 69 | This function logs the initial distance between agent and target source and number of steps 70 | taken to reach target source. The lists are stored in pickle files. The pairs (dist, steps) are in parallel 71 | lists, indexed by the episode number. 72 | 73 | Args: 74 | init_dist_to_target (List[float]): initial distance between agent and target src (size is number of episodes) 75 | steps_to_completion (List[int]): number of steps it took for agent to get to source 76 | """ 77 | # create data folder 78 | if not os.path.exists(constants.DATA_PATH): 79 | os.makedirs(constants.DATA_PATH) 80 | 81 | # write objects 82 | pickle.dump( 83 | init_dist_to_target, open(os.path.join(constants.DATA_PATH, constants.DIST_URL), "wb"), 84 | ) 85 | pickle.dump( 86 | steps_to_completion, open(os.path.join(constants.DATA_PATH, constants.STEPS_URL), "wb"), 87 | ) 88 | 89 | 90 | def log_reward_vs_steps(rewards_per_episode): 91 | """ 92 | This function logs the rewards per episode in order to plot the rewards vs. step for each episode. 93 | The lists are stored in pickle files. The pairs (dist, steps) are in parallel lists, indexed by the 94 | episode number. 95 | 96 | Args: 97 | rewards_per_episode (List[float]): rewards gained per episode 98 | """ 99 | # create data folder 100 | if not os.path.exists(constants.DATA_PATH): 101 | os.makedirs(constants.DATA_PATH) 102 | 103 | # write objects 104 | pickle.dump( 105 | rewards_per_episode, open(os.path.join( 106 | constants.DATA_PATH, constants.REWARD_URL), "wb"), 107 | ) 108 | 109 | 110 | def plot_reward_vs_steps(): 111 | """ 112 | Plots the reward vs step for an episode. 113 | """ 114 | with open(os.path.join(constants.DATA_PATH, constants.REWARD_URL), "rb") as f: 115 | rewards = pickle.load(f) 116 | 117 | reward = rewards[0] 118 | plt.scatter(list(range(len(reward))), reward) 119 | plt.title("Reward vs. Number of Steps") 120 | plt.xlabel("Step") 121 | plt.ylabel("Reward") 122 | 123 | plt.show() 124 | 125 | def plot_dist_and_steps(): 126 | """Plots initial distance and number of steps to reach target""" 127 | with open(os.path.join(constants.DATA_PATH, constants.DIST_URL), "rb") as f: 128 | dist = pickle.load(f) 129 | avg_dist = np.mean(dist) 130 | 131 | with open(os.path.join(constants.DATA_PATH, constants.STEPS_URL), "rb") as f: 132 | steps = pickle.load(f) 133 | avg_steps = np.mean(steps) 134 | 135 | plt.scatter(dist, np.log(steps)) 136 | plt.title("Number of Steps and Initial Distance") 137 | plt.xlabel("Euclidean Distance") 138 | plt.ylabel("Log(# of Steps to Reach Target)") 139 | plt.text( 140 | 5, 141 | 600, 142 | "Avg Steps: " + str(int(avg_steps)), 143 | size=15, 144 | rotation=0.0, 145 | ha="right", 146 | va="top", 147 | bbox=dict(boxstyle="square", ec=(1.0, 0.5, 0.5), fc=(1.0, 0.8, 0.8),), 148 | ) 149 | plt.text( 150 | 5, 151 | 500, 152 | "Avg Init Dist: " + str(int(avg_dist)), 153 | size=15, 154 | rotation=0.0, 155 | ha="right", 156 | va="top", 157 | bbox=dict(boxstyle="square", ec=(1.0, 0.5, 0.5), fc=(1.0, 0.8, 0.8),), 158 | ) 159 | 160 | plt.show() 161 | 162 | 163 | def create_buffer_data_folders(): 164 | """Empty and re-create the buffer data folders""" 165 | if os.path.exists(constants.DIR_PREV_STATES): 166 | shutil.rmtree(constants.DIR_PREV_STATES) 167 | os.makedirs(constants.DIR_PREV_STATES) 168 | if os.path.exists(constants.DIR_NEW_STATES): 169 | shutil.rmtree(constants.DIR_NEW_STATES) 170 | os.makedirs(constants.DIR_NEW_STATES) 171 | if os.path.exists(constants.DIR_DATASET_ITEMS): 172 | shutil.rmtree(constants.DIR_DATASET_ITEMS) 173 | os.makedirs(constants.DIR_DATASET_ITEMS) 174 | 175 | 176 | def clear_models_folder(save_path): 177 | """Empty/clear and re-create save_path folder to save models. 178 | 179 | Args: 180 | save_path (str): path to folder where models will be stored 181 | """ 182 | if os.path.exists(save_path): 183 | shutil.rmtree(save_path) 184 | os.makedirs(save_path) 185 | 186 | 187 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gym 3 | import numpy as np 4 | import torch 5 | import logging 6 | import agent 7 | import utils 8 | import constants 9 | import nussl 10 | import audio_processing 11 | import agent 12 | from datasets import BufferData, RLDataset 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | import math 17 | 18 | 19 | class RnnAgent(agent.AgentBase): 20 | def __init__(self, env_config, dataset_config, rnn_config=None, stft_config=None, 21 | verbose=False, autoclip_percentile=10, learning_rate=.001, pretrained=False): 22 | """ 23 | Args: 24 | env_config (dict): Dictionary containing the audio environment config 25 | dataset_config (dict): Dictionary consisting of dataset related parameters. List of parameters 26 | 'batch_size' : Denotes the batch size of the samples 27 | 'num_updates': Amount of iterations we run the training for in each pass. 28 | Ex - If num_updates = 5 and batch_size = 25, then we run the update process 5 times where in each run we 29 | sample 25 data points. 30 | 'sampler': The sampler to use. Ex - Weighted Sampler, Batch sampler etc 31 | 32 | rnn_config (dict): Dictionary containing the parameters for the model 33 | stft_config (dict): Dictionary containing the parameters for STFT 34 | """ 35 | 36 | # Select device 37 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 38 | print('DEVICE:', self.device) 39 | # Use default config if configs are not provided by user 40 | if rnn_config is None: 41 | self.rnn_config = nussl.ml.networks.builders.build_recurrent_end_to_end( 42 | bidirectional=True, dropout=0.3, filter_length=256, hidden_size=300, 43 | hop_length=64, mask_activation=['sigmoid'], mask_complex=False, mix_key='mix_audio', 44 | normalization_class='BatchNorm', num_audio_channels=1, num_filters=256, 45 | num_layers=2, num_sources=2, rnn_type='lstm', trainable=False, window_type='sqrt_hann' 46 | ) 47 | else: 48 | self.rnn_config = nussl.ml.networks.builders.build_recurrent_end_to_end(**rnn_config) 49 | 50 | if stft_config is None: 51 | self.stft_diff = nussl.ml.networks.modules.STFT( 52 | hop_length=128, filter_length=512, 53 | direction='transform', num_filters=512 54 | ) 55 | else: 56 | self.stft_diff = nussl.ml.networks.modules.STFT(**stft_config) 57 | 58 | # Initialize the Agent Base class 59 | super().__init__(**env_config) 60 | 61 | # Uncomment this to find backprop errors 62 | # torch.autograd.set_detect_anomaly(True) 63 | 64 | # Initialize the rnn model 65 | self.rnn_model = RnnSeparator(self.rnn_config).to(self.device) 66 | self.rnn_model_stable = RnnSeparator(self.rnn_config).to(self.device) 67 | 68 | # Load pretrained model 69 | if pretrained: 70 | model_dict = torch.load(constants.PRETRAIN_PATH) 71 | self.rnn_model.rnn_model.load_state_dict(model_dict) 72 | 73 | # Initialize dataset related parameters 74 | self.bs = dataset_config['batch_size'] 75 | self.num_updates = dataset_config['num_updates'] 76 | self.dynamic_dataset = RLDataset(buffer=self.dataset, sample_size=self.bs) 77 | self.dataloader = torch.utils.data.DataLoader(self.dynamic_dataset, batch_size=self.bs) 78 | 79 | # Initialize network layers for DQN network 80 | filter_length = (stft_config['num_filters'] // 2 + 1) * 2 if stft_config is not None else 514 81 | total_actions = self.env.action_space.n 82 | network_params = {'filter_length': filter_length, 'total_actions': total_actions, 'stft_diff': self.stft_diff} 83 | self.q_net = DQN(network_params).to(self.device) 84 | self.q_net_stable = DQN(network_params).to(self.device) # Fixed Q net 85 | 86 | # Tell optimizer which parameters to learn 87 | if pretrained: 88 | # Freezing the rnn model weights, only optimizing Q-net 89 | params = self.q_net.parameters() 90 | else: 91 | params = list(self.rnn_model.parameters()) + list(self.q_net.parameters()) 92 | self.optimizer = optim.Adam(params, lr=learning_rate) 93 | 94 | # Folder path where the model will be saved 95 | self.SAVE_PATH = dataset_config['save_path'] 96 | self.grad_norms = None 97 | self.percentile = autoclip_percentile 98 | 99 | def update(self): 100 | """ 101 | Runs the main training pipeline. Sends mix to RNN separator, then to the DQN. 102 | Calculates the q-values and the expected q-values, comparing them to get the loss and then 103 | computes the gradient w.r.t to the entire differentiable pipeline. 104 | """ 105 | # Run the update only if samples >= batch_size 106 | if len(self.dataset.items) < self.bs: 107 | return 108 | 109 | for index, data in enumerate(self.dataloader): 110 | if index > self.num_updates: 111 | break 112 | 113 | # Get the total number of time steps 114 | total_time_steps = data['mix_audio_prev_state'].shape[-1] 115 | 116 | # Reshape the mixture to pass through the separation model (Convert dual channels into one) 117 | # Also, rename the state to work on to mix_audio so that it can pass through remaining nussl architecture 118 | # Move to GPU 119 | data['mix_audio'] = data['mix_audio_prev_state'].float().view( 120 | -1, 1, total_time_steps).to(self.device) 121 | data['action'] = data['action'].to(self.device) 122 | data['reward'] = data['reward'].to(self.device) 123 | agent_info = data['agent_info'].to(self.device) 124 | 125 | # Get the separated sources by running through RNN separation model 126 | output = self.rnn_model(data) 127 | output['mix_audio'] = data['mix_audio'] 128 | 129 | # Pass then through the DQN model to get q-values 130 | q_values = self.q_net(output, agent_info, total_time_steps) 131 | q_values = q_values.gather(1, data['action']) 132 | 133 | with torch.no_grad(): 134 | # Now, get q-values for the next-state 135 | # Get the total number of time steps 136 | total_time_steps = data['mix_audio_new_state'].shape[-1] 137 | 138 | # Reshape the mixture to pass through the separation model (Convert dual channels into one) 139 | data['mix_audio'] = data['mix_audio_new_state'].float().view( 140 | -1, 1, total_time_steps).to(self.device) 141 | stable_output = self.rnn_model_stable(data) 142 | stable_output['mix_audio'] = data['mix_audio'] 143 | q_values_next = self.q_net_stable(stable_output, agent_info, total_time_steps).max(1)[0].unsqueeze(-1) 144 | 145 | expected_q_values = data['reward'] + self.gamma * q_values_next 146 | 147 | # Calculate loss 148 | loss = F.l1_loss(q_values, expected_q_values) 149 | self.losses.append(loss) 150 | self.writer.add_scalar('Loss/train', loss, len(self.losses)) 151 | 152 | # Optimize the model with backprop 153 | self.optimizer.zero_grad() 154 | loss.backward() 155 | 156 | # Applying AutoClip 157 | self.grad_norms = utils.autoclip(self.rnn_model, self.percentile, self.grad_norms) 158 | 159 | # Stepping optimizer 160 | self.optimizer.step() 161 | 162 | def choose_action(self): 163 | """ 164 | Runs a forward pass though the RNN separator and then the Q-network. An action is choosen 165 | by taking the argmax of the output vector of the network, where the output is a 166 | probability distribution over the action space (via softmax). 167 | 168 | Returns: 169 | action (int): the argmax of the q-values vector 170 | """ 171 | with torch.no_grad(): 172 | # Get the latest state from the buffer 173 | data = self.dataset[self.dataset.last_ptr] 174 | 175 | # Perform the forward pass (RNN separator => DQN) 176 | total_time_steps = data['mix_audio_new_state'].shape[-1] 177 | data['mix_audio'] = data['mix_audio_new_state'].float().view( 178 | -1, 1, total_time_steps).to(self.device) 179 | output = self.rnn_model(data) 180 | output['mix_audio'] = data['mix_audio'] 181 | agent_info = data['agent_info'].to(self.device) 182 | 183 | # action = argmax(q-values) 184 | q_values = self.q_net(output, agent_info, total_time_steps) 185 | action = q_values.max(1)[1].unsqueeze(-1) 186 | action = action[0].item() 187 | 188 | action = int(action) 189 | 190 | return action 191 | 192 | def update_stable_networks(self): 193 | self.rnn_model_stable.load_state_dict(self.rnn_model.state_dict()) 194 | self.q_net_stable.load_state_dict(self.q_net.state_dict()) 195 | 196 | def save_model(self, name): 197 | """ 198 | Args: 199 | name (str): Name contains the episode information (To give saved models unique names) 200 | """ 201 | # Save the parameters for rnn model and q net separately 202 | metadata = { 203 | 'sample_rate': 8000 204 | } 205 | self.rnn_model.rnn_model.save(os.path.join(self.SAVE_PATH, 'sp_' + name), metadata) 206 | torch.save(self.rnn_model.state_dict(), os.path.join(self.SAVE_PATH, 'rnn_' + name)) 207 | torch.save(self.q_net.state_dict(), os.path.join(self.SAVE_PATH, 'qnet_' + name)) 208 | 209 | 210 | class RnnSeparator(nn.Module): 211 | def __init__(self, rnn_config, verbose=False): 212 | super(RnnSeparator, self).__init__() 213 | self.rnn_model = nussl.ml.SeparationModel(rnn_config, verbose=verbose) 214 | 215 | def forward(self, x): 216 | return self.rnn_model(x) 217 | 218 | 219 | class DQN(nn.Module): 220 | def __init__(self, network_params): 221 | """ 222 | The main DQN class, which takes the output of the RNN separator and input and 223 | returns q-values (prob dist of the action space) 224 | 225 | Args: 226 | network_params (dict): Dict of network parameters 227 | 228 | Returns: 229 | q_values (torch.Tensor): q-values 230 | """ 231 | super(DQN, self).__init__() 232 | 233 | self.stft_diff = network_params['stft_diff'] 234 | self.fc = nn.Linear(9, network_params['total_actions']) 235 | self.prelu = nn.PReLU() 236 | 237 | def forward(self, output, agent_info, total_time_steps): 238 | # Reshape the output again to get dual channels 239 | # Perform short time fourier transform of this output 240 | _, _, nt = output['mix_audio'].shape 241 | output['mix_audio'] = output['mix_audio'].reshape(-1, 2, nt) 242 | stft_data = self.stft_diff(output['mix_audio'], direction='transform') 243 | 244 | # Get the IPD and ILD features from the stft data 245 | _, nt, nf, _, ns = output['mask'].shape 246 | output['mask'] = output['mask'].view(-1, 2, nt, nf, ns) 247 | output['mask'] = output['mask'].max(dim=1)[0] 248 | ipd, ild, vol = audio_processing.ipd_ild_features(stft_data) 249 | 250 | ipd = ipd.unsqueeze(-1) 251 | ild = ild.unsqueeze(-1) 252 | vol = vol.unsqueeze(-1) 253 | 254 | ipd_means = (output['mask'] * ipd).mean(dim=[1, 2]).unsqueeze(1) 255 | ild_means = (output['mask'] * ild).mean(dim=[1, 2]).unsqueeze(1) 256 | vol_means = (output['mask'] * vol).mean(dim=[1, 2]).unsqueeze(1) 257 | X = torch.cat([ipd_means, ild_means, vol_means], dim=1) 258 | X = X.reshape(X.shape[0], -1) 259 | agent_info = agent_info.view(-1, 3) 260 | X = torch.cat((X, agent_info), dim=1) 261 | X = self.prelu(self.fc(X)) 262 | q_values = F.softmax(X, dim=1) 263 | 264 | return q_values 265 | 266 | def flatten_features(self, x): 267 | # Flatten all dimensions except the batch 268 | size = x.size()[1:] 269 | num_features = 1 270 | for dimension in size: 271 | num_features *= dimension 272 | 273 | return num_features 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | -------------------------------------------------------------------------------- /src/agent.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | import warnings 4 | from collections import deque 5 | 6 | import numpy as np 7 | import gym 8 | from scipy.spatial.distance import euclidean 9 | import nussl 10 | 11 | import utils 12 | import constants 13 | from datasets import BufferData 14 | 15 | # setup logging 16 | logging.basicConfig(level=logging.INFO) 17 | logging_str = ( 18 | f"\n\n" 19 | f"------------------------------ \n" 20 | f"- Starting to Fit Agent\n" 21 | f"------------------------------- \n\n" 22 | ) 23 | logging.info(logging_str) 24 | 25 | 26 | class AgentBase: 27 | def __init__( 28 | self, 29 | env, 30 | dataset, 31 | episodes=1, 32 | max_steps=100, 33 | gamma=0.98, 34 | alpha=0.001, 35 | decay_rate=0.0005, 36 | stable_update_freq=-1, 37 | save_freq=1, 38 | play_audio=False, 39 | show_room=False, 40 | writer=None, 41 | dense=True, 42 | decay_per_ep=False, 43 | validation_freq=None 44 | ): 45 | """ 46 | This class is a base agent class which will be inherited when creating various agents. 47 | 48 | Args: 49 | self.env (gym object): The gym self.environment object which the agent is going to explore 50 | dataset (nussl dataset): Nussl dataset object for experience replay 51 | episodes (int): # of episodes to simulate 52 | max_steps (int): # of steps the agent can take before stopping an episode 53 | gamma (float): Discount factor 54 | alpha (float): Learning rate alpha 55 | decay_rate (float): decay rate for exploration rate (we want to decrease exploration as time proceeds) 56 | stable_update_freq (int): Update frequency value for stable networks (Target networks) 57 | play_audio (bool): choose to play audio at each iteration 58 | show_room (bool): choose to display the configurations and movements within a room 59 | writer (torch.utils.tensorboard.SummaryWriter): for logging to tensorboard 60 | dense (bool): makes the rewards more dense, less sparse 61 | gives reward for distance to closest source every step 62 | decay_per_ep (bool): If set to true, epsilon is decayed per episode, else decayed per step 63 | validation_freq (None or int): if not None, then do a validation run every validation_freq # of episodes 64 | validation run means epsilon=0 (no random actions) and no model update/training 65 | """ 66 | self.env = env 67 | self.dataset = dataset 68 | self.episodes = episodes 69 | self.max_steps = max_steps 70 | self.gamma = gamma 71 | self.alpha = alpha 72 | self.epsilon = constants.MAX_EPSILON 73 | self.decay_rate = decay_rate 74 | self.stable_update_freq = stable_update_freq 75 | self.save_freq = save_freq 76 | self.play_audio = play_audio 77 | self.show_room = show_room 78 | self.writer = writer 79 | self.dense = dense 80 | self.decay_per_ep = decay_per_ep 81 | self.validation_freq = validation_freq 82 | self.losses = [] 83 | self.cumulative_reward = 0 84 | self.total_experiment_steps = 0 85 | self.mean_episode_reward = [] 86 | self.action_memory = deque(maxlen=4) 87 | 88 | # for saving model at best validation episode (as determined by mean reward in the episode) 89 | if self.validation_freq is not None: 90 | self.max_validation_reward = -np.inf 91 | 92 | def fit(self): 93 | for episode in range(1, self.episodes + 1): 94 | # Reset the self.environment and any other variables at beginning of each episode 95 | prev_state = None 96 | 97 | episode_rewards = [] 98 | found_sources = [] 99 | 100 | # validation episode? 101 | validation_episode = False 102 | if self.validation_freq is not None: 103 | validation_episode = True if (episode % self.validation_freq == 0) else False 104 | 105 | # Measure time to complete the episode 106 | start = time.time() 107 | for step in range(self.max_steps): 108 | self.total_experiment_steps += 1 109 | 110 | # Perform random actions with prob < epsilon 111 | model_action = False 112 | if (np.random.uniform(0, 1) < self.epsilon): 113 | action = self.env.action_space.sample() 114 | else: 115 | model_action = True 116 | 117 | # validation run: no random actions and no model update/training 118 | if validation_episode: 119 | model_action = True 120 | 121 | if model_action: 122 | # For the first two steps (We don't have prev_state, new_state pair), then perform a random action 123 | if step < 2: 124 | action = self.env.action_space.sample() 125 | else: 126 | # This is where agent will actually do something 127 | action = self.choose_action() 128 | 129 | # if same action 4x in a row (to avoid model infinite loop), choose an action randomly 130 | self.action_memory.append(action) 131 | if all(self.action_memory[0] == x for x in self.action_memory): 132 | action = self.env.action_space.sample() 133 | self.action_memory.append(action) 134 | 135 | # Perform the chosen action (NOTE: reward is a dictionary) 136 | new_state, agent_info, reward, won = self.env.step( 137 | action, play_audio=self.play_audio, show_room=self.show_room 138 | ) 139 | 140 | # dense vs sparse 141 | total_step_reward = 0 142 | if self.dense: 143 | total_step_reward += sum(reward.values()) 144 | else: 145 | total_step_reward += (reward['step_penalty'] + reward['turn_off_reward']) 146 | 147 | # record reward stats 148 | self.cumulative_reward += total_step_reward 149 | episode_rewards.append(total_step_reward) 150 | 151 | if reward['turn_off_reward'] == constants.TURN_OFF_REWARD: 152 | print('In FIT. Received reward: {} at step {}\n'.format(total_step_reward, step)) 153 | logging.info(f"In FIT. Received reward {total_step_reward} at step: {step}\n") 154 | 155 | # Perform Update 156 | if not validation_episode: 157 | self.update() 158 | 159 | # store SARS in buffer 160 | if prev_state is not None and new_state is not None and not won: 161 | self.dataset.write_buffer_data( 162 | prev_state, action, total_step_reward, new_state, agent_info, episode, step 163 | ) 164 | 165 | # Decay epsilon based on total steps (across all episodes, not within an episode) 166 | if not self.decay_per_ep: 167 | self.epsilon = constants.MIN_EPSILON + ( 168 | constants.MAX_EPSILON - constants.MIN_EPSILON 169 | ) * np.exp(-self.decay_rate * self.total_experiment_steps) 170 | if self.total_experiment_steps % 200 == 0: 171 | print("Epsilon decayed to {} at step {} ".format(self.epsilon, self.total_experiment_steps)) 172 | 173 | # Update stable networks based on number of steps 174 | if step % self.stable_update_freq == 0: 175 | self.update_stable_networks() 176 | 177 | # Terminate the episode if episode is won or at max steps 178 | if won or (step == self.max_steps - 1): 179 | # terminal state is silence 180 | silence_array = np.zeros_like(prev_state.audio_data) 181 | terminal_silent_state = prev_state.make_copy_with_audio_data(audio_data=silence_array) 182 | self.dataset.write_buffer_data( 183 | prev_state, action, total_step_reward, terminal_silent_state, agent_info, episode, step 184 | ) 185 | 186 | # record mean reward for this episode 187 | self.mean_episode_reward = np.mean(episode_rewards) 188 | 189 | if validation_episode: 190 | # new best validation reward 191 | if self.mean_episode_reward > self.max_validation_reward: 192 | self.max_validation_reward = self.mean_episode_reward 193 | 194 | # save best validation model 195 | self.save_model('best_valid_reward.pt') 196 | 197 | if self.writer is not None: 198 | self.writer.add_scalar('Reward/validation_mean_per_episode', self.mean_episode_reward, episode) 199 | elif self.writer is not None: 200 | self.writer.add_scalar('Reward/mean_per_episode', self.mean_episode_reward, episode) 201 | self.writer.add_scalar('Reward/cumulative', self.cumulative_reward, self.total_experiment_steps) 202 | 203 | end = time.time() 204 | total_time = end - start 205 | 206 | # log episode summary 207 | logging_str = ( 208 | f"\n\n" 209 | f"Episode Summary \n" 210 | f"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ \n" 211 | f"- Episode: {episode}\n" 212 | f"- Won?: {won}\n" 213 | f"- Finished at step: {step+1}\n" 214 | f"- Time taken: {total_time:04f} \n" 215 | f"- Steps/Second: {float(step+1)/total_time:04f} \n" 216 | f"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ \n" 217 | ) 218 | print(logging_str) 219 | logging.info(logging_str) 220 | 221 | # break and go to new episode 222 | break 223 | 224 | prev_state = new_state 225 | 226 | if episode % self.save_freq == 0: 227 | name = 'ep{}.pt'.format(episode) 228 | self.save_model(name) 229 | 230 | # Decay epsilon per episode 231 | if self.decay_per_ep: 232 | self.epsilon = constants.MIN_EPSILON + ( 233 | constants.MAX_EPSILON - constants.MIN_EPSILON 234 | ) * np.exp(-self.decay_rate * episode) 235 | print("Decayed epsilon value: {}".format(self.epsilon)) 236 | 237 | # Reset the environment 238 | self.env.reset() 239 | 240 | def choose_action(self): 241 | """ 242 | This function must be implemented by subclass. 243 | It will choose the action to take at any time step. 244 | 245 | Returns: 246 | action (int): the action to take 247 | """ 248 | raise NotImplementedError() 249 | 250 | def update(self): 251 | """ 252 | This function must be implemented by subclass. 253 | It will perform an update (e.g. updating Q table or Q network) 254 | """ 255 | raise NotImplementedError() 256 | 257 | def update_stable_networks(self): 258 | """ 259 | This function must be implemented by subclass. 260 | It will perform an update to the stable networks. I.E Copy values from current network to target network 261 | """ 262 | raise NotImplementedError() 263 | 264 | def save_model(self, name): 265 | raise NotImplementedError() 266 | 267 | 268 | class RandomAgent(AgentBase): 269 | def choose_action(self): 270 | """ 271 | Since this is a random agent, we just randomly sample our action 272 | every time. 273 | """ 274 | return self.env.action_space.sample() 275 | 276 | def update(self): 277 | """ 278 | No update for a random agent 279 | """ 280 | pass 281 | 282 | def update_stable_networks(self): 283 | pass 284 | 285 | def save_model(self, name): 286 | pass 287 | 288 | 289 | # Create a perfect agent that steps to each of the closest sources one at a time. 290 | class OracleAgent(AgentBase): 291 | """ 292 | This agent is a perfect agent. It knows where all of the sources are and 293 | will iteratively go to the closest source at each time step. 294 | """ 295 | 296 | def choose_action(self): 297 | """ 298 | Since we know all information about the environment, the agent moves the following way. 299 | 1. Find the closest audio source using euclidean distance 300 | 2. Rotate to face the audio source if necessary 301 | 3. Step in the direction of the audio source 302 | """ 303 | agent = self.env.agent_loc 304 | radius = self.env.acceptable_radius 305 | action = None 306 | 307 | # find the closest audio source 308 | source, minimum_distance = self.env.source_locs[0], np.linalg.norm( 309 | agent - self.env.source_locs[0]) 310 | for s in self.env.source_locs[1:]: 311 | dist = np.linalg.norm(agent - s) 312 | if dist < minimum_distance: 313 | minimum_distance = dist 314 | source = s 315 | 316 | # Determine our current angle 317 | angle = abs(int(np.degrees(self.env.cur_angle))) % 360 318 | # The agent is too far left or right of the source 319 | if np.abs(source[0] - agent[0]) >= radius: 320 | # First check if we need to turn 321 | if angle not in [0, 1, 179, 180, 181]: 322 | action = 2 323 | # We are facing correct way need to move forward or backward 324 | else: 325 | if source[0] < agent[0]: 326 | if angle == 0: 327 | action = 1 328 | else: 329 | action = 0 330 | else: 331 | if angle == 0: 332 | action = 0 333 | else: 334 | action = 1 335 | # Agent is to the right of the source 336 | elif np.abs(source[1] - agent[1]) >= radius: 337 | # First check if we need to turn 338 | if angle not in [89, 90, 91, 269, 270, 271]: 339 | action = 2 340 | # We are facing correct way need to move forward or backward 341 | else: 342 | if source[1] < agent[1]: 343 | if angle == 90: 344 | action = 1 345 | else: 346 | action = 0 347 | else: 348 | if angle == 90: 349 | action = 0 350 | else: 351 | action = 1 352 | else: 353 | action = np.random.randint(0, 4) 354 | return action 355 | 356 | def update(self): 357 | """ 358 | Since this is a perfect agent, we do not update any network nor save any models 359 | """ 360 | pass 361 | 362 | def update_stable_networks(self): 363 | """ 364 | Since this is a perfect agent, we do not update any network nor save any models 365 | """ 366 | pass 367 | 368 | def save_model(self, name): 369 | """ 370 | Since this is a perfect agent, we do not update any network nor save any models 371 | """ 372 | pass 373 | -------------------------------------------------------------------------------- /src/audio_room/envs/audio_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from pyroomacoustics import MicrophoneArray, ShoeBox, Room, linear_2D_array, Constants 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from gym import spaces 6 | from scipy.spatial.distance import euclidean 7 | from sklearn.metrics.pairwise import euclidean_distances 8 | from copy import deepcopy 9 | import nussl 10 | import sys 11 | sys.path.append("../../") 12 | 13 | import constants 14 | from utils import choose_random_files 15 | 16 | # Suppress 'c' argument errors caused by room.plot() 17 | from matplotlib.axes._axes import _log as matplotlib_axes_logger 18 | matplotlib_axes_logger.setLevel('ERROR') 19 | 20 | 21 | class AudioEnv(gym.Env): 22 | def __init__( 23 | self, 24 | room_config, 25 | source_folders_dict, 26 | resample_rate=8000, 27 | num_channels=2, 28 | bytes_per_sample=2, 29 | corners=False, 30 | absorption=0.0, 31 | max_order=2, 32 | step_size=1, 33 | acceptable_radius=.5, 34 | num_sources=2, 35 | degrees=np.deg2rad(30), 36 | reset_sources=True, 37 | same_config=False 38 | ): 39 | """ 40 | This class inherits from OpenAI Gym Env and is used to simulate the agent moving in PyRoom. 41 | 42 | Args: 43 | room_config (List or np.array): dimensions of the room. For Shoebox, in the form of [10,10]. Otherwise, 44 | in the form of [[1,1], [1, 4], [4, 4], [4, 1]] specifying the corners of the room 45 | resample_rate (int): sample rate in Hz 46 | num_channels (int): number of channels (used in playing what the mic hears) 47 | bytes_per_sample (int): used in playing what the mic hears 48 | corners (bool): False if using Shoebox config, otherwise True 49 | absorption (float): Absorption param of the room (how walls absorb sound) 50 | max_order (int): another room parameter 51 | step_size (float): specified step size else we programmatically assign it 52 | acceptable_radius (float): source is considered found/turned off if agent is within this distance of src 53 | num_sources (int): the number of audio sources the agent will listen to 54 | degrees (float): value of degrees to rotate in radians (.2618 radians = 15 degrees) 55 | reset_sources (bool): True if you want to choose different sources when resetting env 56 | source_folders_dict (Dict[str, int]): specify how many source files to choose from each folder 57 | e.g. 58 | { 59 | 'car_horn_source_folder': 1, 60 | 'phone_ringing_source_folder': 1 61 | } 62 | same_config (bool): If set to true, difficulty of env becomes easier - agent initial loc, source placement, 63 | and audio files don't change over episodes 64 | """ 65 | self.resample_rate = resample_rate 66 | self.absorption = absorption 67 | self.max_order = max_order 68 | self.audio = [] 69 | self.num_channels = num_channels 70 | self.bytes_per_sample = bytes_per_sample 71 | self.num_actions = 4 72 | self.action_space = spaces.Discrete(self.num_actions) 73 | self.action_to_string = { 74 | 0: "Forward", 75 | 1: "Backward", 76 | 2: "Rotate Right", 77 | 3: "Rotate Left", 78 | } 79 | self.corners = corners 80 | self.room_config = room_config 81 | self.acceptable_radius = acceptable_radius 82 | self.step_size = step_size 83 | self.num_sources = num_sources 84 | self.source_locs = None 85 | self.min_size_audio = np.inf 86 | self.degrees = degrees 87 | self.cur_angle = 0 88 | self.reset_sources = reset_sources 89 | self.source_folders_dict = source_folders_dict 90 | self.same_config = same_config 91 | 92 | # choose audio files as sources 93 | self.direct_sources = choose_random_files(self.source_folders_dict) 94 | self.direct_sources_copy = deepcopy(self.direct_sources) 95 | 96 | # create the room and add sources 97 | self._create_room() 98 | self._add_sources() 99 | self._move_agent(new_agent_loc=None, initial_placing=True) 100 | 101 | # If same_config = True, keep track of the generated locations 102 | self.fixed_source_locs = self.source_locs.copy() 103 | self.fixed_agent_locs = self.agent_loc.copy() 104 | 105 | # The step size must be smaller than radius in order to make sure we don't 106 | # overstep a audio source 107 | if self.acceptable_radius < self.step_size / 2: 108 | raise ValueError( 109 | """The threshold radius (acceptable_radius) must be at least step_size / 2. Else, the agent may overstep 110 | an audio source.""" 111 | ) 112 | 113 | def _create_room(self): 114 | """ 115 | This function creates the Pyroomacoustics room with our environment class variables. 116 | """ 117 | # non-Shoebox config (corners of room are given) 118 | if self.corners: 119 | self.room = Room.from_corners( 120 | self.room_config, fs=self.resample_rate, 121 | absorption=self.absorption, max_order=self.max_order 122 | ) 123 | 124 | # The x_max and y_max in this case would be used to generate agent's location randomly 125 | self.x_min = min(self.room_config[0]) 126 | self.y_min = min(self.room_config[1]) 127 | self.x_max = max(self.room_config[0]) 128 | self.y_max = max(self.room_config[1]) 129 | 130 | # ShoeBox config 131 | else: 132 | self.room = ShoeBox( 133 | self.room_config, fs=self.resample_rate, 134 | absorption=self.absorption, max_order=self.max_order 135 | ) 136 | self.x_max = self.room_config[0] 137 | self.y_max = self.room_config[1] 138 | self.x_min, self.y_min = 0, 0 139 | 140 | def _move_agent(self, new_agent_loc, initial_placing=False): 141 | """ 142 | This function moves the agent to a new location (given by new_agent_loc). It effectively removes the 143 | agent (mic array) from the room and then adds it back in the new location. 144 | 145 | If initial_placing == True, the agent is placed in the room for the first time. 146 | 147 | Args: 148 | new_agent_loc (List[int] or np.array or None): [x,y] coordinates of the agent's new location. 149 | initial_placing (bool): True if initially placing the agent in the room at the beginning of the episode 150 | """ 151 | # Placing agent in room for the first time (likely at the beginning of a new episode, after a reset) 152 | if initial_placing: 153 | if new_agent_loc is None: 154 | loc = self._sample_points(1, sources=False, agent=True) 155 | print("Placing agent at {}".format(loc)) 156 | self.agent_loc = loc 157 | self.cur_angle = 0 # Reset the orientation of agent back to zero at start of an ep 158 | else: 159 | self.agent_loc = new_agent_loc.copy() 160 | print("Placing agent at {}".format(self.agent_loc)) 161 | self.cur_angle = 0 162 | else: 163 | # Set the new agent location (where to move) 164 | self.agent_loc = new_agent_loc 165 | 166 | # Setup microphone in agent location, delete the array at previous time step 167 | self.room.mic_array = None 168 | 169 | if self.num_channels == 2: 170 | # Create the array at current time step (2 mics, angle IN RADIANS, 0.2m apart) 171 | mic = MicrophoneArray( 172 | linear_2D_array(self.agent_loc, 2, self.cur_angle, 173 | constants.DIST_BTWN_EARS), self.room.fs 174 | ) 175 | self.room.add_microphone_array(mic) 176 | else: 177 | mic = MicrophoneArray(self.agent_loc.reshape(-1, 1), self.room.fs) 178 | self.room.add_microphone_array(mic) 179 | 180 | def _sample_points(self, num_points, sources=True, agent=False): 181 | """ 182 | This function generates randomly sampled points for the sources (or agent) to be placed 183 | 184 | Args: 185 | num_points (int): Number of [x, y] random points to generate 186 | sources (bool): True if generating points for sources (agent must be False) 187 | agent(bool): True if generating points for agent (sources must be False) 188 | 189 | Returns: 190 | sample_points (List[List[int]]): A list of [x,y] points for source location 191 | or 192 | random_point (List[int]): An [x, y] point for agent location 193 | """ 194 | assert(sources != agent) 195 | sampled_points = [] 196 | 197 | if sources: 198 | angles = np.arange(0, 2 * np.pi, self.degrees).tolist() 199 | while len(sampled_points) < num_points: 200 | chosen_angles = np.random.choice(angles, num_points) 201 | for angle in chosen_angles: 202 | direction = np.random.choice([-1, 1]) 203 | distance = np.random.uniform(2 * self.step_size, 5 * self.step_size) 204 | x = (self.x_min + self.x_max) / 2 205 | y = (self.y_min + self.y_max) / 2 206 | x = x + direction * np.cos(angle) * distance 207 | y = y + direction + np.sin(angle) * distance 208 | point = [x, y] 209 | if self.room.is_inside(point, include_borders=False): 210 | accepted = True 211 | if len(sampled_points) > 0: 212 | dist_to_existing = euclidean_distances( 213 | np.array(point).reshape(1, -1), sampled_points) 214 | accepted = dist_to_existing.min() > 2 * self.step_size 215 | if accepted and len(sampled_points) < num_points: 216 | sampled_points.append(point) 217 | return sampled_points 218 | elif agent: 219 | accepted = False 220 | while not accepted: 221 | accepted = True 222 | point = [ 223 | np.random.uniform(self.x_min, self.x_max), 224 | np.random.uniform(self.y_min, self.y_max), 225 | ] 226 | 227 | # ensure agent doesn't spawn too close to sources 228 | for source_loc in self.source_locs: 229 | if(euclidean(point, source_loc) < 2.0 * self.acceptable_radius): 230 | accepted = False 231 | 232 | return point 233 | 234 | def _add_sources(self, new_source_locs=None, reset_env=False, removing_source=None): 235 | """ 236 | This function adds the sources to the environment. 237 | 238 | Args: 239 | new_source_locs (List[List[int]]): A list consisting of [x, y] coordinates if the programmer wants 240 | to manually set the new source locations 241 | reset_env (bool): Bool indicating whether we reset_env the agents position to be the mean 242 | of all the sources 243 | removing_source (None or int): Value that will tell us if we are removing a source 244 | from sources 245 | """ 246 | # Can reset with NEW, randomly sampled sources (typically at the start of a new episode) 247 | if self.reset_sources: 248 | self.direct_sources = choose_random_files(self.source_folders_dict) 249 | else: 250 | self.direct_sources = deepcopy(self.direct_sources_copy) 251 | 252 | if new_source_locs is None: 253 | self.source_locs = self._sample_points(num_points=self.num_sources) 254 | else: 255 | self.source_locs = new_source_locs.copy() 256 | 257 | print("Source locs {}".format(self.source_locs)) 258 | 259 | self.audio = [] 260 | self.min_size_audio = np.inf 261 | for idx, audio_file in enumerate(self.direct_sources): 262 | # Audio will be automatically re-sampled to the given rate (default sr=8000). 263 | a = nussl.AudioSignal(audio_file, sample_rate=self.resample_rate) 264 | a.to_mono() 265 | 266 | # normalize audio so both sources have similar volume at beginning before mixing 267 | loudness = a.loudness() 268 | 269 | # # mix to reference db 270 | ref_db = -40 271 | db_diff = ref_db - loudness 272 | gain = 10 ** (db_diff / 20) 273 | a = a * gain 274 | 275 | # Find min sized source to ensure something is playing at all times 276 | if len(a) < self.min_size_audio: 277 | self.min_size_audio = len(a) 278 | self.audio.append(a.audio_data.squeeze()) 279 | 280 | # add sources using audio data 281 | for idx, audio in enumerate(self.audio): 282 | self.room.add_source( 283 | self.source_locs[idx], signal=audio[: self.min_size_audio]) 284 | 285 | def _remove_source(self, index): 286 | """ 287 | This function removes a source from the environment 288 | 289 | Args: 290 | index (int): index of the source to remove 291 | """ 292 | if index < len(self.source_locs): 293 | src = self.source_locs.pop(index) 294 | src2 = self.direct_sources.pop(index) 295 | 296 | # actually remove source from the room 297 | room_src = self.room.sources.pop(index) 298 | 299 | def step(self, action, play_audio=False, show_room=False): 300 | """ 301 | This function simulates the agent taking one step in the environment (and room) given an action: 302 | 0 = Move forward 303 | 1 = Move backward 304 | 2 = Turn right x degrees 305 | 3 = Turn left x degrees 306 | 307 | It calls _move_agent, checks to see if the agent has reached a source, and if not, computes the RIR. 308 | 309 | Args: 310 | action (int): direction agent is to move - 0 (L), 1 (R), 2 (U), 3 (D) 311 | play_audio (bool): whether to play the the mic audio (stored in "data") 312 | show_room (bool): Controls whether room is visually plotted at each step 313 | 314 | Returns: 315 | Tuple of the format List (empty if done, else [data]), reward, done 316 | """ 317 | # return reward dictionary for each step 318 | reward = { 319 | 'step_penalty': constants.STEP_PENALTY, 320 | 'turn_off_reward': 0, 321 | 'closest_reward': 0, 322 | 'orient_penalty': 0 323 | } 324 | 325 | #print('Action:', self.action_to_string[action]) 326 | 327 | # movement 328 | x, y = self.agent_loc[0], self.agent_loc[1] 329 | done = False 330 | 331 | if action in [0, 1]: 332 | if action == 0: 333 | sign = 1 334 | if action == 1: 335 | sign = -1 336 | x = x + sign * np.cos(self.cur_angle) * self.step_size 337 | y = y + sign * np.sin(self.cur_angle) * self.step_size 338 | elif action == 2: 339 | self.cur_angle = round((self.cur_angle + self.degrees) % (2 * np.pi), 4) 340 | reward['orient_penalty'] = constants.ORIENT_PENALTY 341 | elif action == 3: 342 | self.cur_angle = round((self.cur_angle - self.degrees) % (2 * np.pi), 4) 343 | reward['orient_penalty'] = constants.ORIENT_PENALTY 344 | # Check if the new points lie within the room 345 | try: 346 | if self.room.is_inside([x, y], include_borders=False): 347 | points = np.array([x, y]) 348 | else: 349 | points = self.agent_loc 350 | except: 351 | # in case the is_inside func fails 352 | points = self.agent_loc 353 | 354 | # Move agent in the direction of action 355 | self._move_agent(new_agent_loc=points) 356 | 357 | # Check if goal state is reached 358 | for index, source in enumerate(self.source_locs): 359 | # Agent has found the source 360 | if euclidean(self.agent_loc, source) <= self.acceptable_radius: 361 | print(f'Agent has found source {self.direct_sources[index]}. \nAgent loc: {self.agent_loc}, Source loc: {source}') 362 | reward['turn_off_reward'] = constants.TURN_OFF_REWARD 363 | # If there is more than one source, then we want to remove this source 364 | if len(self.source_locs) > 1: 365 | # remove the source (will take effect in the next step) 366 | self._remove_source(index=index) 367 | 368 | # Calculate the impulse response 369 | self.room.compute_rir() 370 | self.room.simulate() 371 | data = self.room.mic_array.signals 372 | 373 | # Convert the data back to Nussl Audio object 374 | data = nussl.AudioSignal( 375 | audio_data_array=data, sample_rate=self.resample_rate) 376 | 377 | if play_audio or show_room: 378 | self.render(data, play_audio, show_room) 379 | 380 | done = False 381 | return data, [self.agent_loc, self.cur_angle], reward, done 382 | 383 | # This was the last source hence we can assume we are done 384 | else: 385 | done = True 386 | self.reset() 387 | return None, [self.agent_loc, self.cur_angle], reward, done 388 | 389 | if not done: 390 | # Calculate the impulse response 391 | self.room.compute_rir() 392 | self.room.simulate() 393 | data = self.room.mic_array.signals 394 | 395 | # Convert data to nussl audio signal 396 | data = nussl.AudioSignal( 397 | audio_data_array=data, sample_rate=self.resample_rate) 398 | 399 | if play_audio or show_room: 400 | self.render(data, play_audio, show_room) 401 | 402 | # be careful not to give too much reward here (i.e. 1/min_dist could be very large if min_dist is quite small) 403 | # related to acceptable radius size because this reward is only given when NOT turning off a source 404 | min_dist = euclidean_distances( 405 | np.array(self.agent_loc).reshape(1, -1), self.source_locs).min() 406 | reward['closest_reward'] = (1 / (min_dist + 1e-4)) 407 | #print('agent_loc:', self.agent_loc, 'source_locs:', self.source_locs) 408 | #print('cur angle:', self.cur_angle) 409 | #print('reward:', reward) 410 | 411 | # Return the room rir and convolved signals as the new state 412 | return data, [self.agent_loc, self.cur_angle], reward, done 413 | 414 | def reset(self, removing_source=None): 415 | """ 416 | This function re-creates the room, then places sources and agent randomly (but separated) in the room. 417 | To be used after each episode. 418 | 419 | Args: 420 | removing_source (int): Integer that tells us the index of sources that we will be removing 421 | """ 422 | # re-create room 423 | self._create_room() 424 | 425 | if not self.same_config: 426 | # randomly add sources to the room 427 | self._add_sources() 428 | # randomly place agent in room at beginning of next episode 429 | self._move_agent(new_agent_loc=None, initial_placing=True) 430 | else: 431 | # place sources and agent at same locations to start every episode 432 | self._add_sources(new_source_locs=self.fixed_source_locs) 433 | self._move_agent(new_agent_loc=self.fixed_agent_locs, initial_placing=True) 434 | 435 | def render(self, data, play_audio, show_room): 436 | """ 437 | Play the convolved sound using SimpleAudio. 438 | 439 | Args: 440 | data (AudioSignal): if 2 mics, should be of shape (x, 2) 441 | play_audio (bool): If true, audio will play 442 | show_room (bool): If true, room will be displayed to user 443 | """ 444 | if play_audio: 445 | data.embed_audio(display=True) 446 | 447 | # Show the room while the audio is playing 448 | if show_room: 449 | fig, ax = self.room.plot(img_order=0) 450 | plt.pause(1) 451 | 452 | plt.close() 453 | 454 | elif show_room: 455 | fig, ax = self.room.plot(img_order=0) 456 | plt.pause(1) 457 | plt.close() 458 | -------------------------------------------------------------------------------- /src/datasets.py: -------------------------------------------------------------------------------- 1 | import nussl 2 | import json 3 | import os 4 | import constants 5 | import numpy as np 6 | from torch.utils.data import IterableDataset 7 | 8 | import warnings 9 | 10 | from torch.utils.data import Dataset 11 | 12 | from nussl.core import AudioSignal 13 | import transforms as tfm 14 | import tqdm 15 | 16 | 17 | class BaseDataset(Dataset): 18 | """ 19 | The BaseDataset class is the starting point for all dataset hooks 20 | in nussl. To subclass BaseDataset, you only have to implement two 21 | functions: 22 | 23 | - ``get_items``: a function that is passed the folder and generates a 24 | list of items that will be processed by the next function. The 25 | number of items in the list will dictate len(dataset). Must return 26 | a list. 27 | - ``process_item``: this function processes a single item in the list 28 | generated by get_items. Must return a dictionary. 29 | 30 | After process_item is called, a set of Transforms can be applied to the 31 | output of process_item. If no transforms are defined (``self.transforms = None``), 32 | then the output of process_item is returned by self[i]. For implemented 33 | Transforms, see nussl.datasets.transforms. For example, 34 | PhaseSpectrumApproximation will add three new keys to the output dictionary 35 | of process_item: 36 | 37 | - mix_magnitude: the magnitude spectrogram of the mixture 38 | - source_magnitudes: the magnitude spectrogram of each source 39 | - ideal_binary_mask: the ideal binary mask for each source 40 | 41 | The transforms are applied in sequence using transforms.Compose. 42 | Not all sequences of transforms will be valid (e.g. if you pop a key in 43 | one transform but a later transform operates on that key, you will get 44 | an error). 45 | 46 | For examples of subclassing, see ``nussl.datasets.hooks``. 47 | 48 | Args: 49 | folder (str): location that should be processed to produce the list of files 50 | 51 | transform (transforms.* object, optional): A transforms to apply to the output of 52 | ``self.process_item``. If using transforms.Compose, each transform will be 53 | applied in sequence. Defaults to None. 54 | 55 | sample_rate (int, optional): Sample rate to use for each audio files. If 56 | audio file sample rate doesn't match, it will be resampled on the fly. 57 | If None, uses the default sample rate. Defaults to None. 58 | 59 | stft_params (STFTParams, optional): STFTParams object defining window_length, 60 | hop_length, and window_type that will be set for each AudioSignal object. 61 | Defaults to None (32ms window length, 8ms hop, 'hann' window). 62 | 63 | num_channels (int, optional): Number of channels to make each AudioSignal 64 | object conform to. If an audio signal in your dataset has fewer channels 65 | than ``num_channels``, a warning is raised, as the behavior in this case 66 | is undefined. Defaults to None. 67 | 68 | strict_sample_rate (bool, optional): Whether to raise an error if 69 | 70 | Raises: 71 | DataSetException: Exceptions are raised if the output of the implemented 72 | functions by the subclass don't match the specification. 73 | """ 74 | 75 | def __init__(self, folder, transform=None, sample_rate=None, stft_params=None, 76 | num_channels=None, strict_sample_rate=True, cache_populated=False): 77 | self.folder = folder 78 | self.items = self.get_items(self.folder) 79 | self.transform = transform 80 | 81 | self.cache_populated = cache_populated 82 | 83 | self.stft_params = stft_params 84 | self.sample_rate = sample_rate 85 | self.num_channels = num_channels 86 | self.strict_sample_rate = strict_sample_rate 87 | 88 | if not isinstance(self.items, list): 89 | raise DataSetException("Output of self.get_items must be a list!") 90 | 91 | # getting one item in order to set up parameters for audio 92 | # signals if necessary, if there are any items 93 | if self.items: 94 | self.process_item(self.items[0]) 95 | 96 | def filter_items_by_condition(self, func): 97 | """ 98 | Filter the items in the list according to a function that takes 99 | in both the dataset as well as the item currently be processed. 100 | If the item in the list passes the condition, then it is kept 101 | in the list. Otherwise it is taken out of the list. For example, 102 | a function that would get rid of an item if it is below some 103 | minimum number of seconds would look like this: 104 | 105 | .. code-block:: python 106 | 107 | min_length = 1 # in seconds 108 | 109 | # self here refers to the dataset 110 | def remove_short_audio(self, item): 111 | processed_item = self.process_item(item) 112 | mix_length = processed_item['mix'].signal_duration 113 | if mix_length < min_length: 114 | return False 115 | return True 116 | 117 | dataset.items # contains all items 118 | dataset.filter_items_by_condition(remove_short_audio) 119 | dataset.items # contains only items longer than min length 120 | 121 | Args: 122 | func (function): A function that takes in two arguments: the dataset and 123 | this dataset object (self). The function must return a bool. 124 | """ 125 | filtered_items = [] 126 | n_removed = 0 127 | desc = f"Filtered {n_removed} items out of dataset" 128 | pbar = tqdm.tqdm(self.items, desc=desc) 129 | for item in pbar: 130 | check = func(self, item) 131 | if not isinstance(check, bool): 132 | raise DataSetException( 133 | "Output of filter function must be True or False!" 134 | ) 135 | if check: 136 | filtered_items.append(item) 137 | else: 138 | n_removed += 1 139 | pbar.set_description(f"Filtered {n_removed} items out of dataset") 140 | self.items = filtered_items 141 | 142 | @property 143 | def cache_populated(self): 144 | return self._cache_populated 145 | 146 | @cache_populated.setter 147 | def cache_populated(self, value): 148 | self.post_cache_transforms = [] 149 | cache_transform = None 150 | 151 | transforms = ( 152 | self.transform.transforms 153 | if isinstance(self.transform, tfm.Compose) 154 | else [self.transform]) 155 | 156 | found_cache_transform = False 157 | for t in transforms: 158 | if isinstance(t, tfm.Cache): 159 | found_cache_transform = True 160 | cache_transform = t 161 | if found_cache_transform: 162 | self.post_cache_transforms.append(t) 163 | 164 | if not found_cache_transform: 165 | # there is no cache transform 166 | self._cache_populated = False 167 | else: 168 | self._cache_populated = value 169 | cache_transform.cache_size = len(self) 170 | cache_transform.overwrite = not value 171 | 172 | self.post_cache_transforms = tfm.Compose( 173 | self.post_cache_transforms) 174 | 175 | def get_items(self, folder): 176 | """ 177 | This function must be implemented by whatever class inherits BaseDataset. 178 | It should return a list of items in the given folder, each of which is 179 | processed by process_items in some way to produce mixes, sources, class 180 | labels, etc. 181 | 182 | Args: 183 | folder (str): location that should be processed to produce the list of files. 184 | 185 | Returns: 186 | list: list of items that should be processed 187 | """ 188 | raise NotImplementedError() 189 | 190 | def __len__(self): 191 | """ 192 | Gets the length of the dataset (the number of items that will be processed). 193 | 194 | Returns: 195 | int: Length of the dataset (``len(self.items)``). 196 | """ 197 | return len(self.items) 198 | 199 | def __getitem__(self, i): 200 | """ 201 | Processes a single item in ``self.items`` using ``self.process_item``. 202 | The output of ``self.process_item`` is further passed through bunch of 203 | of transforms if they are defined in parallel. If you want to have 204 | a set of transforms that depend on each other, then you should compose them 205 | into a single transforms and then pass it into here. The output of each 206 | transform is added to an output dictionary which is returned by this 207 | function. 208 | 209 | Args: 210 | i (int): Index of the dataset to return. Indexes ``self.items``. 211 | 212 | Returns: 213 | dict: Dictionary with keys and values corresponding to the processed 214 | item after being put through the set of transforms (if any are 215 | defined). 216 | """ 217 | if self.cache_populated: 218 | data = {'index': i} 219 | data = self.post_cache_transforms(data) 220 | else: 221 | data = self.process_item(self.items[i]) 222 | 223 | if not isinstance(data, dict): 224 | raise DataSetException( 225 | "The output of process_item must be a dictionary!") 226 | 227 | if self.transform: 228 | data['index'] = i 229 | data = self.transform(data) 230 | 231 | if not isinstance(data, dict): 232 | raise tfm.TransformException( 233 | "The output of transform must be a dictionary!") 234 | 235 | return data 236 | 237 | def process_item(self, item): 238 | """Each file returned by get_items is processed by this function. For example, 239 | if each file is a json file containing the paths to the mixture and sources, 240 | then this function should parse the json file and load the mixture and sources 241 | and return them. 242 | 243 | Exact behavior of this functionality is determined by implementation by subclass. 244 | 245 | Args: 246 | item (object): the item that will be processed by this function. Input depends 247 | on implementation of ``self.get_items``. 248 | 249 | Returns: 250 | This should return a dictionary that gets processed by the transforms. 251 | """ 252 | raise NotImplementedError() 253 | 254 | def _load_audio_file(self, path_to_audio_file): 255 | """ 256 | Loads audio file at given path. Uses AudioSignal to load the audio data 257 | from disk. 258 | 259 | Args: 260 | path_to_audio_file: relative or absolute path to file to load 261 | 262 | Returns: 263 | AudioSignal: loaded AudioSignal object of path_to_audio_file 264 | """ 265 | audio_signal = AudioSignal(path_to_audio_file) 266 | self._setup_audio_signal(audio_signal) 267 | return audio_signal 268 | 269 | def _load_audio_from_array(self, audio_data, sample_rate=None): 270 | """ 271 | Loads the audio data into an AudioSignal object with the appropriate 272 | sample rate. 273 | 274 | Args: 275 | audio_data (np.ndarray): numpy array containing the samples containing 276 | the audio data. 277 | 278 | sample_rate (int): the sample rate at which to load the audio file. 279 | If None, self.sample_rate or the sample rate of the actual file is used. 280 | Defaults to None. 281 | 282 | Returns: 283 | AudioSignal: loaded AudioSignal object of audio_data 284 | """ 285 | sample_rate = sample_rate if sample_rate else self.sample_rate 286 | audio_signal = AudioSignal( 287 | audio_data_array=audio_data, sample_rate=sample_rate) 288 | self._setup_audio_signal(audio_signal) 289 | return audio_signal 290 | 291 | def _setup_audio_signal(self, audio_signal): 292 | """ 293 | You will want every item from a dataset to be uniform in sample rate, STFT 294 | parameters, and number of channels. This function takes an audio signal 295 | object loaded by the dataset and uses it to set the sample rate, STFT parameters, 296 | and the number of channels. If ``self.sample_rate``, ``self.stft_params``, and 297 | ``self.num_channels`` are set at construction time of the dataset, then the 298 | opposite happens - attributes of the AudioSignal object are set to the desired 299 | values. 300 | 301 | Args: 302 | audio_signal (AudioSignal): AudioSignal object to query to set the parameters 303 | of this dataset or to set the parameters of, according to what is in the 304 | dataset. 305 | """ 306 | if self.sample_rate and self.sample_rate != audio_signal.sample_rate: 307 | if self.strict_sample_rate: 308 | raise DataSetException( 309 | f"All audio files should have been the same sample rate already " 310 | f"because self.strict_sample_rate = True. Please resample or " 311 | f"turn set self.strict_sample_rate = False" 312 | ) 313 | audio_signal.resample(self.sample_rate) 314 | else: 315 | self.sample_rate = audio_signal.sample_rate 316 | 317 | # set audio signal attributes to requested values, if they exist 318 | if self.stft_params: 319 | audio_signal.stft_params = self.stft_params 320 | else: 321 | self.stft_params = audio_signal.stft_params 322 | 323 | if self.num_channels: 324 | if audio_signal.num_channels > self.num_channels: 325 | # pick the first ``self.num_channels`` channels 326 | audio_signal.audio_data = audio_signal.audio_data[:self.num_channels] 327 | elif audio_signal.num_channels < self.num_channels: 328 | warnings.warn( 329 | f"AudioSignal had {audio_signal.num_channels} channels " 330 | f"but self.num_channels = {self.num_channels}. Unsure " 331 | f"of what to do, so warning. You might want to make sure " 332 | f"your dataset is uniform!" 333 | ) 334 | else: 335 | self.num_channels = audio_signal.num_channels 336 | 337 | 338 | class DataSetException(Exception): 339 | """ 340 | Exception class for errors when working with data sets in nussl. 341 | """ 342 | pass 343 | 344 | 345 | class BufferData(BaseDataset): 346 | def __init__(self, folder, to_disk=False, transform=None): 347 | """ 348 | 349 | Args: 350 | folder (string): File path to store the data. Put any string when not saving to disk. 351 | to_disk (bool): When true, data will be saved to disk for inspection, data will also be stored in memory 352 | regardless of whether this is True or False 353 | transform (transforms.* object): A transforms to apply to the output of 354 | ``self.process_item``. If using transforms.Compose, each transform will be 355 | applied in sequence. Defaults to None. 356 | """ 357 | # Circular buffer parameters 358 | self.MAX_BUFFER_ITEMS = constants.MAX_BUFFER_ITEMS 359 | self.ptr = 0 360 | self.items = [] 361 | self.metadata = {} 362 | self.full_buffer = False 363 | self.to_disk = to_disk 364 | self.last_ptr = -1 # To keep track of the latest item in the buffer 365 | 366 | # Make sure the relevant directories exist 367 | if self.to_disk: 368 | if not os.path.exists(constants.DIR_PREV_STATES): 369 | os.mkdir(constants.DIR_PREV_STATES) 370 | if not os.path.exists(constants.DIR_NEW_STATES): 371 | os.mkdir(constants.DIR_NEW_STATES) 372 | if not os.path.exists(constants.DIR_DATASET_ITEMS): 373 | os.mkdir(constants.DIR_DATASET_ITEMS) 374 | 375 | super().__init__(folder=folder, transform=transform) 376 | 377 | def get_items(self, folder): 378 | """ 379 | Superclass: "This function must be implemented by whatever class inherits BaseDataset. 380 | It should return a list of items in the given folder, each of which is 381 | processed by process_items in some way to produce mixes, sources, class 382 | labels, etc." 383 | 384 | Implementation: Adds file paths to items list. Keeps list under MAX_ITEMS. 385 | 386 | Args: 387 | folder (str): location that should be processed to produce the list of files 388 | Returns: 389 | list: list of items (path to json files in our case) that should be processed 390 | """ 391 | return self.items 392 | 393 | def process_item(self, item): 394 | """ 395 | Superclass: Each file returned by get_items is processed by this function. For example, 396 | if each file is a json file containing the paths to the mixture and sources, 397 | then this function should parse the json file and load the mixture and sources 398 | and return them. 399 | Exact behavior of this functionality is determined by implementation by subclass." 400 | 401 | Implementation: read json of format: 402 | {'prev_state': '../data/prev_states/prev8-224.wav', 403 | 'action': 0, 404 | 'reward': -0.1, 405 | 'new_state': '../data/new_states/new8-224.wav'} 406 | 407 | convert the wav files to AudioSignals and return and output dict: 408 | { 409 | 'observations': { 410 | 'prev_state': AudioSignal, 411 | 'new_state': AudioSignal, 412 | } 413 | 'reward': -0.1 414 | 'action': 0 415 | } 416 | 417 | Args: 418 | item (object): the item that will be processed by this function. Input depends 419 | on implementation of ``self.get_items``. 420 | Returns: 421 | This should return a dictionary that gets processed by the transforms. 422 | """ 423 | 424 | # load data from memory 425 | output = item.copy() 426 | prev_state, new_state = output['prev_state'], output['new_state'] 427 | 428 | # convert to output dict format 429 | del output['prev_state'], output['new_state'] 430 | output['prev_state'], output['new_state'] = prev_state, new_state 431 | output['reward'] = np.array([output['reward']], dtype='float32') 432 | output['action'] = np.array([output['action']], dtype='int64') 433 | output['agent_info'] = np.array(output['agent_info'], dtype='float32') 434 | return output 435 | 436 | def random_sample(self, bs): 437 | indices = np.random.choice(len(self.items), bs-1, replace=False) 438 | indices = np.append(indices, self.last_ptr) 439 | return indices 440 | 441 | def append(self, item): 442 | """ 443 | Override the default append function to work as circular buffer 444 | Args: 445 | item (object): Item to append to the list 446 | 447 | Returns: Nothing (Item is appended to the circular buffer in place) 448 | """ 449 | if self.full_buffer: 450 | self.items[self.ptr] = item 451 | self.last_ptr = self.ptr 452 | self.ptr = (self.ptr + 1) % self.MAX_BUFFER_ITEMS 453 | else: 454 | self.items.append(item) 455 | self.last_ptr += 1 456 | if len(self.items) == self.MAX_BUFFER_ITEMS: 457 | self.ptr = 0 458 | self.last_ptr = 0 459 | self.full_buffer = True 460 | 461 | def write_buffer_data(self, prev_state, action, reward, new_state, agent_info, episode, step): 462 | """ 463 | Writes states (AudioSignal objects) to .wav files and stores this buffer data 464 | in json files with the states keys pointing to the .wav files. The json files 465 | are to be read by nussl.datasets.BaseDataset subclass as items. 466 | 467 | E.g. { 468 | 'prev_state': '/path/to/previous/mix.wav', 469 | 'reward': [the reward obtained for reaching current state], 470 | 'action': [the action taken to reach current state from previous state] 471 | 'current_state': '/path/to/current/mix.wav', 472 | } 473 | 474 | The unique file names are structured as path/[prev or new]-[episode #]-[step #] 475 | 476 | Args: 477 | prev_state (nussl.AudioSignal): previous state to be converted and saved as .wav file 478 | action (int): action 479 | reward (int): reward 480 | new_state (nussl.AudioSignal): new state to be converted and saved as wav file 481 | agent_info (list): Consists of agent location and current orientation of the agent 482 | episode (int): which episode we're on, used to create unique file name for state 483 | step (int): which step we're on within episode, used to create unique file name for state 484 | 485 | """ 486 | if episode not in self.metadata: 487 | self.metadata[episode] = 1 488 | else: 489 | self.metadata[episode] += 1 490 | 491 | agent_loc, cur_angle = agent_info 492 | agent_info = np.append(agent_loc, cur_angle) # Jut make it a single list to keep things simple 493 | # create buffer dictionary 494 | buffer_dict = { 495 | 'prev_state': prev_state, 496 | 'action': action, 497 | 'reward': reward, 498 | 'new_state': new_state, 499 | 'agent_info': agent_info 500 | } 501 | self.append(buffer_dict) 502 | 503 | # write data for inspection 504 | if self.to_disk and step == 1: 505 | # Unique file names for each state 506 | cur_file = str(episode) + '-' + str(step) 507 | prev_state_file_path = os.path.join( 508 | constants.DIR_PREV_STATES, 'prev' + cur_file + '.wav' 509 | ) 510 | new_state_file_path = os.path.join( 511 | constants.DIR_NEW_STATES, 'new' + cur_file + '.wav' 512 | ) 513 | dataset_json_file_path = os.path.join( 514 | constants.DIR_DATASET_ITEMS, cur_file + '.json' 515 | ) 516 | 517 | prev_state.write_audio_to_file(prev_state_file_path) 518 | new_state.write_audio_to_file(new_state_file_path) 519 | 520 | # write to json 521 | buffer_dict_json = { 522 | 'prev_state': prev_state_file_path, 523 | 'action': action, 524 | 'reward': reward, 525 | 'new_state': new_state_file_path, 526 | } 527 | 528 | with open(dataset_json_file_path, 'w') as json_file: 529 | json.dump(buffer_dict_json, json_file) 530 | 531 | 532 | class RLDataset(IterableDataset): 533 | """ 534 | Dataset which gets updated as buffer gets filled 535 | """ 536 | def __init__(self, buffer, sample_size): 537 | self.buffer = buffer 538 | self.sample_size = sample_size 539 | 540 | def __iter__(self): 541 | batch_indices = self.buffer.random_sample(self.sample_size) 542 | for index in batch_indices: 543 | yield self.buffer[index] 544 | -------------------------------------------------------------------------------- /src/transforms.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import logging 4 | import random 5 | from collections import OrderedDict 6 | 7 | import torch 8 | import zarr 9 | import numcodecs 10 | import numpy as np 11 | from sklearn.preprocessing import OneHotEncoder 12 | 13 | from nussl.core import utils 14 | 15 | # This is for when you're running multiple 16 | # training threads 17 | if hasattr(numcodecs, 'blosc'): 18 | numcodecs.blosc.use_threads = False 19 | 20 | 21 | def compute_ideal_binary_mask(source_magnitudes): 22 | ibm = ( 23 | source_magnitudes == np.max(source_magnitudes, axis=-1, keepdims=True) 24 | ).astype(float) 25 | 26 | ibm = ibm / np.sum(ibm, axis=-1, keepdims=True) 27 | ibm[ibm <= .5] = 0 28 | return ibm 29 | 30 | 31 | # Keys that correspond to the time-frequency representations after being passed through 32 | # the transforms here. 33 | time_frequency_keys = ['mix_magnitude', 'source_magnitudes', 'ideal_binary_mask', 'weights'] 34 | 35 | 36 | class SumSources(object): 37 | """ 38 | Sums sources together. Looks for sources in ``data[self.source_key]``. If 39 | a source belongs to a group, it is popped from the ``data[self.source_key]`` and 40 | summed with the other sources in the group. If there is a corresponding 41 | group_name in group_names, it is named that in ``data[self.source_key]``. If 42 | group_names are not given, then the names are constructed using the keys 43 | in each group (e.g. `drums+bass+other`). 44 | 45 | If using Scaper datasets, then there may be multiple sources with the same 46 | label but different counts. The Scaper dataset hook organizes the source 47 | dictionary as follows: 48 | 49 | .. code-block:: none 50 | 51 | data['sources] = { 52 | '{label}::{count}': AudioSignal, 53 | '{label}::{count}': AudioSignal, 54 | ... 55 | } 56 | 57 | SumSources sums by source label, so the ``::count`` will be ignored and only the 58 | label part will be used when grouping sources. 59 | 60 | Example: 61 | >>> tfm = transforms.SumSources( 62 | >>> groupings=[['drums', 'bass', 'other]], 63 | >>> group_names=['accompaniment], 64 | >>> ) 65 | >>> # data['sources'] is a dict containing keys: 66 | >>> # ['vocals', 'drums', 'bass', 'other] 67 | >>> data = tfm(data) 68 | >>> # data['sources'] is now a dict containing keys: 69 | >>> # ['vocals', 'accompaniment'] 70 | 71 | Args: 72 | groupings (list): a list of lists telling how to group each sources. 73 | group_names (list, optional): A list containing the names of each group, or None. 74 | Defaults to None. 75 | source_key (str, optional): The key to look for in the data containing the list of 76 | source AudioSignals. Defaults to 'sources'. 77 | 78 | Raises: 79 | TransformException: if groupings is not a list 80 | TransformException: if group_names is not None but 81 | len(groupings) != len(group_names) 82 | 83 | Returns: 84 | data: modified dictionary with summed sources 85 | """ 86 | 87 | def __init__(self, groupings, group_names=None, source_key='sources'): 88 | if not isinstance(groupings, list): 89 | raise TransformException( 90 | f"groupings must be a list, got {type(groupings)}!") 91 | 92 | if group_names: 93 | if len(group_names) != len(groupings): 94 | raise TransformException( 95 | f"group_names and groupings must be same length or " 96 | f"group_names can be None! Got {len(group_names)} for " 97 | f"len(group_names) and {len(groupings)} for len(groupings)." 98 | ) 99 | 100 | self.groupings = groupings 101 | self.source_key = source_key 102 | if group_names is None: 103 | group_names = ['+'.join(groupings[i]) for i in range(len(groupings))] 104 | self.group_names = group_names 105 | 106 | def __call__(self, data): 107 | if self.source_key not in data: 108 | raise TransformException( 109 | f"Expected {self.source_key} in dictionary " 110 | f"passed to this Transform!" 111 | ) 112 | sources = data[self.source_key] 113 | source_keys = [(k.split('::')[0], k) for k in list(sources.keys())] 114 | 115 | for i, group in enumerate(self.groupings): 116 | combined = [] 117 | group_name = self.group_names[i] 118 | for key1 in group: 119 | for key2 in source_keys: 120 | if key2[0] == key1: 121 | combined.append(sources[key2[1]]) 122 | sources.pop(key2[1]) 123 | sources[group_name] = sum(combined) 124 | sources[group_name].path_to_input_file = group_name 125 | 126 | data[self.source_key] = sources 127 | if 'metadata' in data: 128 | if 'labels' in data['metadata']: 129 | data['metadata']['labels'].extend(self.group_names) 130 | 131 | return data 132 | 133 | def __repr__(self): 134 | return ( 135 | f"{self.__class__.__name__}(" 136 | f"groupings = {self.groupings}, " 137 | f"group_names = {self.group_names}, " 138 | f"source_key = {self.source_key}" 139 | f")" 140 | ) 141 | 142 | 143 | class LabelsToOneHot(object): 144 | """ 145 | Takes a data dictionary with sources and their keys and converts the keys to 146 | a one-hot numpy array using the list in data['metadata']['labels'] to figure 147 | out which index goes where. 148 | """ 149 | 150 | def __init__(self, source_key='sources'): 151 | self.source_key = source_key 152 | 153 | def __call__(self, data): 154 | if 'metadata' not in data: 155 | raise TransformException( 156 | f"Expected metadata in data, got {list(data.keys())}") 157 | if 'labels' not in data['metadata']: 158 | raise TransformException( 159 | f"Expected labels in data['metadata'], got " 160 | f"{list(data['metadata'].keys())}") 161 | 162 | enc = OneHotEncoder(categories=[data['metadata']['labels']]) 163 | 164 | sources = data[self.source_key] 165 | source_keys = [k.split('::')[0] for k in list(sources.keys())] 166 | source_labels = [[l] for l in sorted(source_keys)] 167 | 168 | one_hot_labels = enc.fit_transform(source_labels) 169 | data['one_hot_labels'] = one_hot_labels.toarray() 170 | 171 | return data 172 | 173 | 174 | class MagnitudeSpectrumApproximation(object): 175 | """ 176 | Takes a dictionary and looks for two special keys, defined by the 177 | arguments ``mix_key`` and ``source_key``. These default to `mix` and `sources`. 178 | These values of these keys are used to calculate the magnitude spectrum 179 | approximation [1]. The input dictionary is modified to have additional 180 | keys: 181 | 182 | - mix_magnitude: The magnitude spectrogram of the mixture audio signal. 183 | - source_magnitudes: The magnitude spectrograms of each source spectrogram. 184 | - assignments: The ideal binary assignments for each time-frequency bin. 185 | 186 | ``data[self.source_key]`` points to a dictionary containing the source names in 187 | the keys and the corresponding AudioSignal in the values. The keys are sorted 188 | in alphabetical order and then appended to the mask. ``data[self.source_key]`` 189 | then points to an OrderedDict instead, where the keys are in the same order 190 | as in ``data['source_magnitudes']`` and ``data['assignments']``. 191 | 192 | This transform uses the STFTParams that are attached to the AudioSignal objects 193 | contained in ``data[mix_key]`` and ``data[source_key]``. 194 | 195 | [1] Erdogan, Hakan, John R. Hershey, Shinji Watanabe, and Jonathan Le Roux. 196 | "Phase-sensitive and recognition-boosted speech separation using 197 | deep recurrent neural networks." In 2015 IEEE International Conference 198 | on Acoustics, Speech and Signal Processing (ICASSP), pp. 708-712. IEEE, 199 | 2015. 200 | 201 | Args: 202 | mix_key (str, optional): The key to look for in data for the mixture AudioSignal. 203 | Defaults to 'mix'. 204 | source_key (str, optional): The key to look for in the data containing the dict of 205 | source AudioSignals. Defaults to 'sources'. 206 | 207 | Raises: 208 | TransformException: if the expected keys are not in the dictionary, an 209 | Exception is raised. 210 | 211 | Returns: 212 | data: Modified version of the input dictionary. 213 | """ 214 | 215 | def __init__(self, mix_key='mix', source_key='sources'): 216 | self.mix_key = mix_key 217 | self.source_key = source_key 218 | 219 | def __call__(self, data): 220 | if self.mix_key not in data: 221 | raise TransformException( 222 | f"Expected {self.mix_key} in dictionary " 223 | f"passed to this Transform! Got {list(data.keys())}." 224 | ) 225 | 226 | mixture = data[self.mix_key] 227 | mixture.stft() 228 | mix_magnitude = mixture.magnitude_spectrogram_data 229 | 230 | data['mix_magnitude'] = mix_magnitude 231 | 232 | if self.source_key not in data: 233 | return data 234 | 235 | _sources = data[self.source_key] 236 | source_names = sorted(list(_sources.keys())) 237 | 238 | sources = OrderedDict() 239 | for key in source_names: 240 | sources[key] = _sources[key] 241 | data[self.source_key] = sources 242 | 243 | source_magnitudes = [] 244 | for key in source_names: 245 | s = sources[key] 246 | s.stft() 247 | source_magnitudes.append(s.magnitude_spectrogram_data) 248 | 249 | source_magnitudes = np.stack(source_magnitudes, axis=-1) 250 | 251 | data['ideal_binary_mask'] = compute_ideal_binary_mask(source_magnitudes) 252 | data['source_magnitudes'] = source_magnitudes 253 | 254 | return data 255 | 256 | def __repr__(self): 257 | return ( 258 | f"{self.__class__.__name__}(" 259 | f"mix_key = {self.mix_key}, " 260 | f"source_key = {self.source_key}" 261 | f")" 262 | ) 263 | 264 | 265 | class MagnitudeWeights(object): 266 | """ 267 | Applying time-frequency weights to the deep clustering objective results in a 268 | huge performance boost. This transform looks for 'mix_magnitude', which is output 269 | by either MagnitudeSpectrumApproximation or PhaseSensitiveSpectrumApproximation 270 | and puts it into the weights. 271 | 272 | [1] Wang, Zhong-Qiu, Jonathan Le Roux, and John R. Hershey. 273 | "Alternative objective functions for deep clustering." 2018 IEEE International 274 | Conference on Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2018. 275 | 276 | Args: 277 | mix_magnitude_key (str): Which key to look for the mix_magnitude data in. 278 | """ 279 | 280 | def __init__(self, mix_key='mix', mix_magnitude_key='mix_magnitude'): 281 | self.mix_magnitude_key = mix_magnitude_key 282 | self.mix_key = mix_key 283 | 284 | def __call__(self, data): 285 | if self.mix_magnitude_key not in data and self.mix_key not in data: 286 | raise TransformException( 287 | f"Expected {self.mix_magnitude_key} or {self.mix_key} in dictionary " 288 | f"passed to this Transform! Got {list(data.keys())}. " 289 | "Either MagnitudeSpectrumApproximation or " 290 | "PhaseSensitiveSpectrumApproximation should be called " 291 | "on the data dict prior to this transform. " 292 | ) 293 | elif self.mix_magnitude_key not in data: 294 | data[self.mix_magnitude_key] = np.abs(data[self.mix_key].stft()) 295 | 296 | magnitude_spectrogram = data[self.mix_magnitude_key] 297 | weights = magnitude_spectrogram / (np.sum(magnitude_spectrogram) + 1e-6) 298 | weights *= ( 299 | magnitude_spectrogram.shape[0] * magnitude_spectrogram.shape[1] 300 | ) 301 | data['weights'] = np.sqrt(weights) 302 | return data 303 | 304 | 305 | class PhaseSensitiveSpectrumApproximation(object): 306 | """ 307 | Takes a dictionary and looks for two special keys, defined by the 308 | arguments ``mix_key`` and ``source_key``. These default to `mix` and `sources`. 309 | These values of these keys are used to calculate the phase sensitive spectrum 310 | approximation [1]. The input dictionary is modified to have additional 311 | keys: 312 | 313 | - mix_magnitude: The magnitude spectrogram of the mixture audio signal. 314 | - source_magnitudes: The magnitude spectrograms of each source spectrogram. 315 | - assignments: The ideal binary assignments for each time-frequency bin. 316 | 317 | ``data[self.source_key]`` points to a dictionary containing the source names in 318 | the keys and the corresponding AudioSignal in the values. The keys are sorted 319 | in alphabetical order and then appended to the mask. ``data[self.source_key]`` 320 | then points to an OrderedDict instead, where the keys are in the same order 321 | as in ``data['source_magnitudes']`` and ``data['assignments']``. 322 | 323 | This transform uses the STFTParams that are attached to the AudioSignal objects 324 | contained in ``data[mix_key]`` and ``data[source_key]``. 325 | 326 | [1] Erdogan, Hakan, John R. Hershey, Shinji Watanabe, and Jonathan Le Roux. 327 | "Phase-sensitive and recognition-boosted speech separation using 328 | deep recurrent neural networks." In 2015 IEEE International Conference 329 | on Acoustics, Speech and Signal Processing (ICASSP), pp. 708-712. IEEE, 330 | 2015. 331 | 332 | Args: 333 | mix_key (str, optional): The key to look for in data for the mixture AudioSignal. 334 | Defaults to 'mix'. 335 | source_key (str, optional): The key to look for in the data containing the list of 336 | source AudioSignals. Defaults to 'sources'. 337 | range_min (float, optional): The lower end to use when truncating the source 338 | magnitudes in the phase sensitive spectrum approximation. Defaults to 0.0 (construct 339 | non-negative masks). Use -np.inf for untruncated source magnitudes. 340 | range_max (float, optional): The higher end of the truncated spectrum. This gets 341 | multiplied by the magnitude of the mixture. Use 1.0 to truncate the source 342 | magnitudes to `max(source_magnitudes, mix_magnitude)`. Use np.inf for untruncated 343 | source magnitudes (best performance for an oracle mask but may be beyond what a 344 | neural network is capable of masking). Defaults to 1.0. 345 | 346 | Raises: 347 | TransformException: if the expected keys are not in the dictionary, an 348 | Exception is raised. 349 | 350 | Returns: 351 | data: Modified version of the input dictionary. 352 | """ 353 | 354 | def __init__(self, mix_key='mix', source_key='sources', 355 | range_min=0.0, range_max=1.0): 356 | self.mix_key = mix_key 357 | self.source_key = source_key 358 | self.range_min = range_min 359 | self.range_max = range_max 360 | 361 | def __call__(self, data): 362 | if self.mix_key not in data: 363 | raise TransformException( 364 | f"Expected {self.mix_key} in dictionary " 365 | f"passed to this Transform! Got {list(data.keys())}." 366 | ) 367 | 368 | mixture = data[self.mix_key] 369 | 370 | mix_stft = mixture.stft() 371 | mix_magnitude = np.abs(mix_stft) 372 | mix_angle = np.angle(mix_stft) 373 | data['mix_magnitude'] = mix_magnitude 374 | 375 | if self.source_key not in data: 376 | return data 377 | 378 | _sources = data[self.source_key] 379 | source_names = sorted(list(_sources.keys())) 380 | 381 | sources = OrderedDict() 382 | for key in source_names: 383 | sources[key] = _sources[key] 384 | data[self.source_key] = sources 385 | 386 | source_angles = [] 387 | source_magnitudes = [] 388 | for key in source_names: 389 | s = sources[key] 390 | _stft = s.stft() 391 | source_magnitudes.append(np.abs(_stft)) 392 | source_angles.append(np.angle(_stft)) 393 | 394 | source_magnitudes = np.stack(source_magnitudes, axis=-1) 395 | source_angles = np.stack(source_angles, axis=-1) 396 | range_min = self.range_min 397 | range_max = self.range_max * mix_magnitude[..., None] 398 | 399 | # Section 3.1: https://arxiv.org/pdf/1909.08494.pdf 400 | source_magnitudes = np.minimum( 401 | np.maximum( 402 | source_magnitudes * np.cos(source_angles - mix_angle[..., None]), 403 | range_min 404 | ), 405 | range_max 406 | ) 407 | 408 | data['ideal_binary_mask'] = compute_ideal_binary_mask(source_magnitudes) 409 | data['source_magnitudes'] = source_magnitudes 410 | 411 | return data 412 | 413 | def __repr__(self): 414 | return ( 415 | f"{self.__class__.__name__}(" 416 | f"mix_key = {self.mix_key}, " 417 | f"source_key = {self.source_key}" 418 | f")" 419 | ) 420 | 421 | 422 | class IndexSources(object): 423 | """ 424 | Takes in a dictionary containing Torch tensors or numpy arrays and extracts the 425 | indexed sources from the set key (usually either `source_magnitudes` or 426 | `ideal_binary_mask`). Can be used to train single-source separation models 427 | (e.g. mix goes in, vocals come out). 428 | 429 | You need to know which slice of the source magnitudes or ideal binary mask arrays 430 | to extract. The order of the sources in the source magnitudes array will be in 431 | alphabetical order according to their source labels. 432 | 433 | For example, if source magnitudes has shape `(257, 400, 1, 4)`, and the data is 434 | from MUSDB, then the four possible source labels are bass, drums, other, and vocals. 435 | The data in source magnitudes is in alphabetical order, so: 436 | 437 | .. code-block:: python 438 | 439 | # source_magnitudes is an array returned by either MagnitudeSpectrumApproximation 440 | # or PhaseSensitiveSpectrumApproximation 441 | source_magnitudes[..., 0] # bass spectrogram 442 | source_magnitudes[..., 1] # drums spectrogram 443 | source_magnitudes[..., 2] # other spectrogram 444 | source_magnitudes[..., 3] # vocals spectrogram 445 | 446 | # ideal_binary_mask is an array returned by either MagnitudeSpectrumApproximation 447 | # or PhaseSensitiveSpectrumApproximation 448 | ideal_binary_mask[..., 0] # bass ibm mask 449 | ideal_binary_mask[..., 1] # drums ibm mask 450 | ideal_binary_mask[..., 2] # other ibm mask 451 | ideal_binary_mask[..., 3] # vocals ibm mask 452 | 453 | You can apply this transform to either the `source_magnitudes` or the 454 | `ideal_binary_mask` or both. 455 | 456 | 457 | Args: 458 | object ([type]): [description] 459 | """ 460 | 461 | def __init__(self, target_key, index): 462 | self.target_key = target_key 463 | self.index = index 464 | 465 | def __call__(self, data): 466 | if self.target_key not in data: 467 | raise TransformException( 468 | f"Expected {self.target_key} in dictionary, got {list(data.keys())}") 469 | if self.index >= data[self.target_key].shape[-1]: 470 | raise TransformException( 471 | f"Shape of data[{self.target_key}] is {data[self.target_key].shape} " 472 | f"but index = {self.index} out of bounds bounds of last dim.") 473 | data[self.target_key] = data[self.target_key][..., self.index, None] 474 | return data 475 | 476 | 477 | class GetExcerpt(object): 478 | """ 479 | Takes in a dictionary containing Torch tensors or numpy arrays and extracts an 480 | excerpt from each tensor corresponding to a spectral representation of a specified 481 | length in frames. Can be used to get L-length spectrograms from mixture and source 482 | spectrograms. If the data is shorter than the specified length, it 483 | is padded to the specified length. If it is longer, a random offset between 484 | ``(0, data_length - specified_length)`` is chosen. This function assumes that 485 | it is being passed data AFTER ToSeparationModel. Thus the time dimension is 486 | on axis=1. 487 | 488 | Args: 489 | excerpt_length (int): Specified length of transformed data in frames. 490 | 491 | time_dim (int): Which dimension time is on (excerpts are taken along this axis). 492 | Defaults to 0. 493 | 494 | time_frequency_keys (list): Which keys to look at it in the data dictionary to 495 | take excerpts from. 496 | """ 497 | 498 | def __init__(self, excerpt_length, time_dim=0, 499 | tf_keys=None): 500 | self.excerpt_length = excerpt_length 501 | self.time_dim = time_dim 502 | self.time_frequency_keys = tf_keys if tf_keys else time_frequency_keys 503 | # print('time_freqency_keys:', self.time_frequency_keys) 504 | 505 | @staticmethod 506 | def _validate(data, key): 507 | is_tensor = torch.is_tensor(data[key]) 508 | is_array = isinstance(data[key], np.ndarray) 509 | if not is_tensor and not is_array: 510 | raise TransformException( 511 | f"data[{key}] was not a torch Tensor or a numpy array!") 512 | return is_tensor, is_array 513 | 514 | def _get_offset(self, data, key): 515 | self._validate(data, key) 516 | data_length = data[key].shape[self.time_dim] 517 | 518 | if data_length >= self.excerpt_length: 519 | offset = random.randint(0, data_length - self.excerpt_length) 520 | else: 521 | offset = 0 522 | 523 | pad_amount = max(0, self.excerpt_length - data_length) 524 | # print('key:', key) 525 | # print('data len:', data_length) 526 | # print('PAD:', pad_amount) 527 | return offset, pad_amount 528 | 529 | def _construct_pad_func_tuple(self, shape, pad_amount, is_tensor): 530 | if is_tensor: 531 | pad_func = torch.nn.functional.pad 532 | pad_tuple = [0 for _ in range(2 * len(shape))] 533 | pad_tuple[2 * self.time_dim] = pad_amount 534 | pad_tuple = pad_tuple[::-1] 535 | else: 536 | pad_func = np.pad 537 | pad_tuple = [(0, 0) for _ in range(len(shape))] 538 | pad_tuple[self.time_dim] = (0, pad_amount) 539 | return pad_func, pad_tuple 540 | 541 | def __call__(self, data): 542 | offset, pad_amount = self._get_offset( 543 | data, self.time_frequency_keys[0]) 544 | 545 | for key in data: 546 | if key in self.time_frequency_keys: 547 | is_tensor, is_array = self._validate(data, key) 548 | 549 | if pad_amount > 0: 550 | pad_func, pad_tuple = self._construct_pad_func_tuple( 551 | data[key].shape, pad_amount, is_tensor) 552 | data[key] = pad_func(data[key], pad_tuple) 553 | 554 | data[key] = utils._slice_along_dim( 555 | data[key], self.time_dim, offset, offset + self.excerpt_length) 556 | 557 | # to verify mix lengths are the same 558 | # print(key, data[key].size()) 559 | 560 | return data 561 | 562 | 563 | class Cache(object): 564 | """ 565 | The Cache transform can be placed within a Compose transform. The data 566 | dictionary coming into this transform will be saved to the specified 567 | location using ``zarr``. Then instead of computing all of the transforms 568 | before the cache, one can simply read from the cache. The transforms after 569 | this will then be applied to the data dictionary that is read from the 570 | cache. A typical pipeline might look like this: 571 | 572 | .. code-block:: python 573 | 574 | dataset = datasets.Scaper('path/to/scaper/folder') 575 | tfm = transforms.Compose([ 576 | transforms.PhaseSensitiveApproximation(), 577 | transforms.ToSeparationModel(), 578 | transforms.Cache('~/.nussl/cache/tag', overwrite=True), 579 | transforms.GetExcerpt() 580 | ]) 581 | dataset[0] # first time will write to cache then apply GetExcerpt 582 | dataset.cache_populated = True # switches to reading from cache 583 | dataset[0] # second time will read from cache then apply GetExcerpt 584 | dataset[1] # will error out as it wasn't written to the cache! 585 | 586 | dataset.cache_populated = False 587 | for i in range(len(dataset)): 588 | dataset[i] # every item will get written to cache 589 | dataset.cache_populated = True 590 | dataset[1] # now it exists 591 | 592 | dataset = datasets.Scaper('path/to/scaper/folder') # next time around 593 | tfm = transforms.Compose([ 594 | transforms.PhaseSensitiveApproximation(), 595 | transforms.ToSeparationModel(), 596 | transforms.Cache('~/.nussl/cache/tag', overwrite=False), 597 | transforms.GetExcerpt() 598 | ]) 599 | dataset.cache_populated = True 600 | dataset[0] # will read from cache, which still exists from last time 601 | 602 | Args: 603 | object ([type]): [description] 604 | """ 605 | 606 | def __init__(self, location, cache_size=1, overwrite=False): 607 | self.location = location 608 | self.cache_size = cache_size 609 | self.cache = None 610 | self.overwrite = overwrite 611 | 612 | @property 613 | def info(self): 614 | return self.cache.info 615 | 616 | @property 617 | def overwrite(self): 618 | return self._overwrite 619 | 620 | @overwrite.setter 621 | def overwrite(self, value): 622 | self._overwrite = value 623 | self._clear_cache(self.location) 624 | self._open_cache(self.location) 625 | 626 | def _clear_cache(self, location): 627 | if os.path.exists(location): 628 | if self.overwrite: 629 | logging.info( 630 | f"Cache {location} exists and overwrite = True, clearing cache.") 631 | shutil.rmtree(location, ignore_errors=True) 632 | 633 | def _open_cache(self, location): 634 | if self.overwrite: 635 | self.cache = zarr.open(location, mode='w', shape=(self.cache_size,), 636 | chunks=(1,), dtype=object, 637 | object_codec=numcodecs.Pickle(), 638 | synchronizer=zarr.ThreadSynchronizer()) 639 | else: 640 | if os.path.exists(location): 641 | self.cache = zarr.open(location, mode='r', 642 | object_codec=numcodecs.Pickle(), 643 | synchronizer=zarr.ThreadSynchronizer()) 644 | 645 | def __call__(self, data): 646 | if 'index' not in data: 647 | raise TransformException( 648 | f"Expected 'index' in dictionary, got {list(data.keys())}") 649 | index = data['index'] 650 | if self.overwrite: 651 | self.cache[index] = data 652 | data = self.cache[index] 653 | 654 | if not isinstance(data, dict): 655 | raise TransformException( 656 | f"Reading from cache resulted in not a dictionary! " 657 | f"Maybe you haven't written to index {index} yet in " 658 | f"the cache?") 659 | 660 | return data 661 | 662 | 663 | class GetAudio(object): 664 | """ 665 | Extracts the audio from each signal in `mix_key` and `source_key`. 666 | These will be at new keys, called `mix_audio` and `source_audio`. 667 | Can be used for training end-to-end models. 668 | 669 | Args: 670 | mix_key (str, optional): The key to look for in data for the mixture AudioSignal. 671 | Defaults to 'mix'. 672 | source_key (str, optional): The key to look for in the data containing the dict of 673 | source AudioSignals. Defaults to 'sources'. 674 | """ 675 | 676 | def __init__(self, mix_key=['mix'], source_key='sources'): 677 | self.mix_key = mix_key 678 | self.source_key = source_key 679 | 680 | if not isinstance(mix_key, list): 681 | raise TransformException( 682 | f"Expected a list of mix keys" 683 | ) 684 | 685 | def __call__(self, data): 686 | 687 | for key in self.mix_key: 688 | if key not in data: 689 | raise TransformException( 690 | f"Expected {key} in dictionary " 691 | f"passed to this Transform! Got {list(data.keys())}." 692 | ) 693 | 694 | # Check if mix_key is a list of keys 695 | for key in self.mix_key: 696 | new_key = 'mix_audio_' + key 697 | data[new_key] = data[key].audio_data 698 | 699 | if self.source_key not in data: 700 | return data 701 | 702 | _sources = data[self.source_key] 703 | source_names = sorted(list(_sources.keys())) 704 | 705 | source_audio = [] 706 | for key in source_names: 707 | source_audio.append(_sources[key].audio_data) 708 | # sources on last axis 709 | source_audio = np.stack(source_audio, axis=-1) 710 | 711 | data['source_audio'] = source_audio 712 | return data 713 | 714 | 715 | class ToSeparationModel(object): 716 | """ 717 | Takes in a dictionary containing objects and removes any objects that cannot 718 | be passed to SeparationModel (e.g. not a numpy array or torch Tensor). 719 | If these objects are passed to SeparationModel, then an error will occur. This 720 | class should be the last one in your list of transforms, if you're using 721 | this dataset in a DataLoader object for training a network. If the keys 722 | correspond to numpy arrays, they are converted to tensors using 723 | ``torch.from_numpy``. Finally, the dimensions corresponding to time and 724 | frequency are swapped for all the keys in swap_tf_dims, as this is how 725 | SeparationModel expects it. 726 | 727 | Example: 728 | 729 | .. code-block:: none 730 | 731 | data = { 732 | # 2ch spectrogram for mixture 733 | 'mix_magnitude': torch.randn(513, 400, 2), 734 | # 2ch spectrogram for each source 735 | 'source_magnitudes': torch.randn(513, 400, 2, 4) 736 | 'mix': AudioSignal() 737 | } 738 | 739 | tfm = transforms.ToSeparationModel() 740 | data = tfm(data) 741 | 742 | data['mix_magnitude'].shape # (400, 513, 2) 743 | data['source_magnitudes].shape # (400, 513, 2, 4) 744 | 'mix' in data.keys() # False 745 | 746 | 747 | If this class isn't in your transforms list for the dataset, but you are 748 | using it in the Trainer class, then it is added automatically as the 749 | last transform. 750 | """ 751 | 752 | def __init__(self, swap_tf_dims=None): 753 | self.swap_tf_dims = swap_tf_dims if swap_tf_dims else time_frequency_keys 754 | 755 | def __call__(self, data): 756 | keys = list(data.keys()) 757 | for key in keys: 758 | if key != 'index': 759 | is_array = isinstance(data[key], np.ndarray) 760 | if is_array: 761 | data[key] = torch.from_numpy(data[key]) 762 | if not torch.is_tensor(data[key]): 763 | data.pop(key) 764 | if key in self.swap_tf_dims: 765 | data[key] = data[key].transpose(1, 0) 766 | return data 767 | 768 | def __repr__(self): 769 | return f"{self.__class__.__name__}()" 770 | 771 | 772 | class Compose(object): 773 | """Composes several transforms together. Inspired by torchvision implementation. 774 | 775 | Args: 776 | transforms (list of ``Transform`` objects): list of transforms to compose. 777 | 778 | Example: 779 | >>> transforms.Compose([ 780 | >>> transforms.MagnitudeSpectrumApproximation(), 781 | >>> transforms.ToSeparationModel(), 782 | >>> ]) 783 | """ 784 | 785 | def __init__(self, transforms): 786 | self.transforms = transforms 787 | 788 | def __call__(self, data): 789 | for t in self.transforms: 790 | data = t(data) 791 | if not isinstance(data, dict): 792 | raise TransformException( 793 | "The output of every transform must be a dictionary!") 794 | return data 795 | 796 | def __repr__(self): 797 | format_string = self.__class__.__name__ + '(' 798 | for t in self.transforms: 799 | format_string += '\n' 800 | format_string += ' {0}'.format(t) 801 | format_string += '\n)' 802 | return format_string 803 | 804 | 805 | class TransformException(Exception): 806 | """ 807 | Exception class for errors when working with transforms in nussl. 808 | """ 809 | pass 810 | 811 | --------------------------------------------------------------------------------