├── experiment
├── __init__.py
├── allocate.py
├── experiment1.py
└── experiment2.py
├── tests
├── test_models.py
├── test_datasets.py
├── test_utils.py
├── test_audio_env.py
├── test_room_types.py
├── test_agent.py
└── test_transforms.py
├── src
├── __init__.py
├── audio_room
│ ├── envs
│ │ ├── __init__.py
│ │ └── audio_env.py
│ └── __init__.py
├── constants.py
├── audio_processing.py
├── room_types.py
├── plot_runs.py
├── utils.py
├── models.py
├── agent.py
├── datasets.py
└── transforms.py
├── otoworld.png
├── notebooks
├── df_sdr
└── df_sir
├── sounds
├── car
│ └── car_horn.wav
├── siren
│ └── siren.wav
├── samples
│ ├── male
│ │ ├── 051a050a.wav
│ │ ├── 051a050b.wav
│ │ └── 051a050c.wav
│ └── female
│ │ ├── 050a050a.wav
│ │ ├── 050a050b.wav
│ │ └── 050a050c.wav
└── phone
│ └── cellphone_ringing.wav
├── requirements.txt
├── LICENSE
├── .gitignore
└── README.md
/experiment/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/test_models.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.1.0"
--------------------------------------------------------------------------------
/src/audio_room/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .audio_env import AudioEnv
2 |
--------------------------------------------------------------------------------
/otoworld.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pseeth/otoworld/HEAD/otoworld.png
--------------------------------------------------------------------------------
/notebooks/df_sdr:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pseeth/otoworld/HEAD/notebooks/df_sdr
--------------------------------------------------------------------------------
/notebooks/df_sir:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pseeth/otoworld/HEAD/notebooks/df_sir
--------------------------------------------------------------------------------
/sounds/car/car_horn.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pseeth/otoworld/HEAD/sounds/car/car_horn.wav
--------------------------------------------------------------------------------
/sounds/siren/siren.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pseeth/otoworld/HEAD/sounds/siren/siren.wav
--------------------------------------------------------------------------------
/sounds/samples/male/051a050a.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pseeth/otoworld/HEAD/sounds/samples/male/051a050a.wav
--------------------------------------------------------------------------------
/sounds/samples/male/051a050b.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pseeth/otoworld/HEAD/sounds/samples/male/051a050b.wav
--------------------------------------------------------------------------------
/sounds/samples/male/051a050c.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pseeth/otoworld/HEAD/sounds/samples/male/051a050c.wav
--------------------------------------------------------------------------------
/sounds/phone/cellphone_ringing.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pseeth/otoworld/HEAD/sounds/phone/cellphone_ringing.wav
--------------------------------------------------------------------------------
/sounds/samples/female/050a050a.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pseeth/otoworld/HEAD/sounds/samples/female/050a050a.wav
--------------------------------------------------------------------------------
/sounds/samples/female/050a050b.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pseeth/otoworld/HEAD/sounds/samples/female/050a050b.wav
--------------------------------------------------------------------------------
/sounds/samples/female/050a050c.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pseeth/otoworld/HEAD/sounds/samples/female/050a050c.wav
--------------------------------------------------------------------------------
/src/audio_room/__init__.py:
--------------------------------------------------------------------------------
1 | from gym.envs.registration import register
2 |
3 | register(
4 | id="audio-room-v0", entry_point="audio_room.envs:AudioEnv",
5 | )
6 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pyroomacoustics==0.4.1
2 | gym
3 | ffprobe
4 | jupyter
5 | git+git://github.com/nussl/nussl
6 | soxbindings
7 | numba==0.48.0
8 | jupyter
9 | seaborn^XX^
10 |
--------------------------------------------------------------------------------
/tests/test_datasets.py:
--------------------------------------------------------------------------------
1 | # NOTE: to run, need to cd into tests/, then run pytest
2 | import sys
3 | sys.path.append("../src")
4 | sys.path.append('../experiment')
5 |
6 | import room_types
7 | import agent
8 | import constants
9 | import datasets
10 | import experiment1
11 |
12 | def test_buffer_data():
13 | # TODO
14 | pass
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | # NOTE: to run, need to cd into tests/, then run pytest
2 | import sys
3 | sys.path.append("../src")
4 |
5 | import numpy as np
6 |
7 | import utils
8 | import constants
9 |
10 |
11 | def test_choose_random_files():
12 | # test a range of different numbers of files
13 | for num_sources in range(5):
14 | num_sources = np.random.randint(1, 5)
15 | paths = utils.choose_random_files(num_sources=num_sources)
16 |
17 | assert(len(paths) == num_sources)
18 |
19 | # ensure we collect from correct folder
20 | random_file = np.random.choice(paths)
21 | assert(
22 | random_file.startswith(constants.DIR_FEMALE) \
23 | or random_file.startswith(constants.DIR_MALE)
24 | )
25 |
26 | # .wav or .mp3 or.. (need to be audio files)
27 | assert(random_file.endswith(constants.AUDIO_EXTENSION))
28 |
--------------------------------------------------------------------------------
/src/constants.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | # source data
4 | DIR_MALE = "../sounds/dry_recordings/dev/051/"
5 | DIR_FEMALE = "../sounds/dry_recordings/dev/050/"
6 | DIR_CAR = '../sounds/car/'
7 | DIR_PHONE = '../sounds/siren/'
8 | AUDIO_EXTENSION = ".wav"
9 |
10 | # saved data during experiment
11 | DATA_PATH = "../data"
12 | DIR_PREV_STATES = os.path.join(DATA_PATH, 'prev_states/')
13 | DIR_NEW_STATES = os.path.join(DATA_PATH, 'new_states/')
14 | DIR_DATASET_ITEMS = os.path.join(DATA_PATH, 'dataset_items/')
15 | MODEL_SAVE_PATH = '../models/'
16 | DIST_URL = "init_dist_to_target.p"
17 | STEPS_URL = "steps_to_completion.p"
18 | REWARD_URL = "rewards_per_episode.p"
19 | PRETRAIN_PATH = '../models/pretrained.pth'
20 |
21 | # audio stuff
22 | RESAMPLE_RATE = 8000
23 |
24 | # env stuff
25 | DIST_BTWN_EARS = 0.15
26 |
27 | # max and min values of exploration rate
28 | MAX_EPSILON = 0.9
29 | MIN_EPSILON = 0.1
30 |
31 | # reward structure (keep as floats)
32 | STEP_PENALTY = -0.5
33 | TURN_OFF_REWARD = 100.0
34 | ORIENT_PENALTY = -0.1
35 |
36 | # dataset
37 | MAX_BUFFER_ITEMS = 10000
38 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 pseeth
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/tests/test_audio_env.py:
--------------------------------------------------------------------------------
1 | # NOTE: to run, need to cd into tests/, then run pytest
2 | import sys
3 | sys.path.append("../src")
4 |
5 | import gym
6 | import numpy as np
7 |
8 | import audio_room
9 | import utils
10 | import room_types
11 |
12 | def test_audio_env():
13 | # Shoebox Room
14 | room = room_types.ShoeBox(x_length=10, y_length=10)
15 |
16 | agent_loc = np.array([3, 8])
17 |
18 | # Set up the audio/gym environment
19 | env = gym.make(
20 | "audio-room-v0",
21 | room_config=room.generate(),
22 | agent_loc=agent_loc,
23 | corners=room.corners,
24 | max_order=10,
25 | step_size=1.0,
26 | acceptable_radius=0.5,
27 | )
28 |
29 | # store initial room obj
30 | init_room = env.room
31 |
32 | # test step (taking actions)
33 | # remember: 0,0 is at the bottom left
34 | env.step(action=0) # step left
35 | assert(np.allclose(env.agent_loc, np.array([2, 8])))
36 | env.step(action=1) # step right
37 | assert(np.allclose(env.agent_loc, np.array([3, 8])))
38 | env.step(action=2) # step up
39 | assert(np.allclose(env.agent_loc, np.array([3, 9])))
40 | env.step(action=3) # step down
41 | assert(np.allclose(env.agent_loc, np.array([3, 8])))
42 |
43 | # test move function
44 | env._move_agent([5, 5])
45 | assert(env.agent_loc == [5, 5])
46 |
47 | # ensure the room is the same dimensions
48 | # even though its a different q object
49 | new_room = env.room
50 | for idx, wall in enumerate(init_room.walls):
51 | assert(np.allclose(wall.corners, new_room.walls[idx].corners))
52 |
53 | # test reset
54 |
55 |
56 |
--------------------------------------------------------------------------------
/src/audio_processing.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def ipd_ild_features(stft_data, sampling_rate=8000):
6 | """
7 | Computes interphase difference (IPD) and interlevel difference (ILD) for a
8 | stereo spectrogram.
9 | Args:
10 | stft_data (Torch tensor): Tensor of shape batch_size, time_frames, mag+phase, channels, sources
11 | sampling_rate (int): The rate at which data is sampled. Default = 8000 Hz
12 | Returns:
13 | ipd (``Torch tensor`): Interphase difference between two channels
14 | ild (``Torch tensor``): Interlevel difference between two channels
15 | """
16 |
17 | # The mag and phase are concatenated together, so each of them will have half the dimensionality
18 | mag_phase_dim = stft_data.shape[2]//2
19 |
20 | # Separate the data by channels
21 | stft_ch_one = stft_data[:, :, :, 0]
22 | stft_ch_two = stft_data[:, :, :, 1]
23 |
24 | # Calculate ILD over the magnitudes
25 |
26 | # Extract the magnitudes from the stft data
27 | stft_ch_one_mag = stft_ch_one[:, :, 0:mag_phase_dim]
28 | stft_ch_two_mag = stft_ch_two[:, :, 0:mag_phase_dim]
29 | vol = torch.abs(stft_ch_one_mag) + torch.abs(stft_ch_two_mag)
30 | ild = torch.abs(stft_ch_one_mag) / (torch.abs(stft_ch_two_mag) + 1e-4)
31 | ild = 20 * torch.log10(ild + 1e-8)
32 |
33 | # Extract the phase from the stft data
34 | phase_ch_two = stft_ch_one[..., -mag_phase_dim:]
35 | phase_ch_one = stft_ch_two[..., -mag_phase_dim:]
36 |
37 | ipd = torch.fmod(phase_ch_two - phase_ch_one, np.pi)
38 |
39 | # Output shape of ILD and IPD = [batch_size, time_frames, mag_phase_dim, sources]
40 | return ipd, ild, vol
41 |
42 |
--------------------------------------------------------------------------------
/tests/test_room_types.py:
--------------------------------------------------------------------------------
1 | # NOTE: to run, need to cd into tests/, then run pytest
2 | import sys
3 | sys.path.append("../src")
4 |
5 | from pyroomacoustics import ShoeBox, Room
6 |
7 | import room_types
8 | import constants
9 |
10 |
11 | def test_polygon_num_sides():
12 | """
13 | Test Polygon room class with different number of sides
14 | """
15 | for num_sides in range(3, 11):
16 | room = room_types.Polygon(n=num_sides, r=2)
17 | points = room.generate()
18 |
19 | # only x and y-coordinates for 2d polygon
20 | assert(len(points) == 2)
21 |
22 | # create pra room
23 | pra_room = Room.from_corners(points, fs=constants.RESAMPLE_RATE)
24 |
25 | # assert center is inside of room
26 | assert(pra_room.is_inside([room.x_center, room.y_center]))
27 |
28 | # apparently polygons don't need to be convex
29 | # pra_room.convex_hull()
30 | # assert(len(pra_room.obstructing_walls) == 0)
31 |
32 |
33 | def test_shoebox():
34 | """
35 | Testing (our) ShoeBox room class
36 | """
37 | # 2d shoebox room is a rectangle (4 walls)
38 | room = room_types.ShoeBox()
39 | points = room.generate()
40 |
41 | # sanity, rectangles only have two different lengths
42 | assert(len(points) == 2)
43 |
44 | # ensure class variables equal points returned by generate
45 | assert(room.x_length == points[0])
46 | assert(room.y_length == points[1])
47 |
48 | # not sure how this would fail
49 | assert((room.x_length * room.y_length) == (points[0] * points[1]))
50 |
51 | # create pra room
52 | pra_room = ShoeBox(points)
53 |
54 | # test whether it is a convex hull (this should be ensured by pra)
55 | pra_room.convex_hull()
56 | assert(len(pra_room.obstructing_walls) == 0)
57 |
--------------------------------------------------------------------------------
/tests/test_agent.py:
--------------------------------------------------------------------------------
1 | # NOTE: to run, need to cd into tests/, then run pytest
2 | import sys
3 | sys.path.append("../src")
4 |
5 | import gym
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | import torch
9 |
10 | import room_types
11 | import agent
12 | import audio_room
13 | import utils
14 | import constants
15 | import nussl
16 | from datasets import BufferData
17 | from agent import RandomAgent
18 |
19 |
20 | def test_experiment_shoebox():
21 | """
22 | Testing a run with ShoeBox room
23 |
24 | TODO
25 | """
26 | # Shoebox Room
27 | room = room_types.ShoeBox(x_length=10, y_length=10)
28 |
29 | agent_loc = np.array([3, 8])
30 |
31 | # Set up the gym environment
32 | env = gym.make(
33 | "audio-room-v0",
34 | room_config=room.generate(),
35 | agent_loc=agent_loc,
36 | corners=room.corners,
37 | max_order=10,
38 | step_size=1.0,
39 | acceptable_radius=0.8,
40 | )
41 |
42 | # create buffer data folders
43 | utils.create_buffer_data_folders()
44 |
45 | tfm = nussl.datasets.transforms.Compose([
46 | nussl.datasets.transforms.GetAudio(mix_key='new_state'),
47 | nussl.datasets.transforms.ToSeparationModel(),
48 | nussl.datasets.transforms.GetExcerpt(excerpt_length=32000, tf_keys=['mix_audio'], time_dim=1),
49 | ])
50 |
51 | # create dataset object (subclass of nussl.datasets.BaseDataset)
52 | dataset = BufferData(folder=constants.DIR_DATASET_ITEMS, to_disk=True, transform=tfm)
53 |
54 | # Load the agent class
55 | a = agent.RandomAgent(env=env, dataset=dataset, episodes=2, max_steps=10, plot_reward_vs_steps=False)
56 | a.fit()
57 |
58 | # what should we assert?
59 | #assert()
60 |
61 |
62 | def test_experiment_polygon():
63 | """
64 | Testing a run with Polygon room
65 |
66 | TODO
67 | """
68 | # Uncomment for Polygon Room
69 | room = room_types.Polygon(n=6, r=2, x_center=5, y_center=5)
70 |
--------------------------------------------------------------------------------
/src/room_types.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | """
4 | Currently, we are only working with rooms that are convex and 2D.
5 | """
6 |
7 | class Polygon:
8 | def __init__(self, n, r, x_center=0, y_center=0, theta=0):
9 | """
10 | This class represents a polygon room.
11 | pyroomacoustics.Room.from_corners() is called to create a polygon room.
12 |
13 | Args:
14 | n (int): Number of vertices of the polygon
15 | r (int): Radius of the polygon
16 | x_center (int): x coordinate of center of polygon
17 | y_center (int): y coordinate of center of polygon
18 | theta (int): polygon construction parameter
19 | """
20 | self.n = n
21 | self.r = r
22 | self.x_center = x_center
23 | self.y_center = y_center
24 | assert(theta >= 0.0)
25 | self.theta = theta
26 | self.corners = True
27 |
28 | def generate(self):
29 | """
30 | This function generates a polygon and returns the points
31 |
32 | Returns:
33 | x_points (List[np.array]): x-coordinates of points that makeup the room
34 | y_points (List[np.array]): y-coordinates of points that makeup the room
35 | """
36 | numbers = np.array([i for i in range(self.n)])
37 |
38 | x_points = (
39 | self.r * (2 * np.cos((2 * np.pi * numbers) / self.n + self.theta)) + self.x_center
40 | )
41 | y_points = (
42 | self.r * (2 * np.sin((2 * np.pi * numbers) / self.n) + self.theta) + self.y_center
43 | )
44 |
45 | return [x_points, y_points]
46 |
47 |
48 | class ShoeBox:
49 | def __init__(self, x_length=10, y_length=10):
50 | """
51 | This class represents a shoe box (rectangular) room. It is a wrapper
52 | for the pyroomacoustics.Room.ShoeBox class. We are sticking with
53 | 2D rooms (4 walls) for now, though PRA supports 3D rooms (6 walls)
54 |
55 | Args:
56 | x_length (float): the horizontal length of the room
57 | y_length (float): the vertical length of the room
58 | """
59 | self.x_length = x_length
60 | self.y_length = y_length
61 | self.corners = False
62 |
63 | def generate(self):
64 | """
65 | This function generates a shoebox and returns the points
66 |
67 | Returns:
68 | List[int]: x_length and y_length
69 | """
70 | return [self.x_length, self.y_length]
71 |
--------------------------------------------------------------------------------
/experiment/allocate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | """
4 | Script to wait for an open GPU, then to run the job. Use with
5 | `tsp` (task-spooler) for a common workflow, by queuing a job
6 | that looks for a GPU and then runs the actual job.
7 |
8 | Make this executable: chmod +x allocate.py
9 |
10 | This stacks jobs on a single GPU if memory is available and the
11 | other process belongs to you. Otherwise, it finds a completely
12 | unused GPU to run your command on.
13 |
14 | Usage:
15 |
16 | # allocates 1 gpu for train script
17 | ./allocate.py 1 python train.py ...
18 | # allocates 2 gpus for train script
19 | ./allocate.py 2 python train.py ...
20 |
21 | Requirements:
22 |
23 | pip install nvgpu
24 | """
25 |
26 | import subprocess
27 | import argparse
28 | import time
29 | import logging
30 | import sys
31 | import os, pwd
32 | import nvgpu
33 | from nvgpu.list_gpus import device_statuses
34 |
35 | logging.basicConfig(level=logging.INFO)
36 | mem_threshold = 50
37 |
38 | def run(cmd):
39 | print(cmd)
40 | subprocess.run([cmd], shell=True)
41 |
42 | def _allocate_gpu(num_gpus):
43 | current_user = pwd.getpwuid(os.getuid()).pw_name
44 | gpu_info = nvgpu.gpu_info()
45 | device_info = device_statuses()
46 |
47 | # assume nothing is available
48 | completely_available = [False for _ in gpu_info]
49 | same_user_available = [False for _ in gpu_info]
50 |
51 | for i, (_info, _device) in enumerate(zip(gpu_info, device_info)):
52 | completely_available[i] = _device['is_available']
53 | if _info['mem_used_percent'] < mem_threshold and current_user in _device['users']:
54 | same_user_available[i] = True
55 |
56 | available_gpus = same_user_available
57 | if sum(same_user_available) == 0:
58 | available_gpus = completely_available
59 |
60 | available_gpus = [i for i, val in enumerate(available_gpus) if val]
61 |
62 | return available_gpus[:num_gpus]
63 |
64 | if __name__ == "__main__":
65 | args = sys.argv
66 |
67 | num_gpus = int(sys.argv[1])
68 | cmd = sys.argv[2:]
69 |
70 | available_gpus = _allocate_gpu(num_gpus)
71 |
72 | while len(available_gpus) < num_gpus:
73 | logging.info("Waiting for available GPUs. Checking again in 30 seconds.")
74 | available_gpus = _allocate_gpu(num_gpus)
75 | time.sleep(30)
76 |
77 | available_gpus = ','.join(map(str, available_gpus))
78 | CUDA_VISIBLE_DEVICES = f'CUDA_VISIBLE_DEVICES={available_gpus}'
79 | cmd = ' '.join(cmd)
80 | cmd = f"{CUDA_VISIBLE_DEVICES} {cmd}"
81 | run(cmd)
82 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # vscode stuff
132 | .vscode/
133 |
134 | # dstore
135 | *.DS_Store
136 |
137 | # ide stuff
138 | .idea/
139 |
140 | # nussl tutorial
141 | nussl-tutorial.ipynb
142 | refactor-test.ipynb
143 |
144 | # data
145 | data/
146 | experiment/runs/
147 | models/
148 | sounds/dry_recordings
149 |
150 | *tfevents*
151 |
--------------------------------------------------------------------------------
/tests/test_transforms.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("../src")
3 |
4 | from pyroomacoustics import ShoeBox, Room
5 |
6 |
7 | import gym
8 | import numpy as np
9 | import matplotlib.pyplot as plt
10 | import torch
11 | import room_types
12 | import agent
13 | import audio_room
14 | import utils
15 | import constants
16 | import nussl
17 | from datasets import BufferData
18 | import time
19 | import audio_processing
20 | from models import RnnAgent
21 | import transforms
22 | from transforms import GetExcerpt
23 |
24 | """
25 | Takes too long to run and isn't a great test. Assertion will be added in GetExcerpt to ensure
26 | the mixes for the prev and new state are the same length
27 | """
28 |
29 | def test_mix_lengths():
30 | pass
31 | # # Shoebox Room
32 | # room = room_types.ShoeBox(x_length=10, y_length=10)
33 |
34 | # # Uncomment for Polygon Room
35 | # # room = room_types.Polygon(n=6, r=2, x_center=5, y_center=5)
36 |
37 | # agent_loc = np.array([3, 8])
38 |
39 | # # Set up the gym environment
40 | # env = gym.make(
41 | # "audio-room-v0",
42 | # room_config=room.generate(),
43 | # agent_loc=agent_loc,
44 | # corners=room.corners,
45 | # max_order=10,
46 | # step_size=1.0,
47 | # acceptable_radius=0.8,
48 | # )
49 |
50 | # # create buffer data folders
51 | # utils.create_buffer_data_folders()
52 |
53 | # tfm = transforms.Compose([
54 | # transforms.GetAudio(mix_key=['prev_state', 'new_state']),
55 | # transforms.ToSeparationModel(),
56 | # transforms.GetExcerpt(excerpt_length=32000,
57 | # tf_keys=['mix_audio_prev_state', 'mix_audio_new_state'], time_dim=1),
58 | # ])
59 |
60 | # dataset = BufferData(folder=constants.DIR_DATASET_ITEMS, to_disk=False, transform=tfm)
61 |
62 | # # run really short experiment to generate data
63 | # # Define the relevant dictionaries
64 | # env_config = {'env': env, 'dataset': dataset, 'episodes': 3, 'max_steps': 3, 'plot_reward_vs_steps': False,
65 | # 'stable_update_freq': 3, 'epsilon': 0.7, 'save_freq': 1}
66 | # dataset_config = {'batch_size': 3, 'num_updates': 2, 'save_path': '../models/'}
67 | # rnn_agent = RnnAgent(env_config=env_config, dataset_config=dataset_config)
68 |
69 | # rnn_agent.fit()
70 |
71 | # # test mix lengths, want to be equal (THIS MAY TAKE A WHILE)
72 | # # this test is not perfect, may get lucky and have all dataset items be same length, but it's already expensive to compute
73 | # for t in tfm.transforms:
74 | # lengths = []
75 | # print(len(dataset))
76 | # for i in range(len(dataset)):
77 | # if isinstance(t, GetExcerpt):
78 | # data = t(dataset[i])
79 | # for k, v in data.items():
80 | # if k in t.time_frequency_keys:
81 | # lengths.append(v.size())
82 | # print(lengths)
83 | # assert(lengths[0] == lengths[1])
84 |
85 | # test_mix_lengths()
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
OtoWorld
2 |
3 | OtoWorld is an interactive environment in which agents must learn to listen in order to solve navigational tasks. The purpose of OtoWorld is to facilitate reinforcement learning research in computer audition, where agents must learn to listen to the world around them to navigate.
4 |
5 | **Note:** Currently the focus is on audio source separation.
6 |
7 | OtoWorld is built on three open source libraries: OpenAI [`gym`](https://gym.openai.com/) for environment and agent interaction, [`pyroomacoustics`](https://github.com/LCAV/pyroomacoustics) for ray-tracing and acoustics simulation, and [`nussl`](https://github.com/nussl/nussl) for training deep computer audition models. OtoWorld is the audio analogue of GridWorld, a simple navigation game. OtoWorld can be easily extended to more complex environments and games.
8 |
9 | To solve one episode of OtoWorld, an agent must move towards each sounding source in the auditory scene and "turn it off". The agent receives no other input than the current sound of the room. The sources are placed randomly within the room and can vary in number. The agent receives a reward for turning off a source.
10 |
11 | [Read the OtoWorld Paper here](https://arxiv.org/abs/2007.06123)
12 |
13 |
14 |
15 | 
16 |
17 |
18 | ## Installation
19 | Clone the repository
20 | ```
21 | git clone https://github.com/pseeth/otoworld.git
22 | ```
23 | Create a conda environment:
24 | ```
25 | conda create -n otoworld python==3.7
26 | ```
27 | Activate the environment:
28 | ```
29 | conda activate otoworld
30 | ```
31 | Install requirements:
32 | ```
33 | pip install -r requirements.txt
34 | ```
35 | Install ffmpeg from conda distribution (Note: Pypi distribution of ffmpeg is outdated):
36 | ```
37 | conda install ffmpeg
38 | ```
39 | If using a **CUDA-enabled GPU (highly recommended)**, install Pytorch `1.4` from official source:
40 | ```
41 | pip install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html
42 | ```
43 | otherwise:
44 | ```
45 | pip install torch==1.4.0 torchvision==0.5.0
46 | ```
47 |
48 | You may ignore warnings if about certain dependencies for now.
49 |
50 |
51 | ## Additional Installation Notes - Linux
52 | * Linux users may need to install the sound file library if it is not present in the system. It can be done using the following command:
53 | ```
54 | sudo apt-get install libsndfile1
55 | ```
56 |
57 | This should take care of a common `musdb` error.
58 |
59 | ## Demo and Tutorial
60 | You can get familiar with OtoWorld using our tutorial notebook: [Tutorial Notebook](https://github.com/pseeth/otoworld/blob/master/notebooks/tutorial.ipynb).
61 |
62 | Run
63 | ```
64 | jupyter notebook
65 | ```
66 | and navigate to `notebooks/tutorial.ipynb`.
67 |
68 | ## Experiments
69 | You can view (and run) examples of experiments:
70 | ```
71 | cd experiments/
72 |
73 | python experiment1.py
74 | ```
75 |
76 | Please create your own experiments and see if you can win OtoWorld! You will need a GPU running CUDA to be able to perform any meaningful experiments.
77 |
78 | ## Is It Running Properly?
79 | You should a message indicating the experiment is running, such as this:
80 | ```
81 | ------------------------------
82 | - Starting to Fit Agent
83 | -------------------------------
84 | ```
85 |
86 | You may get a warning about `SoX`. Ignore this for now. You're good to go!
87 |
88 | ## Citing
89 | ```
90 | @inproceedings {otoworld
91 | author = {Omkar Ranadive and Grant Gasser and David Terpay and Prem Seetharaman},
92 | title = "OtoWorld: Towards Learning to Separate by Learning to Move",
93 | journal = "Self Supervision in Audio and Speech Workshop, 37th International Conference on Machine Learning ({ICML} 2020), Vienna, Austria",
94 | year = 2020
95 | }
96 | ```
97 |
--------------------------------------------------------------------------------
/experiment/experiment1.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("../src/")
3 | from datetime import datetime
4 | import os
5 |
6 | import gym
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 | import torch
10 | import logging
11 | from torch.utils.tensorboard import SummaryWriter
12 |
13 | import room_types
14 | import agent
15 | import audio_room
16 | import utils
17 | import constants
18 | import nussl
19 | from datasets import BufferData
20 | import time
21 | import audio_processing
22 | from models import RnnAgent
23 | import transforms
24 |
25 | import warnings
26 | warnings.filterwarnings("ignore")
27 |
28 | """
29 | One of our main experiments for OtoWorld introductory paper
30 | """
31 |
32 | # Shoebox Room
33 | nussl.utils.seed(0)
34 | room = room_types.ShoeBox(x_length=8, y_length=8)
35 |
36 | # Uncomment for Polygon Room
37 | #room = room_types.Polygon(n=6, r=2, x_center=5, y_center=5)
38 |
39 |
40 | source_folders_dict = {'../sounds/phone/': 1,
41 | '../sounds/siren/': 1}
42 |
43 | # Set up the gym environment
44 | env = gym.make(
45 | "audio-room-v0",
46 | room_config=room.generate(),
47 | source_folders_dict=source_folders_dict,
48 | corners=room.corners,
49 | max_order=10,
50 | step_size=.5,
51 | acceptable_radius=1.0,
52 | absorption=1.0,
53 | )
54 | env.seed(0)
55 |
56 | # create buffer data folders
57 | utils.create_buffer_data_folders()
58 |
59 | # fixing lengths
60 | tfm = transforms.Compose([
61 | transforms.GetAudio(mix_key=['prev_state', 'new_state']),
62 | transforms.ToSeparationModel(),
63 | transforms.GetExcerpt(excerpt_length=32000,
64 | tf_keys=['mix_audio_prev_state'], time_dim=1),
65 | transforms.GetExcerpt(excerpt_length=32000,
66 | tf_keys=['mix_audio_new_state'], time_dim=1)
67 | ])
68 |
69 | # create dataset object (subclass of nussl.datasets.BaseDataset)
70 | dataset = BufferData(
71 | folder=constants.DIR_DATASET_ITEMS,
72 | to_disk=True,
73 | transform=tfm
74 | )
75 |
76 | # define tensorboard writer, name the experiment!
77 | exp_name = 'pretrain-150eps'
78 | exp_id = '{}_{}'.format(exp_name, datetime.now().strftime('%d_%m_%Y-%H_%M_%S'))
79 | writer = SummaryWriter('runs/{}'.format(exp_id))
80 |
81 |
82 | # Define the relevant dictionaries
83 | env_config = {
84 | 'env': env,
85 | 'dataset': dataset,
86 | 'episodes': 150,
87 | 'max_steps': 1000,
88 | 'stable_update_freq': 150,
89 | 'save_freq': 1,
90 | 'play_audio': False,
91 | 'show_room': False,
92 | 'writer': writer,
93 | 'dense': True,
94 | 'decay_rate': 0.0002, # trial and error
95 | 'decay_per_ep': True
96 | }
97 |
98 | save_path = os.path.join(constants.MODEL_SAVE_PATH, exp_name)
99 | dataset_config = {
100 | 'batch_size': 10,
101 | 'num_updates': 2,
102 | 'save_path': save_path
103 | }
104 |
105 | # clear save_path folder for each experiment
106 | utils.clear_models_folder(save_path)
107 |
108 | rnn_config = {
109 | 'bidirectional': True,
110 | 'dropout': 0.3,
111 | 'filter_length': 256,
112 | 'hidden_size': 50,
113 | 'hop_length': 64,
114 | 'mask_activation': ['softmax'],
115 | 'mask_complex': False,
116 | 'mix_key': 'mix_audio',
117 | 'normalization_class': 'BatchNorm',
118 | 'num_audio_channels': 1,
119 | 'num_filters': 256,
120 | 'num_layers': 1,
121 | 'num_sources': 2,
122 | 'rnn_type': 'lstm',
123 | 'window_type': 'sqrt_hann',
124 | }
125 |
126 | stft_config = {
127 | 'hop_length': 64,
128 | 'num_filters': 256,
129 | 'direction': 'transform',
130 | 'window_type': 'sqrt_hann'
131 | }
132 |
133 | rnn_agent = RnnAgent(
134 | env_config=env_config,
135 | dataset_config=dataset_config,
136 | rnn_config=rnn_config,
137 | stft_config=stft_config,
138 | learning_rate=.001,
139 | pretrained=True
140 | )
141 | torch.autograd.set_detect_anomaly(True)
142 | rnn_agent.fit()
143 |
144 |
--------------------------------------------------------------------------------
/experiment/experiment2.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("../src/")
3 | from datetime import datetime
4 | import os
5 |
6 | import gym
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 | import torch
10 | import logging
11 | from torch.utils.tensorboard import SummaryWriter
12 |
13 | import room_types
14 | import agent
15 | import audio_room
16 | import utils
17 | import constants
18 | import nussl
19 | from datasets import BufferData
20 | import time
21 | import audio_processing
22 | from models import RnnAgent
23 | import transforms
24 |
25 | import warnings
26 | warnings.filterwarnings("ignore")
27 |
28 | """
29 | One of our main experiments for OtoWorld introductory paper
30 | """
31 |
32 | # Shoebox Room
33 | nussl.utils.seed(7)
34 | room = room_types.ShoeBox(x_length=8, y_length=8)
35 |
36 | # Uncomment for Polygon Room
37 | #room = room_types.Polygon(n=6, r=2, x_center=5, y_center=5)
38 |
39 | source_folders_dict = {'../sounds/phone/': 1,
40 | '../sounds/siren/': 1}
41 |
42 | # Set up the gym environment
43 | env = gym.make(
44 | "audio-room-v0",
45 | room_config=room.generate(),
46 | source_folders_dict=source_folders_dict,
47 | corners=room.corners,
48 | max_order=10,
49 | step_size=1.0,
50 | acceptable_radius=1.0,
51 | absorption=1.0,
52 | reset_sources=False,
53 | same_config=True
54 | )
55 | env.seed(7)
56 |
57 | # create buffer data folders
58 | utils.create_buffer_data_folders()
59 |
60 | # fixing lengths
61 | tfm = transforms.Compose([
62 | transforms.GetAudio(mix_key=['prev_state', 'new_state']),
63 | transforms.ToSeparationModel(),
64 | transforms.GetExcerpt(excerpt_length=32000,
65 | tf_keys=['mix_audio_prev_state'], time_dim=1),
66 | transforms.GetExcerpt(excerpt_length=32000,
67 | tf_keys=['mix_audio_new_state'], time_dim=1)
68 | ])
69 |
70 | # create dataset object (subclass of nussl.datasets.BaseDataset)
71 | dataset = BufferData(
72 | folder=constants.DIR_DATASET_ITEMS,
73 | to_disk=True,
74 | transform=tfm
75 | )
76 |
77 | # define tensorboard writer, name the experiment!
78 | exp_name = 'evaluate'
79 | exp_id = '{}_{}'.format(exp_name, datetime.now().strftime('%d_%m_%Y-%H_%M_%S'))
80 | writer = SummaryWriter('runs/{}'.format(exp_id))
81 |
82 | # Define the relevant dictionaries
83 | env_config = {
84 | 'env': env,
85 | 'dataset': dataset,
86 | 'episodes': 200,
87 | 'max_steps': 1000,
88 | 'stable_update_freq': 200,
89 | 'save_freq': 2,
90 | 'play_audio': False,
91 | 'show_room': False,
92 | 'writer': writer,
93 | 'dense': True,
94 | 'decay_rate': 0.01, # trial and error
95 | 'decay_per_ep': True,
96 | 'validation_freq': 5
97 | }
98 |
99 | if env_config['decay_per_ep']:
100 | end_epsilon = constants.MIN_EPSILON + (constants.MAX_EPSILON - constants.MIN_EPSILON) * np.exp(-env_config['decay_rate'] * env_config['episodes'])
101 | print('\nEpsilon value at last episode ({}): {}'.format(env_config['episodes'], end_epsilon))
102 |
103 | save_path = os.path.join(constants.MODEL_SAVE_PATH, exp_name)
104 | dataset_config = {
105 | 'batch_size': 50,
106 | 'num_updates': 2,
107 | 'save_path': save_path
108 | }
109 |
110 | # clear save_path folder for each experiment
111 | utils.clear_models_folder(save_path)
112 |
113 | rnn_config = {
114 | 'bidirectional': True,
115 | 'dropout': 0.3,
116 | 'filter_length': 256,
117 | 'hidden_size': 50,
118 | 'hop_length': 64,
119 | 'mask_activation': ['softmax'],
120 | 'mask_complex': False,
121 | 'mix_key': 'mix_audio',
122 | 'normalization_class': 'BatchNorm',
123 | 'num_audio_channels': 1,
124 | 'num_filters': 256,
125 | 'num_layers': 1,
126 | 'num_sources': 2,
127 | 'rnn_type': 'lstm',
128 | 'window_type': 'sqrt_hann',
129 | }
130 |
131 | stft_config = {
132 | 'hop_length': 64,
133 | 'num_filters': 256,
134 | 'direction': 'transform',
135 | 'window_type': 'sqrt_hann'
136 | }
137 |
138 | rnn_agent = RnnAgent(
139 | env_config=env_config,
140 | dataset_config=dataset_config,
141 | rnn_config=rnn_config,
142 | stft_config=stft_config,
143 | learning_rate=.01,
144 | )
145 | torch.autograd.set_detect_anomaly(True)
146 | rnn_agent.fit()
147 |
--------------------------------------------------------------------------------
/src/plot_runs.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import requests
3 | import pandas as pd
4 | import seaborn as sns
5 | import re
6 |
7 |
8 | port_number = 6006
9 | path = '../models/evaluation_data/'
10 | sns.set(style="whitegrid")
11 | dpi = 300
12 |
13 |
14 | def create_csv(file_name):
15 | # Form two urls - for mean reward and cumulative reward
16 |
17 | csv_url_cumul = "http://localhost:{}/data/plugin/scalars/scalars?tag=Reward%2Fcumulative&run={}&format=csv".format(
18 | port_number, file_name
19 | )
20 | csv_url_mean = "http://localhost:{}/data/plugin/scalars/scalars?tag=Reward%2Fmean_per_episode&run={}&format=csv".format(
21 | port_number, file_name
22 | )
23 |
24 | # Create CSV for mean rewards
25 | req = requests.get(csv_url_mean)
26 | url_content = req.content
27 | csv_file = open('{}.csv.'.format(path+file_name+'_mean'), 'wb')
28 | csv_file.write(url_content)
29 | csv_file.close()
30 |
31 | # Create csv for cumulative rewards
32 | req = requests.get(csv_url_cumul)
33 | url_content = req.content
34 | csv_file = open('{}.csv.'.format(path + file_name+'_cumul'), 'wb')
35 | csv_file.write(url_content)
36 | csv_file.close()
37 |
38 |
39 | def create_plots_single(file_name):
40 |
41 | mean_rewards = pd.read_csv(path+file_name+'_mean.csv')
42 | cumul_rewards = pd.read_csv(path+file_name+'_cumul.csv')
43 |
44 | sns_plot = sns.relplot(x="Step", y="Value", data=cumul_rewards, kind="line")
45 | sns_plot.set(xlabel='Step', ylabel='Cumulative Reward')
46 | sns_plot.savefig(path+file_name+'_cumul.png', dpi=dpi)
47 |
48 | sns_plot = sns.relplot(x="Step", y="Value", data=mean_rewards, kind="line")
49 | sns_plot.set(xlabel='Step', ylabel='Mean Reward')
50 | sns_plot.savefig(path+file_name+'_mean.png', dpi=dpi)
51 |
52 |
53 | def create_plots_multiple(file_names):
54 |
55 | combined_data_mean = pd.DataFrame(columns=['Wall time', 'Step', 'Value', 'Number'])
56 | combined_data_cumul = pd.DataFrame(columns=['Wall time', 'Step', 'Value', 'Number'])
57 | for i, file_name in enumerate(file_names):
58 | mean_rewards = pd.read_csv(path+file_name+'_mean.csv')
59 | cumul_rewards = pd.read_csv(path+file_name+'_cumul.csv')
60 | mean_rewards['Number'] = i
61 | cumul_rewards['Number'] = i
62 | combined_data_mean = pd.concat([combined_data_mean, mean_rewards])
63 | combined_data_cumul = pd.concat([combined_data_cumul, cumul_rewards])
64 |
65 | sns_plot = sns.relplot(x="Step", y="Value", data=combined_data_mean, kind="line", hue="Number")
66 | sns_plot.set(xlabel='Step', ylabel='Mean Reward')
67 |
68 |
69 | sns_plot.savefig(path + '_mean_combined.png', dpi=dpi)
70 |
71 | sns_plot = sns.relplot(x="Step", y="Value", data=combined_data_cumul, kind="line", hue="Number")
72 | sns_plot.set(xlabel='Step', ylabel='Cumulative Reward')
73 | sns_plot.savefig(path+'_cumul-combined.png', dpi=dpi)
74 |
75 |
76 | def generate_data_from_log(file_name):
77 | steps_second = []
78 | finished_steps = []
79 | with open(file_name, 'r') as f:
80 | lines = f.readlines()
81 | counter = 0
82 | while counter < len(lines):
83 | cur_line = lines[counter]
84 | if 'Episode: ' in cur_line:
85 | ep_num = int(re.findall(r'\d+', cur_line)[0])
86 | # Skip validation episodes
87 | if ep_num % 5 == 0:
88 | pass
89 | else:
90 | # Grab the data
91 | counter += 2
92 | cur_line = lines[counter]
93 | finished_count = int(re.findall(r'\d+', cur_line)[0])
94 | finished_steps.append(finished_count)
95 | counter += 2
96 | cur_line = lines[counter]
97 | sps = float(re.findall(r'\d+\.\d+', cur_line)[0])
98 | steps_second.append(sps)
99 | counter += 1
100 |
101 | print(len(finished_steps), len(steps_second))
102 | print(finished_steps)
103 | print(steps_second)
104 |
105 | return finished_steps, steps_second
106 |
107 |
108 | if __name__ == '__main__':
109 | file_names = ["test-exp-5-50eps_test_simp_env_validation-2_15_06_2020-02_33_50",
110 | "exp5-200eps-final-run_15_06_2020-06_30_00"]
111 |
112 | # for file_name in file_names:
113 | # create_csv(file_name=file_name)
114 | # create_plots_single(file_name)
115 | #
116 | # create_plots_multiple(file_names)
117 |
118 | generate_data_from_log(file_name=path+'run.txt')
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import os
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | import shutil
6 | import random
7 | import torch
8 |
9 | import constants
10 |
11 | def autoclip(model, percentile, grad_norms=None):
12 | if grad_norms is None:
13 | grad_norms = []
14 |
15 | def _get_grad_norm(model):
16 | total_norm = 0
17 | for p in model.parameters():
18 | if p.requires_grad and p.grad is not None:
19 | param_norm = p.grad.data.norm(2)
20 | total_norm += param_norm.item() ** 2
21 | total_norm = total_norm ** (1. / 2)
22 | return total_norm
23 |
24 | grad_norms.append(_get_grad_norm(model))
25 | clip_value = np.percentile(grad_norms, percentile)
26 |
27 | torch.nn.utils.clip_grad_norm_(
28 | model.parameters(), clip_value)
29 | return grad_norms
30 |
31 |
32 | def choose_random_files(source_folders_dict):
33 | """
34 | Function returns random source files from provided folders.
35 |
36 | Args:
37 | source_folders_dict (Dict[str, int]): specify how many source files to choose from each folder
38 | e.g.
39 | {
40 | 'car_horn_source_folder': 1,
41 | 'phone_ringing_source_folder': 1
42 | }
43 |
44 | This would choose 1 source file from each folder
45 | Returns:
46 | paths (List[str]): the paths to two wav files
47 | """
48 | paths = []
49 |
50 | for folder, num_sources in source_folders_dict.items():
51 | files = os.listdir(folder)
52 |
53 | source_files = []
54 | random_indices = np.random.permutation(len(files))
55 | for i in random_indices:
56 | if files[i].endswith(constants.AUDIO_EXTENSION):
57 | source_files.append(os.path.join(folder, files[i]))
58 |
59 | if len(source_files) == num_sources:
60 | break
61 |
62 | paths.extend(source_files)
63 |
64 | return paths
65 |
66 |
67 | def log_dist_and_num_steps(init_dist_to_target, steps_to_completion):
68 | """
69 | This function logs the initial distance between agent and target source and number of steps
70 | taken to reach target source. The lists are stored in pickle files. The pairs (dist, steps) are in parallel
71 | lists, indexed by the episode number.
72 |
73 | Args:
74 | init_dist_to_target (List[float]): initial distance between agent and target src (size is number of episodes)
75 | steps_to_completion (List[int]): number of steps it took for agent to get to source
76 | """
77 | # create data folder
78 | if not os.path.exists(constants.DATA_PATH):
79 | os.makedirs(constants.DATA_PATH)
80 |
81 | # write objects
82 | pickle.dump(
83 | init_dist_to_target, open(os.path.join(constants.DATA_PATH, constants.DIST_URL), "wb"),
84 | )
85 | pickle.dump(
86 | steps_to_completion, open(os.path.join(constants.DATA_PATH, constants.STEPS_URL), "wb"),
87 | )
88 |
89 |
90 | def log_reward_vs_steps(rewards_per_episode):
91 | """
92 | This function logs the rewards per episode in order to plot the rewards vs. step for each episode.
93 | The lists are stored in pickle files. The pairs (dist, steps) are in parallel lists, indexed by the
94 | episode number.
95 |
96 | Args:
97 | rewards_per_episode (List[float]): rewards gained per episode
98 | """
99 | # create data folder
100 | if not os.path.exists(constants.DATA_PATH):
101 | os.makedirs(constants.DATA_PATH)
102 |
103 | # write objects
104 | pickle.dump(
105 | rewards_per_episode, open(os.path.join(
106 | constants.DATA_PATH, constants.REWARD_URL), "wb"),
107 | )
108 |
109 |
110 | def plot_reward_vs_steps():
111 | """
112 | Plots the reward vs step for an episode.
113 | """
114 | with open(os.path.join(constants.DATA_PATH, constants.REWARD_URL), "rb") as f:
115 | rewards = pickle.load(f)
116 |
117 | reward = rewards[0]
118 | plt.scatter(list(range(len(reward))), reward)
119 | plt.title("Reward vs. Number of Steps")
120 | plt.xlabel("Step")
121 | plt.ylabel("Reward")
122 |
123 | plt.show()
124 |
125 | def plot_dist_and_steps():
126 | """Plots initial distance and number of steps to reach target"""
127 | with open(os.path.join(constants.DATA_PATH, constants.DIST_URL), "rb") as f:
128 | dist = pickle.load(f)
129 | avg_dist = np.mean(dist)
130 |
131 | with open(os.path.join(constants.DATA_PATH, constants.STEPS_URL), "rb") as f:
132 | steps = pickle.load(f)
133 | avg_steps = np.mean(steps)
134 |
135 | plt.scatter(dist, np.log(steps))
136 | plt.title("Number of Steps and Initial Distance")
137 | plt.xlabel("Euclidean Distance")
138 | plt.ylabel("Log(# of Steps to Reach Target)")
139 | plt.text(
140 | 5,
141 | 600,
142 | "Avg Steps: " + str(int(avg_steps)),
143 | size=15,
144 | rotation=0.0,
145 | ha="right",
146 | va="top",
147 | bbox=dict(boxstyle="square", ec=(1.0, 0.5, 0.5), fc=(1.0, 0.8, 0.8),),
148 | )
149 | plt.text(
150 | 5,
151 | 500,
152 | "Avg Init Dist: " + str(int(avg_dist)),
153 | size=15,
154 | rotation=0.0,
155 | ha="right",
156 | va="top",
157 | bbox=dict(boxstyle="square", ec=(1.0, 0.5, 0.5), fc=(1.0, 0.8, 0.8),),
158 | )
159 |
160 | plt.show()
161 |
162 |
163 | def create_buffer_data_folders():
164 | """Empty and re-create the buffer data folders"""
165 | if os.path.exists(constants.DIR_PREV_STATES):
166 | shutil.rmtree(constants.DIR_PREV_STATES)
167 | os.makedirs(constants.DIR_PREV_STATES)
168 | if os.path.exists(constants.DIR_NEW_STATES):
169 | shutil.rmtree(constants.DIR_NEW_STATES)
170 | os.makedirs(constants.DIR_NEW_STATES)
171 | if os.path.exists(constants.DIR_DATASET_ITEMS):
172 | shutil.rmtree(constants.DIR_DATASET_ITEMS)
173 | os.makedirs(constants.DIR_DATASET_ITEMS)
174 |
175 |
176 | def clear_models_folder(save_path):
177 | """Empty/clear and re-create save_path folder to save models.
178 |
179 | Args:
180 | save_path (str): path to folder where models will be stored
181 | """
182 | if os.path.exists(save_path):
183 | shutil.rmtree(save_path)
184 | os.makedirs(save_path)
185 |
186 |
187 |
--------------------------------------------------------------------------------
/src/models.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gym
3 | import numpy as np
4 | import torch
5 | import logging
6 | import agent
7 | import utils
8 | import constants
9 | import nussl
10 | import audio_processing
11 | import agent
12 | from datasets import BufferData, RLDataset
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 | import torch.optim as optim
16 | import math
17 |
18 |
19 | class RnnAgent(agent.AgentBase):
20 | def __init__(self, env_config, dataset_config, rnn_config=None, stft_config=None,
21 | verbose=False, autoclip_percentile=10, learning_rate=.001, pretrained=False):
22 | """
23 | Args:
24 | env_config (dict): Dictionary containing the audio environment config
25 | dataset_config (dict): Dictionary consisting of dataset related parameters. List of parameters
26 | 'batch_size' : Denotes the batch size of the samples
27 | 'num_updates': Amount of iterations we run the training for in each pass.
28 | Ex - If num_updates = 5 and batch_size = 25, then we run the update process 5 times where in each run we
29 | sample 25 data points.
30 | 'sampler': The sampler to use. Ex - Weighted Sampler, Batch sampler etc
31 |
32 | rnn_config (dict): Dictionary containing the parameters for the model
33 | stft_config (dict): Dictionary containing the parameters for STFT
34 | """
35 |
36 | # Select device
37 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38 | print('DEVICE:', self.device)
39 | # Use default config if configs are not provided by user
40 | if rnn_config is None:
41 | self.rnn_config = nussl.ml.networks.builders.build_recurrent_end_to_end(
42 | bidirectional=True, dropout=0.3, filter_length=256, hidden_size=300,
43 | hop_length=64, mask_activation=['sigmoid'], mask_complex=False, mix_key='mix_audio',
44 | normalization_class='BatchNorm', num_audio_channels=1, num_filters=256,
45 | num_layers=2, num_sources=2, rnn_type='lstm', trainable=False, window_type='sqrt_hann'
46 | )
47 | else:
48 | self.rnn_config = nussl.ml.networks.builders.build_recurrent_end_to_end(**rnn_config)
49 |
50 | if stft_config is None:
51 | self.stft_diff = nussl.ml.networks.modules.STFT(
52 | hop_length=128, filter_length=512,
53 | direction='transform', num_filters=512
54 | )
55 | else:
56 | self.stft_diff = nussl.ml.networks.modules.STFT(**stft_config)
57 |
58 | # Initialize the Agent Base class
59 | super().__init__(**env_config)
60 |
61 | # Uncomment this to find backprop errors
62 | # torch.autograd.set_detect_anomaly(True)
63 |
64 | # Initialize the rnn model
65 | self.rnn_model = RnnSeparator(self.rnn_config).to(self.device)
66 | self.rnn_model_stable = RnnSeparator(self.rnn_config).to(self.device)
67 |
68 | # Load pretrained model
69 | if pretrained:
70 | model_dict = torch.load(constants.PRETRAIN_PATH)
71 | self.rnn_model.rnn_model.load_state_dict(model_dict)
72 |
73 | # Initialize dataset related parameters
74 | self.bs = dataset_config['batch_size']
75 | self.num_updates = dataset_config['num_updates']
76 | self.dynamic_dataset = RLDataset(buffer=self.dataset, sample_size=self.bs)
77 | self.dataloader = torch.utils.data.DataLoader(self.dynamic_dataset, batch_size=self.bs)
78 |
79 | # Initialize network layers for DQN network
80 | filter_length = (stft_config['num_filters'] // 2 + 1) * 2 if stft_config is not None else 514
81 | total_actions = self.env.action_space.n
82 | network_params = {'filter_length': filter_length, 'total_actions': total_actions, 'stft_diff': self.stft_diff}
83 | self.q_net = DQN(network_params).to(self.device)
84 | self.q_net_stable = DQN(network_params).to(self.device) # Fixed Q net
85 |
86 | # Tell optimizer which parameters to learn
87 | if pretrained:
88 | # Freezing the rnn model weights, only optimizing Q-net
89 | params = self.q_net.parameters()
90 | else:
91 | params = list(self.rnn_model.parameters()) + list(self.q_net.parameters())
92 | self.optimizer = optim.Adam(params, lr=learning_rate)
93 |
94 | # Folder path where the model will be saved
95 | self.SAVE_PATH = dataset_config['save_path']
96 | self.grad_norms = None
97 | self.percentile = autoclip_percentile
98 |
99 | def update(self):
100 | """
101 | Runs the main training pipeline. Sends mix to RNN separator, then to the DQN.
102 | Calculates the q-values and the expected q-values, comparing them to get the loss and then
103 | computes the gradient w.r.t to the entire differentiable pipeline.
104 | """
105 | # Run the update only if samples >= batch_size
106 | if len(self.dataset.items) < self.bs:
107 | return
108 |
109 | for index, data in enumerate(self.dataloader):
110 | if index > self.num_updates:
111 | break
112 |
113 | # Get the total number of time steps
114 | total_time_steps = data['mix_audio_prev_state'].shape[-1]
115 |
116 | # Reshape the mixture to pass through the separation model (Convert dual channels into one)
117 | # Also, rename the state to work on to mix_audio so that it can pass through remaining nussl architecture
118 | # Move to GPU
119 | data['mix_audio'] = data['mix_audio_prev_state'].float().view(
120 | -1, 1, total_time_steps).to(self.device)
121 | data['action'] = data['action'].to(self.device)
122 | data['reward'] = data['reward'].to(self.device)
123 | agent_info = data['agent_info'].to(self.device)
124 |
125 | # Get the separated sources by running through RNN separation model
126 | output = self.rnn_model(data)
127 | output['mix_audio'] = data['mix_audio']
128 |
129 | # Pass then through the DQN model to get q-values
130 | q_values = self.q_net(output, agent_info, total_time_steps)
131 | q_values = q_values.gather(1, data['action'])
132 |
133 | with torch.no_grad():
134 | # Now, get q-values for the next-state
135 | # Get the total number of time steps
136 | total_time_steps = data['mix_audio_new_state'].shape[-1]
137 |
138 | # Reshape the mixture to pass through the separation model (Convert dual channels into one)
139 | data['mix_audio'] = data['mix_audio_new_state'].float().view(
140 | -1, 1, total_time_steps).to(self.device)
141 | stable_output = self.rnn_model_stable(data)
142 | stable_output['mix_audio'] = data['mix_audio']
143 | q_values_next = self.q_net_stable(stable_output, agent_info, total_time_steps).max(1)[0].unsqueeze(-1)
144 |
145 | expected_q_values = data['reward'] + self.gamma * q_values_next
146 |
147 | # Calculate loss
148 | loss = F.l1_loss(q_values, expected_q_values)
149 | self.losses.append(loss)
150 | self.writer.add_scalar('Loss/train', loss, len(self.losses))
151 |
152 | # Optimize the model with backprop
153 | self.optimizer.zero_grad()
154 | loss.backward()
155 |
156 | # Applying AutoClip
157 | self.grad_norms = utils.autoclip(self.rnn_model, self.percentile, self.grad_norms)
158 |
159 | # Stepping optimizer
160 | self.optimizer.step()
161 |
162 | def choose_action(self):
163 | """
164 | Runs a forward pass though the RNN separator and then the Q-network. An action is choosen
165 | by taking the argmax of the output vector of the network, where the output is a
166 | probability distribution over the action space (via softmax).
167 |
168 | Returns:
169 | action (int): the argmax of the q-values vector
170 | """
171 | with torch.no_grad():
172 | # Get the latest state from the buffer
173 | data = self.dataset[self.dataset.last_ptr]
174 |
175 | # Perform the forward pass (RNN separator => DQN)
176 | total_time_steps = data['mix_audio_new_state'].shape[-1]
177 | data['mix_audio'] = data['mix_audio_new_state'].float().view(
178 | -1, 1, total_time_steps).to(self.device)
179 | output = self.rnn_model(data)
180 | output['mix_audio'] = data['mix_audio']
181 | agent_info = data['agent_info'].to(self.device)
182 |
183 | # action = argmax(q-values)
184 | q_values = self.q_net(output, agent_info, total_time_steps)
185 | action = q_values.max(1)[1].unsqueeze(-1)
186 | action = action[0].item()
187 |
188 | action = int(action)
189 |
190 | return action
191 |
192 | def update_stable_networks(self):
193 | self.rnn_model_stable.load_state_dict(self.rnn_model.state_dict())
194 | self.q_net_stable.load_state_dict(self.q_net.state_dict())
195 |
196 | def save_model(self, name):
197 | """
198 | Args:
199 | name (str): Name contains the episode information (To give saved models unique names)
200 | """
201 | # Save the parameters for rnn model and q net separately
202 | metadata = {
203 | 'sample_rate': 8000
204 | }
205 | self.rnn_model.rnn_model.save(os.path.join(self.SAVE_PATH, 'sp_' + name), metadata)
206 | torch.save(self.rnn_model.state_dict(), os.path.join(self.SAVE_PATH, 'rnn_' + name))
207 | torch.save(self.q_net.state_dict(), os.path.join(self.SAVE_PATH, 'qnet_' + name))
208 |
209 |
210 | class RnnSeparator(nn.Module):
211 | def __init__(self, rnn_config, verbose=False):
212 | super(RnnSeparator, self).__init__()
213 | self.rnn_model = nussl.ml.SeparationModel(rnn_config, verbose=verbose)
214 |
215 | def forward(self, x):
216 | return self.rnn_model(x)
217 |
218 |
219 | class DQN(nn.Module):
220 | def __init__(self, network_params):
221 | """
222 | The main DQN class, which takes the output of the RNN separator and input and
223 | returns q-values (prob dist of the action space)
224 |
225 | Args:
226 | network_params (dict): Dict of network parameters
227 |
228 | Returns:
229 | q_values (torch.Tensor): q-values
230 | """
231 | super(DQN, self).__init__()
232 |
233 | self.stft_diff = network_params['stft_diff']
234 | self.fc = nn.Linear(9, network_params['total_actions'])
235 | self.prelu = nn.PReLU()
236 |
237 | def forward(self, output, agent_info, total_time_steps):
238 | # Reshape the output again to get dual channels
239 | # Perform short time fourier transform of this output
240 | _, _, nt = output['mix_audio'].shape
241 | output['mix_audio'] = output['mix_audio'].reshape(-1, 2, nt)
242 | stft_data = self.stft_diff(output['mix_audio'], direction='transform')
243 |
244 | # Get the IPD and ILD features from the stft data
245 | _, nt, nf, _, ns = output['mask'].shape
246 | output['mask'] = output['mask'].view(-1, 2, nt, nf, ns)
247 | output['mask'] = output['mask'].max(dim=1)[0]
248 | ipd, ild, vol = audio_processing.ipd_ild_features(stft_data)
249 |
250 | ipd = ipd.unsqueeze(-1)
251 | ild = ild.unsqueeze(-1)
252 | vol = vol.unsqueeze(-1)
253 |
254 | ipd_means = (output['mask'] * ipd).mean(dim=[1, 2]).unsqueeze(1)
255 | ild_means = (output['mask'] * ild).mean(dim=[1, 2]).unsqueeze(1)
256 | vol_means = (output['mask'] * vol).mean(dim=[1, 2]).unsqueeze(1)
257 | X = torch.cat([ipd_means, ild_means, vol_means], dim=1)
258 | X = X.reshape(X.shape[0], -1)
259 | agent_info = agent_info.view(-1, 3)
260 | X = torch.cat((X, agent_info), dim=1)
261 | X = self.prelu(self.fc(X))
262 | q_values = F.softmax(X, dim=1)
263 |
264 | return q_values
265 |
266 | def flatten_features(self, x):
267 | # Flatten all dimensions except the batch
268 | size = x.size()[1:]
269 | num_features = 1
270 | for dimension in size:
271 | num_features *= dimension
272 |
273 | return num_features
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
--------------------------------------------------------------------------------
/src/agent.py:
--------------------------------------------------------------------------------
1 | import time
2 | import logging
3 | import warnings
4 | from collections import deque
5 |
6 | import numpy as np
7 | import gym
8 | from scipy.spatial.distance import euclidean
9 | import nussl
10 |
11 | import utils
12 | import constants
13 | from datasets import BufferData
14 |
15 | # setup logging
16 | logging.basicConfig(level=logging.INFO)
17 | logging_str = (
18 | f"\n\n"
19 | f"------------------------------ \n"
20 | f"- Starting to Fit Agent\n"
21 | f"------------------------------- \n\n"
22 | )
23 | logging.info(logging_str)
24 |
25 |
26 | class AgentBase:
27 | def __init__(
28 | self,
29 | env,
30 | dataset,
31 | episodes=1,
32 | max_steps=100,
33 | gamma=0.98,
34 | alpha=0.001,
35 | decay_rate=0.0005,
36 | stable_update_freq=-1,
37 | save_freq=1,
38 | play_audio=False,
39 | show_room=False,
40 | writer=None,
41 | dense=True,
42 | decay_per_ep=False,
43 | validation_freq=None
44 | ):
45 | """
46 | This class is a base agent class which will be inherited when creating various agents.
47 |
48 | Args:
49 | self.env (gym object): The gym self.environment object which the agent is going to explore
50 | dataset (nussl dataset): Nussl dataset object for experience replay
51 | episodes (int): # of episodes to simulate
52 | max_steps (int): # of steps the agent can take before stopping an episode
53 | gamma (float): Discount factor
54 | alpha (float): Learning rate alpha
55 | decay_rate (float): decay rate for exploration rate (we want to decrease exploration as time proceeds)
56 | stable_update_freq (int): Update frequency value for stable networks (Target networks)
57 | play_audio (bool): choose to play audio at each iteration
58 | show_room (bool): choose to display the configurations and movements within a room
59 | writer (torch.utils.tensorboard.SummaryWriter): for logging to tensorboard
60 | dense (bool): makes the rewards more dense, less sparse
61 | gives reward for distance to closest source every step
62 | decay_per_ep (bool): If set to true, epsilon is decayed per episode, else decayed per step
63 | validation_freq (None or int): if not None, then do a validation run every validation_freq # of episodes
64 | validation run means epsilon=0 (no random actions) and no model update/training
65 | """
66 | self.env = env
67 | self.dataset = dataset
68 | self.episodes = episodes
69 | self.max_steps = max_steps
70 | self.gamma = gamma
71 | self.alpha = alpha
72 | self.epsilon = constants.MAX_EPSILON
73 | self.decay_rate = decay_rate
74 | self.stable_update_freq = stable_update_freq
75 | self.save_freq = save_freq
76 | self.play_audio = play_audio
77 | self.show_room = show_room
78 | self.writer = writer
79 | self.dense = dense
80 | self.decay_per_ep = decay_per_ep
81 | self.validation_freq = validation_freq
82 | self.losses = []
83 | self.cumulative_reward = 0
84 | self.total_experiment_steps = 0
85 | self.mean_episode_reward = []
86 | self.action_memory = deque(maxlen=4)
87 |
88 | # for saving model at best validation episode (as determined by mean reward in the episode)
89 | if self.validation_freq is not None:
90 | self.max_validation_reward = -np.inf
91 |
92 | def fit(self):
93 | for episode in range(1, self.episodes + 1):
94 | # Reset the self.environment and any other variables at beginning of each episode
95 | prev_state = None
96 |
97 | episode_rewards = []
98 | found_sources = []
99 |
100 | # validation episode?
101 | validation_episode = False
102 | if self.validation_freq is not None:
103 | validation_episode = True if (episode % self.validation_freq == 0) else False
104 |
105 | # Measure time to complete the episode
106 | start = time.time()
107 | for step in range(self.max_steps):
108 | self.total_experiment_steps += 1
109 |
110 | # Perform random actions with prob < epsilon
111 | model_action = False
112 | if (np.random.uniform(0, 1) < self.epsilon):
113 | action = self.env.action_space.sample()
114 | else:
115 | model_action = True
116 |
117 | # validation run: no random actions and no model update/training
118 | if validation_episode:
119 | model_action = True
120 |
121 | if model_action:
122 | # For the first two steps (We don't have prev_state, new_state pair), then perform a random action
123 | if step < 2:
124 | action = self.env.action_space.sample()
125 | else:
126 | # This is where agent will actually do something
127 | action = self.choose_action()
128 |
129 | # if same action 4x in a row (to avoid model infinite loop), choose an action randomly
130 | self.action_memory.append(action)
131 | if all(self.action_memory[0] == x for x in self.action_memory):
132 | action = self.env.action_space.sample()
133 | self.action_memory.append(action)
134 |
135 | # Perform the chosen action (NOTE: reward is a dictionary)
136 | new_state, agent_info, reward, won = self.env.step(
137 | action, play_audio=self.play_audio, show_room=self.show_room
138 | )
139 |
140 | # dense vs sparse
141 | total_step_reward = 0
142 | if self.dense:
143 | total_step_reward += sum(reward.values())
144 | else:
145 | total_step_reward += (reward['step_penalty'] + reward['turn_off_reward'])
146 |
147 | # record reward stats
148 | self.cumulative_reward += total_step_reward
149 | episode_rewards.append(total_step_reward)
150 |
151 | if reward['turn_off_reward'] == constants.TURN_OFF_REWARD:
152 | print('In FIT. Received reward: {} at step {}\n'.format(total_step_reward, step))
153 | logging.info(f"In FIT. Received reward {total_step_reward} at step: {step}\n")
154 |
155 | # Perform Update
156 | if not validation_episode:
157 | self.update()
158 |
159 | # store SARS in buffer
160 | if prev_state is not None and new_state is not None and not won:
161 | self.dataset.write_buffer_data(
162 | prev_state, action, total_step_reward, new_state, agent_info, episode, step
163 | )
164 |
165 | # Decay epsilon based on total steps (across all episodes, not within an episode)
166 | if not self.decay_per_ep:
167 | self.epsilon = constants.MIN_EPSILON + (
168 | constants.MAX_EPSILON - constants.MIN_EPSILON
169 | ) * np.exp(-self.decay_rate * self.total_experiment_steps)
170 | if self.total_experiment_steps % 200 == 0:
171 | print("Epsilon decayed to {} at step {} ".format(self.epsilon, self.total_experiment_steps))
172 |
173 | # Update stable networks based on number of steps
174 | if step % self.stable_update_freq == 0:
175 | self.update_stable_networks()
176 |
177 | # Terminate the episode if episode is won or at max steps
178 | if won or (step == self.max_steps - 1):
179 | # terminal state is silence
180 | silence_array = np.zeros_like(prev_state.audio_data)
181 | terminal_silent_state = prev_state.make_copy_with_audio_data(audio_data=silence_array)
182 | self.dataset.write_buffer_data(
183 | prev_state, action, total_step_reward, terminal_silent_state, agent_info, episode, step
184 | )
185 |
186 | # record mean reward for this episode
187 | self.mean_episode_reward = np.mean(episode_rewards)
188 |
189 | if validation_episode:
190 | # new best validation reward
191 | if self.mean_episode_reward > self.max_validation_reward:
192 | self.max_validation_reward = self.mean_episode_reward
193 |
194 | # save best validation model
195 | self.save_model('best_valid_reward.pt')
196 |
197 | if self.writer is not None:
198 | self.writer.add_scalar('Reward/validation_mean_per_episode', self.mean_episode_reward, episode)
199 | elif self.writer is not None:
200 | self.writer.add_scalar('Reward/mean_per_episode', self.mean_episode_reward, episode)
201 | self.writer.add_scalar('Reward/cumulative', self.cumulative_reward, self.total_experiment_steps)
202 |
203 | end = time.time()
204 | total_time = end - start
205 |
206 | # log episode summary
207 | logging_str = (
208 | f"\n\n"
209 | f"Episode Summary \n"
210 | f"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ \n"
211 | f"- Episode: {episode}\n"
212 | f"- Won?: {won}\n"
213 | f"- Finished at step: {step+1}\n"
214 | f"- Time taken: {total_time:04f} \n"
215 | f"- Steps/Second: {float(step+1)/total_time:04f} \n"
216 | f"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ \n"
217 | )
218 | print(logging_str)
219 | logging.info(logging_str)
220 |
221 | # break and go to new episode
222 | break
223 |
224 | prev_state = new_state
225 |
226 | if episode % self.save_freq == 0:
227 | name = 'ep{}.pt'.format(episode)
228 | self.save_model(name)
229 |
230 | # Decay epsilon per episode
231 | if self.decay_per_ep:
232 | self.epsilon = constants.MIN_EPSILON + (
233 | constants.MAX_EPSILON - constants.MIN_EPSILON
234 | ) * np.exp(-self.decay_rate * episode)
235 | print("Decayed epsilon value: {}".format(self.epsilon))
236 |
237 | # Reset the environment
238 | self.env.reset()
239 |
240 | def choose_action(self):
241 | """
242 | This function must be implemented by subclass.
243 | It will choose the action to take at any time step.
244 |
245 | Returns:
246 | action (int): the action to take
247 | """
248 | raise NotImplementedError()
249 |
250 | def update(self):
251 | """
252 | This function must be implemented by subclass.
253 | It will perform an update (e.g. updating Q table or Q network)
254 | """
255 | raise NotImplementedError()
256 |
257 | def update_stable_networks(self):
258 | """
259 | This function must be implemented by subclass.
260 | It will perform an update to the stable networks. I.E Copy values from current network to target network
261 | """
262 | raise NotImplementedError()
263 |
264 | def save_model(self, name):
265 | raise NotImplementedError()
266 |
267 |
268 | class RandomAgent(AgentBase):
269 | def choose_action(self):
270 | """
271 | Since this is a random agent, we just randomly sample our action
272 | every time.
273 | """
274 | return self.env.action_space.sample()
275 |
276 | def update(self):
277 | """
278 | No update for a random agent
279 | """
280 | pass
281 |
282 | def update_stable_networks(self):
283 | pass
284 |
285 | def save_model(self, name):
286 | pass
287 |
288 |
289 | # Create a perfect agent that steps to each of the closest sources one at a time.
290 | class OracleAgent(AgentBase):
291 | """
292 | This agent is a perfect agent. It knows where all of the sources are and
293 | will iteratively go to the closest source at each time step.
294 | """
295 |
296 | def choose_action(self):
297 | """
298 | Since we know all information about the environment, the agent moves the following way.
299 | 1. Find the closest audio source using euclidean distance
300 | 2. Rotate to face the audio source if necessary
301 | 3. Step in the direction of the audio source
302 | """
303 | agent = self.env.agent_loc
304 | radius = self.env.acceptable_radius
305 | action = None
306 |
307 | # find the closest audio source
308 | source, minimum_distance = self.env.source_locs[0], np.linalg.norm(
309 | agent - self.env.source_locs[0])
310 | for s in self.env.source_locs[1:]:
311 | dist = np.linalg.norm(agent - s)
312 | if dist < minimum_distance:
313 | minimum_distance = dist
314 | source = s
315 |
316 | # Determine our current angle
317 | angle = abs(int(np.degrees(self.env.cur_angle))) % 360
318 | # The agent is too far left or right of the source
319 | if np.abs(source[0] - agent[0]) >= radius:
320 | # First check if we need to turn
321 | if angle not in [0, 1, 179, 180, 181]:
322 | action = 2
323 | # We are facing correct way need to move forward or backward
324 | else:
325 | if source[0] < agent[0]:
326 | if angle == 0:
327 | action = 1
328 | else:
329 | action = 0
330 | else:
331 | if angle == 0:
332 | action = 0
333 | else:
334 | action = 1
335 | # Agent is to the right of the source
336 | elif np.abs(source[1] - agent[1]) >= radius:
337 | # First check if we need to turn
338 | if angle not in [89, 90, 91, 269, 270, 271]:
339 | action = 2
340 | # We are facing correct way need to move forward or backward
341 | else:
342 | if source[1] < agent[1]:
343 | if angle == 90:
344 | action = 1
345 | else:
346 | action = 0
347 | else:
348 | if angle == 90:
349 | action = 0
350 | else:
351 | action = 1
352 | else:
353 | action = np.random.randint(0, 4)
354 | return action
355 |
356 | def update(self):
357 | """
358 | Since this is a perfect agent, we do not update any network nor save any models
359 | """
360 | pass
361 |
362 | def update_stable_networks(self):
363 | """
364 | Since this is a perfect agent, we do not update any network nor save any models
365 | """
366 | pass
367 |
368 | def save_model(self, name):
369 | """
370 | Since this is a perfect agent, we do not update any network nor save any models
371 | """
372 | pass
373 |
--------------------------------------------------------------------------------
/src/audio_room/envs/audio_env.py:
--------------------------------------------------------------------------------
1 | import gym
2 | from pyroomacoustics import MicrophoneArray, ShoeBox, Room, linear_2D_array, Constants
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | from gym import spaces
6 | from scipy.spatial.distance import euclidean
7 | from sklearn.metrics.pairwise import euclidean_distances
8 | from copy import deepcopy
9 | import nussl
10 | import sys
11 | sys.path.append("../../")
12 |
13 | import constants
14 | from utils import choose_random_files
15 |
16 | # Suppress 'c' argument errors caused by room.plot()
17 | from matplotlib.axes._axes import _log as matplotlib_axes_logger
18 | matplotlib_axes_logger.setLevel('ERROR')
19 |
20 |
21 | class AudioEnv(gym.Env):
22 | def __init__(
23 | self,
24 | room_config,
25 | source_folders_dict,
26 | resample_rate=8000,
27 | num_channels=2,
28 | bytes_per_sample=2,
29 | corners=False,
30 | absorption=0.0,
31 | max_order=2,
32 | step_size=1,
33 | acceptable_radius=.5,
34 | num_sources=2,
35 | degrees=np.deg2rad(30),
36 | reset_sources=True,
37 | same_config=False
38 | ):
39 | """
40 | This class inherits from OpenAI Gym Env and is used to simulate the agent moving in PyRoom.
41 |
42 | Args:
43 | room_config (List or np.array): dimensions of the room. For Shoebox, in the form of [10,10]. Otherwise,
44 | in the form of [[1,1], [1, 4], [4, 4], [4, 1]] specifying the corners of the room
45 | resample_rate (int): sample rate in Hz
46 | num_channels (int): number of channels (used in playing what the mic hears)
47 | bytes_per_sample (int): used in playing what the mic hears
48 | corners (bool): False if using Shoebox config, otherwise True
49 | absorption (float): Absorption param of the room (how walls absorb sound)
50 | max_order (int): another room parameter
51 | step_size (float): specified step size else we programmatically assign it
52 | acceptable_radius (float): source is considered found/turned off if agent is within this distance of src
53 | num_sources (int): the number of audio sources the agent will listen to
54 | degrees (float): value of degrees to rotate in radians (.2618 radians = 15 degrees)
55 | reset_sources (bool): True if you want to choose different sources when resetting env
56 | source_folders_dict (Dict[str, int]): specify how many source files to choose from each folder
57 | e.g.
58 | {
59 | 'car_horn_source_folder': 1,
60 | 'phone_ringing_source_folder': 1
61 | }
62 | same_config (bool): If set to true, difficulty of env becomes easier - agent initial loc, source placement,
63 | and audio files don't change over episodes
64 | """
65 | self.resample_rate = resample_rate
66 | self.absorption = absorption
67 | self.max_order = max_order
68 | self.audio = []
69 | self.num_channels = num_channels
70 | self.bytes_per_sample = bytes_per_sample
71 | self.num_actions = 4
72 | self.action_space = spaces.Discrete(self.num_actions)
73 | self.action_to_string = {
74 | 0: "Forward",
75 | 1: "Backward",
76 | 2: "Rotate Right",
77 | 3: "Rotate Left",
78 | }
79 | self.corners = corners
80 | self.room_config = room_config
81 | self.acceptable_radius = acceptable_radius
82 | self.step_size = step_size
83 | self.num_sources = num_sources
84 | self.source_locs = None
85 | self.min_size_audio = np.inf
86 | self.degrees = degrees
87 | self.cur_angle = 0
88 | self.reset_sources = reset_sources
89 | self.source_folders_dict = source_folders_dict
90 | self.same_config = same_config
91 |
92 | # choose audio files as sources
93 | self.direct_sources = choose_random_files(self.source_folders_dict)
94 | self.direct_sources_copy = deepcopy(self.direct_sources)
95 |
96 | # create the room and add sources
97 | self._create_room()
98 | self._add_sources()
99 | self._move_agent(new_agent_loc=None, initial_placing=True)
100 |
101 | # If same_config = True, keep track of the generated locations
102 | self.fixed_source_locs = self.source_locs.copy()
103 | self.fixed_agent_locs = self.agent_loc.copy()
104 |
105 | # The step size must be smaller than radius in order to make sure we don't
106 | # overstep a audio source
107 | if self.acceptable_radius < self.step_size / 2:
108 | raise ValueError(
109 | """The threshold radius (acceptable_radius) must be at least step_size / 2. Else, the agent may overstep
110 | an audio source."""
111 | )
112 |
113 | def _create_room(self):
114 | """
115 | This function creates the Pyroomacoustics room with our environment class variables.
116 | """
117 | # non-Shoebox config (corners of room are given)
118 | if self.corners:
119 | self.room = Room.from_corners(
120 | self.room_config, fs=self.resample_rate,
121 | absorption=self.absorption, max_order=self.max_order
122 | )
123 |
124 | # The x_max and y_max in this case would be used to generate agent's location randomly
125 | self.x_min = min(self.room_config[0])
126 | self.y_min = min(self.room_config[1])
127 | self.x_max = max(self.room_config[0])
128 | self.y_max = max(self.room_config[1])
129 |
130 | # ShoeBox config
131 | else:
132 | self.room = ShoeBox(
133 | self.room_config, fs=self.resample_rate,
134 | absorption=self.absorption, max_order=self.max_order
135 | )
136 | self.x_max = self.room_config[0]
137 | self.y_max = self.room_config[1]
138 | self.x_min, self.y_min = 0, 0
139 |
140 | def _move_agent(self, new_agent_loc, initial_placing=False):
141 | """
142 | This function moves the agent to a new location (given by new_agent_loc). It effectively removes the
143 | agent (mic array) from the room and then adds it back in the new location.
144 |
145 | If initial_placing == True, the agent is placed in the room for the first time.
146 |
147 | Args:
148 | new_agent_loc (List[int] or np.array or None): [x,y] coordinates of the agent's new location.
149 | initial_placing (bool): True if initially placing the agent in the room at the beginning of the episode
150 | """
151 | # Placing agent in room for the first time (likely at the beginning of a new episode, after a reset)
152 | if initial_placing:
153 | if new_agent_loc is None:
154 | loc = self._sample_points(1, sources=False, agent=True)
155 | print("Placing agent at {}".format(loc))
156 | self.agent_loc = loc
157 | self.cur_angle = 0 # Reset the orientation of agent back to zero at start of an ep
158 | else:
159 | self.agent_loc = new_agent_loc.copy()
160 | print("Placing agent at {}".format(self.agent_loc))
161 | self.cur_angle = 0
162 | else:
163 | # Set the new agent location (where to move)
164 | self.agent_loc = new_agent_loc
165 |
166 | # Setup microphone in agent location, delete the array at previous time step
167 | self.room.mic_array = None
168 |
169 | if self.num_channels == 2:
170 | # Create the array at current time step (2 mics, angle IN RADIANS, 0.2m apart)
171 | mic = MicrophoneArray(
172 | linear_2D_array(self.agent_loc, 2, self.cur_angle,
173 | constants.DIST_BTWN_EARS), self.room.fs
174 | )
175 | self.room.add_microphone_array(mic)
176 | else:
177 | mic = MicrophoneArray(self.agent_loc.reshape(-1, 1), self.room.fs)
178 | self.room.add_microphone_array(mic)
179 |
180 | def _sample_points(self, num_points, sources=True, agent=False):
181 | """
182 | This function generates randomly sampled points for the sources (or agent) to be placed
183 |
184 | Args:
185 | num_points (int): Number of [x, y] random points to generate
186 | sources (bool): True if generating points for sources (agent must be False)
187 | agent(bool): True if generating points for agent (sources must be False)
188 |
189 | Returns:
190 | sample_points (List[List[int]]): A list of [x,y] points for source location
191 | or
192 | random_point (List[int]): An [x, y] point for agent location
193 | """
194 | assert(sources != agent)
195 | sampled_points = []
196 |
197 | if sources:
198 | angles = np.arange(0, 2 * np.pi, self.degrees).tolist()
199 | while len(sampled_points) < num_points:
200 | chosen_angles = np.random.choice(angles, num_points)
201 | for angle in chosen_angles:
202 | direction = np.random.choice([-1, 1])
203 | distance = np.random.uniform(2 * self.step_size, 5 * self.step_size)
204 | x = (self.x_min + self.x_max) / 2
205 | y = (self.y_min + self.y_max) / 2
206 | x = x + direction * np.cos(angle) * distance
207 | y = y + direction + np.sin(angle) * distance
208 | point = [x, y]
209 | if self.room.is_inside(point, include_borders=False):
210 | accepted = True
211 | if len(sampled_points) > 0:
212 | dist_to_existing = euclidean_distances(
213 | np.array(point).reshape(1, -1), sampled_points)
214 | accepted = dist_to_existing.min() > 2 * self.step_size
215 | if accepted and len(sampled_points) < num_points:
216 | sampled_points.append(point)
217 | return sampled_points
218 | elif agent:
219 | accepted = False
220 | while not accepted:
221 | accepted = True
222 | point = [
223 | np.random.uniform(self.x_min, self.x_max),
224 | np.random.uniform(self.y_min, self.y_max),
225 | ]
226 |
227 | # ensure agent doesn't spawn too close to sources
228 | for source_loc in self.source_locs:
229 | if(euclidean(point, source_loc) < 2.0 * self.acceptable_radius):
230 | accepted = False
231 |
232 | return point
233 |
234 | def _add_sources(self, new_source_locs=None, reset_env=False, removing_source=None):
235 | """
236 | This function adds the sources to the environment.
237 |
238 | Args:
239 | new_source_locs (List[List[int]]): A list consisting of [x, y] coordinates if the programmer wants
240 | to manually set the new source locations
241 | reset_env (bool): Bool indicating whether we reset_env the agents position to be the mean
242 | of all the sources
243 | removing_source (None or int): Value that will tell us if we are removing a source
244 | from sources
245 | """
246 | # Can reset with NEW, randomly sampled sources (typically at the start of a new episode)
247 | if self.reset_sources:
248 | self.direct_sources = choose_random_files(self.source_folders_dict)
249 | else:
250 | self.direct_sources = deepcopy(self.direct_sources_copy)
251 |
252 | if new_source_locs is None:
253 | self.source_locs = self._sample_points(num_points=self.num_sources)
254 | else:
255 | self.source_locs = new_source_locs.copy()
256 |
257 | print("Source locs {}".format(self.source_locs))
258 |
259 | self.audio = []
260 | self.min_size_audio = np.inf
261 | for idx, audio_file in enumerate(self.direct_sources):
262 | # Audio will be automatically re-sampled to the given rate (default sr=8000).
263 | a = nussl.AudioSignal(audio_file, sample_rate=self.resample_rate)
264 | a.to_mono()
265 |
266 | # normalize audio so both sources have similar volume at beginning before mixing
267 | loudness = a.loudness()
268 |
269 | # # mix to reference db
270 | ref_db = -40
271 | db_diff = ref_db - loudness
272 | gain = 10 ** (db_diff / 20)
273 | a = a * gain
274 |
275 | # Find min sized source to ensure something is playing at all times
276 | if len(a) < self.min_size_audio:
277 | self.min_size_audio = len(a)
278 | self.audio.append(a.audio_data.squeeze())
279 |
280 | # add sources using audio data
281 | for idx, audio in enumerate(self.audio):
282 | self.room.add_source(
283 | self.source_locs[idx], signal=audio[: self.min_size_audio])
284 |
285 | def _remove_source(self, index):
286 | """
287 | This function removes a source from the environment
288 |
289 | Args:
290 | index (int): index of the source to remove
291 | """
292 | if index < len(self.source_locs):
293 | src = self.source_locs.pop(index)
294 | src2 = self.direct_sources.pop(index)
295 |
296 | # actually remove source from the room
297 | room_src = self.room.sources.pop(index)
298 |
299 | def step(self, action, play_audio=False, show_room=False):
300 | """
301 | This function simulates the agent taking one step in the environment (and room) given an action:
302 | 0 = Move forward
303 | 1 = Move backward
304 | 2 = Turn right x degrees
305 | 3 = Turn left x degrees
306 |
307 | It calls _move_agent, checks to see if the agent has reached a source, and if not, computes the RIR.
308 |
309 | Args:
310 | action (int): direction agent is to move - 0 (L), 1 (R), 2 (U), 3 (D)
311 | play_audio (bool): whether to play the the mic audio (stored in "data")
312 | show_room (bool): Controls whether room is visually plotted at each step
313 |
314 | Returns:
315 | Tuple of the format List (empty if done, else [data]), reward, done
316 | """
317 | # return reward dictionary for each step
318 | reward = {
319 | 'step_penalty': constants.STEP_PENALTY,
320 | 'turn_off_reward': 0,
321 | 'closest_reward': 0,
322 | 'orient_penalty': 0
323 | }
324 |
325 | #print('Action:', self.action_to_string[action])
326 |
327 | # movement
328 | x, y = self.agent_loc[0], self.agent_loc[1]
329 | done = False
330 |
331 | if action in [0, 1]:
332 | if action == 0:
333 | sign = 1
334 | if action == 1:
335 | sign = -1
336 | x = x + sign * np.cos(self.cur_angle) * self.step_size
337 | y = y + sign * np.sin(self.cur_angle) * self.step_size
338 | elif action == 2:
339 | self.cur_angle = round((self.cur_angle + self.degrees) % (2 * np.pi), 4)
340 | reward['orient_penalty'] = constants.ORIENT_PENALTY
341 | elif action == 3:
342 | self.cur_angle = round((self.cur_angle - self.degrees) % (2 * np.pi), 4)
343 | reward['orient_penalty'] = constants.ORIENT_PENALTY
344 | # Check if the new points lie within the room
345 | try:
346 | if self.room.is_inside([x, y], include_borders=False):
347 | points = np.array([x, y])
348 | else:
349 | points = self.agent_loc
350 | except:
351 | # in case the is_inside func fails
352 | points = self.agent_loc
353 |
354 | # Move agent in the direction of action
355 | self._move_agent(new_agent_loc=points)
356 |
357 | # Check if goal state is reached
358 | for index, source in enumerate(self.source_locs):
359 | # Agent has found the source
360 | if euclidean(self.agent_loc, source) <= self.acceptable_radius:
361 | print(f'Agent has found source {self.direct_sources[index]}. \nAgent loc: {self.agent_loc}, Source loc: {source}')
362 | reward['turn_off_reward'] = constants.TURN_OFF_REWARD
363 | # If there is more than one source, then we want to remove this source
364 | if len(self.source_locs) > 1:
365 | # remove the source (will take effect in the next step)
366 | self._remove_source(index=index)
367 |
368 | # Calculate the impulse response
369 | self.room.compute_rir()
370 | self.room.simulate()
371 | data = self.room.mic_array.signals
372 |
373 | # Convert the data back to Nussl Audio object
374 | data = nussl.AudioSignal(
375 | audio_data_array=data, sample_rate=self.resample_rate)
376 |
377 | if play_audio or show_room:
378 | self.render(data, play_audio, show_room)
379 |
380 | done = False
381 | return data, [self.agent_loc, self.cur_angle], reward, done
382 |
383 | # This was the last source hence we can assume we are done
384 | else:
385 | done = True
386 | self.reset()
387 | return None, [self.agent_loc, self.cur_angle], reward, done
388 |
389 | if not done:
390 | # Calculate the impulse response
391 | self.room.compute_rir()
392 | self.room.simulate()
393 | data = self.room.mic_array.signals
394 |
395 | # Convert data to nussl audio signal
396 | data = nussl.AudioSignal(
397 | audio_data_array=data, sample_rate=self.resample_rate)
398 |
399 | if play_audio or show_room:
400 | self.render(data, play_audio, show_room)
401 |
402 | # be careful not to give too much reward here (i.e. 1/min_dist could be very large if min_dist is quite small)
403 | # related to acceptable radius size because this reward is only given when NOT turning off a source
404 | min_dist = euclidean_distances(
405 | np.array(self.agent_loc).reshape(1, -1), self.source_locs).min()
406 | reward['closest_reward'] = (1 / (min_dist + 1e-4))
407 | #print('agent_loc:', self.agent_loc, 'source_locs:', self.source_locs)
408 | #print('cur angle:', self.cur_angle)
409 | #print('reward:', reward)
410 |
411 | # Return the room rir and convolved signals as the new state
412 | return data, [self.agent_loc, self.cur_angle], reward, done
413 |
414 | def reset(self, removing_source=None):
415 | """
416 | This function re-creates the room, then places sources and agent randomly (but separated) in the room.
417 | To be used after each episode.
418 |
419 | Args:
420 | removing_source (int): Integer that tells us the index of sources that we will be removing
421 | """
422 | # re-create room
423 | self._create_room()
424 |
425 | if not self.same_config:
426 | # randomly add sources to the room
427 | self._add_sources()
428 | # randomly place agent in room at beginning of next episode
429 | self._move_agent(new_agent_loc=None, initial_placing=True)
430 | else:
431 | # place sources and agent at same locations to start every episode
432 | self._add_sources(new_source_locs=self.fixed_source_locs)
433 | self._move_agent(new_agent_loc=self.fixed_agent_locs, initial_placing=True)
434 |
435 | def render(self, data, play_audio, show_room):
436 | """
437 | Play the convolved sound using SimpleAudio.
438 |
439 | Args:
440 | data (AudioSignal): if 2 mics, should be of shape (x, 2)
441 | play_audio (bool): If true, audio will play
442 | show_room (bool): If true, room will be displayed to user
443 | """
444 | if play_audio:
445 | data.embed_audio(display=True)
446 |
447 | # Show the room while the audio is playing
448 | if show_room:
449 | fig, ax = self.room.plot(img_order=0)
450 | plt.pause(1)
451 |
452 | plt.close()
453 |
454 | elif show_room:
455 | fig, ax = self.room.plot(img_order=0)
456 | plt.pause(1)
457 | plt.close()
458 |
--------------------------------------------------------------------------------
/src/datasets.py:
--------------------------------------------------------------------------------
1 | import nussl
2 | import json
3 | import os
4 | import constants
5 | import numpy as np
6 | from torch.utils.data import IterableDataset
7 |
8 | import warnings
9 |
10 | from torch.utils.data import Dataset
11 |
12 | from nussl.core import AudioSignal
13 | import transforms as tfm
14 | import tqdm
15 |
16 |
17 | class BaseDataset(Dataset):
18 | """
19 | The BaseDataset class is the starting point for all dataset hooks
20 | in nussl. To subclass BaseDataset, you only have to implement two
21 | functions:
22 |
23 | - ``get_items``: a function that is passed the folder and generates a
24 | list of items that will be processed by the next function. The
25 | number of items in the list will dictate len(dataset). Must return
26 | a list.
27 | - ``process_item``: this function processes a single item in the list
28 | generated by get_items. Must return a dictionary.
29 |
30 | After process_item is called, a set of Transforms can be applied to the
31 | output of process_item. If no transforms are defined (``self.transforms = None``),
32 | then the output of process_item is returned by self[i]. For implemented
33 | Transforms, see nussl.datasets.transforms. For example,
34 | PhaseSpectrumApproximation will add three new keys to the output dictionary
35 | of process_item:
36 |
37 | - mix_magnitude: the magnitude spectrogram of the mixture
38 | - source_magnitudes: the magnitude spectrogram of each source
39 | - ideal_binary_mask: the ideal binary mask for each source
40 |
41 | The transforms are applied in sequence using transforms.Compose.
42 | Not all sequences of transforms will be valid (e.g. if you pop a key in
43 | one transform but a later transform operates on that key, you will get
44 | an error).
45 |
46 | For examples of subclassing, see ``nussl.datasets.hooks``.
47 |
48 | Args:
49 | folder (str): location that should be processed to produce the list of files
50 |
51 | transform (transforms.* object, optional): A transforms to apply to the output of
52 | ``self.process_item``. If using transforms.Compose, each transform will be
53 | applied in sequence. Defaults to None.
54 |
55 | sample_rate (int, optional): Sample rate to use for each audio files. If
56 | audio file sample rate doesn't match, it will be resampled on the fly.
57 | If None, uses the default sample rate. Defaults to None.
58 |
59 | stft_params (STFTParams, optional): STFTParams object defining window_length,
60 | hop_length, and window_type that will be set for each AudioSignal object.
61 | Defaults to None (32ms window length, 8ms hop, 'hann' window).
62 |
63 | num_channels (int, optional): Number of channels to make each AudioSignal
64 | object conform to. If an audio signal in your dataset has fewer channels
65 | than ``num_channels``, a warning is raised, as the behavior in this case
66 | is undefined. Defaults to None.
67 |
68 | strict_sample_rate (bool, optional): Whether to raise an error if
69 |
70 | Raises:
71 | DataSetException: Exceptions are raised if the output of the implemented
72 | functions by the subclass don't match the specification.
73 | """
74 |
75 | def __init__(self, folder, transform=None, sample_rate=None, stft_params=None,
76 | num_channels=None, strict_sample_rate=True, cache_populated=False):
77 | self.folder = folder
78 | self.items = self.get_items(self.folder)
79 | self.transform = transform
80 |
81 | self.cache_populated = cache_populated
82 |
83 | self.stft_params = stft_params
84 | self.sample_rate = sample_rate
85 | self.num_channels = num_channels
86 | self.strict_sample_rate = strict_sample_rate
87 |
88 | if not isinstance(self.items, list):
89 | raise DataSetException("Output of self.get_items must be a list!")
90 |
91 | # getting one item in order to set up parameters for audio
92 | # signals if necessary, if there are any items
93 | if self.items:
94 | self.process_item(self.items[0])
95 |
96 | def filter_items_by_condition(self, func):
97 | """
98 | Filter the items in the list according to a function that takes
99 | in both the dataset as well as the item currently be processed.
100 | If the item in the list passes the condition, then it is kept
101 | in the list. Otherwise it is taken out of the list. For example,
102 | a function that would get rid of an item if it is below some
103 | minimum number of seconds would look like this:
104 |
105 | .. code-block:: python
106 |
107 | min_length = 1 # in seconds
108 |
109 | # self here refers to the dataset
110 | def remove_short_audio(self, item):
111 | processed_item = self.process_item(item)
112 | mix_length = processed_item['mix'].signal_duration
113 | if mix_length < min_length:
114 | return False
115 | return True
116 |
117 | dataset.items # contains all items
118 | dataset.filter_items_by_condition(remove_short_audio)
119 | dataset.items # contains only items longer than min length
120 |
121 | Args:
122 | func (function): A function that takes in two arguments: the dataset and
123 | this dataset object (self). The function must return a bool.
124 | """
125 | filtered_items = []
126 | n_removed = 0
127 | desc = f"Filtered {n_removed} items out of dataset"
128 | pbar = tqdm.tqdm(self.items, desc=desc)
129 | for item in pbar:
130 | check = func(self, item)
131 | if not isinstance(check, bool):
132 | raise DataSetException(
133 | "Output of filter function must be True or False!"
134 | )
135 | if check:
136 | filtered_items.append(item)
137 | else:
138 | n_removed += 1
139 | pbar.set_description(f"Filtered {n_removed} items out of dataset")
140 | self.items = filtered_items
141 |
142 | @property
143 | def cache_populated(self):
144 | return self._cache_populated
145 |
146 | @cache_populated.setter
147 | def cache_populated(self, value):
148 | self.post_cache_transforms = []
149 | cache_transform = None
150 |
151 | transforms = (
152 | self.transform.transforms
153 | if isinstance(self.transform, tfm.Compose)
154 | else [self.transform])
155 |
156 | found_cache_transform = False
157 | for t in transforms:
158 | if isinstance(t, tfm.Cache):
159 | found_cache_transform = True
160 | cache_transform = t
161 | if found_cache_transform:
162 | self.post_cache_transforms.append(t)
163 |
164 | if not found_cache_transform:
165 | # there is no cache transform
166 | self._cache_populated = False
167 | else:
168 | self._cache_populated = value
169 | cache_transform.cache_size = len(self)
170 | cache_transform.overwrite = not value
171 |
172 | self.post_cache_transforms = tfm.Compose(
173 | self.post_cache_transforms)
174 |
175 | def get_items(self, folder):
176 | """
177 | This function must be implemented by whatever class inherits BaseDataset.
178 | It should return a list of items in the given folder, each of which is
179 | processed by process_items in some way to produce mixes, sources, class
180 | labels, etc.
181 |
182 | Args:
183 | folder (str): location that should be processed to produce the list of files.
184 |
185 | Returns:
186 | list: list of items that should be processed
187 | """
188 | raise NotImplementedError()
189 |
190 | def __len__(self):
191 | """
192 | Gets the length of the dataset (the number of items that will be processed).
193 |
194 | Returns:
195 | int: Length of the dataset (``len(self.items)``).
196 | """
197 | return len(self.items)
198 |
199 | def __getitem__(self, i):
200 | """
201 | Processes a single item in ``self.items`` using ``self.process_item``.
202 | The output of ``self.process_item`` is further passed through bunch of
203 | of transforms if they are defined in parallel. If you want to have
204 | a set of transforms that depend on each other, then you should compose them
205 | into a single transforms and then pass it into here. The output of each
206 | transform is added to an output dictionary which is returned by this
207 | function.
208 |
209 | Args:
210 | i (int): Index of the dataset to return. Indexes ``self.items``.
211 |
212 | Returns:
213 | dict: Dictionary with keys and values corresponding to the processed
214 | item after being put through the set of transforms (if any are
215 | defined).
216 | """
217 | if self.cache_populated:
218 | data = {'index': i}
219 | data = self.post_cache_transforms(data)
220 | else:
221 | data = self.process_item(self.items[i])
222 |
223 | if not isinstance(data, dict):
224 | raise DataSetException(
225 | "The output of process_item must be a dictionary!")
226 |
227 | if self.transform:
228 | data['index'] = i
229 | data = self.transform(data)
230 |
231 | if not isinstance(data, dict):
232 | raise tfm.TransformException(
233 | "The output of transform must be a dictionary!")
234 |
235 | return data
236 |
237 | def process_item(self, item):
238 | """Each file returned by get_items is processed by this function. For example,
239 | if each file is a json file containing the paths to the mixture and sources,
240 | then this function should parse the json file and load the mixture and sources
241 | and return them.
242 |
243 | Exact behavior of this functionality is determined by implementation by subclass.
244 |
245 | Args:
246 | item (object): the item that will be processed by this function. Input depends
247 | on implementation of ``self.get_items``.
248 |
249 | Returns:
250 | This should return a dictionary that gets processed by the transforms.
251 | """
252 | raise NotImplementedError()
253 |
254 | def _load_audio_file(self, path_to_audio_file):
255 | """
256 | Loads audio file at given path. Uses AudioSignal to load the audio data
257 | from disk.
258 |
259 | Args:
260 | path_to_audio_file: relative or absolute path to file to load
261 |
262 | Returns:
263 | AudioSignal: loaded AudioSignal object of path_to_audio_file
264 | """
265 | audio_signal = AudioSignal(path_to_audio_file)
266 | self._setup_audio_signal(audio_signal)
267 | return audio_signal
268 |
269 | def _load_audio_from_array(self, audio_data, sample_rate=None):
270 | """
271 | Loads the audio data into an AudioSignal object with the appropriate
272 | sample rate.
273 |
274 | Args:
275 | audio_data (np.ndarray): numpy array containing the samples containing
276 | the audio data.
277 |
278 | sample_rate (int): the sample rate at which to load the audio file.
279 | If None, self.sample_rate or the sample rate of the actual file is used.
280 | Defaults to None.
281 |
282 | Returns:
283 | AudioSignal: loaded AudioSignal object of audio_data
284 | """
285 | sample_rate = sample_rate if sample_rate else self.sample_rate
286 | audio_signal = AudioSignal(
287 | audio_data_array=audio_data, sample_rate=sample_rate)
288 | self._setup_audio_signal(audio_signal)
289 | return audio_signal
290 |
291 | def _setup_audio_signal(self, audio_signal):
292 | """
293 | You will want every item from a dataset to be uniform in sample rate, STFT
294 | parameters, and number of channels. This function takes an audio signal
295 | object loaded by the dataset and uses it to set the sample rate, STFT parameters,
296 | and the number of channels. If ``self.sample_rate``, ``self.stft_params``, and
297 | ``self.num_channels`` are set at construction time of the dataset, then the
298 | opposite happens - attributes of the AudioSignal object are set to the desired
299 | values.
300 |
301 | Args:
302 | audio_signal (AudioSignal): AudioSignal object to query to set the parameters
303 | of this dataset or to set the parameters of, according to what is in the
304 | dataset.
305 | """
306 | if self.sample_rate and self.sample_rate != audio_signal.sample_rate:
307 | if self.strict_sample_rate:
308 | raise DataSetException(
309 | f"All audio files should have been the same sample rate already "
310 | f"because self.strict_sample_rate = True. Please resample or "
311 | f"turn set self.strict_sample_rate = False"
312 | )
313 | audio_signal.resample(self.sample_rate)
314 | else:
315 | self.sample_rate = audio_signal.sample_rate
316 |
317 | # set audio signal attributes to requested values, if they exist
318 | if self.stft_params:
319 | audio_signal.stft_params = self.stft_params
320 | else:
321 | self.stft_params = audio_signal.stft_params
322 |
323 | if self.num_channels:
324 | if audio_signal.num_channels > self.num_channels:
325 | # pick the first ``self.num_channels`` channels
326 | audio_signal.audio_data = audio_signal.audio_data[:self.num_channels]
327 | elif audio_signal.num_channels < self.num_channels:
328 | warnings.warn(
329 | f"AudioSignal had {audio_signal.num_channels} channels "
330 | f"but self.num_channels = {self.num_channels}. Unsure "
331 | f"of what to do, so warning. You might want to make sure "
332 | f"your dataset is uniform!"
333 | )
334 | else:
335 | self.num_channels = audio_signal.num_channels
336 |
337 |
338 | class DataSetException(Exception):
339 | """
340 | Exception class for errors when working with data sets in nussl.
341 | """
342 | pass
343 |
344 |
345 | class BufferData(BaseDataset):
346 | def __init__(self, folder, to_disk=False, transform=None):
347 | """
348 |
349 | Args:
350 | folder (string): File path to store the data. Put any string when not saving to disk.
351 | to_disk (bool): When true, data will be saved to disk for inspection, data will also be stored in memory
352 | regardless of whether this is True or False
353 | transform (transforms.* object): A transforms to apply to the output of
354 | ``self.process_item``. If using transforms.Compose, each transform will be
355 | applied in sequence. Defaults to None.
356 | """
357 | # Circular buffer parameters
358 | self.MAX_BUFFER_ITEMS = constants.MAX_BUFFER_ITEMS
359 | self.ptr = 0
360 | self.items = []
361 | self.metadata = {}
362 | self.full_buffer = False
363 | self.to_disk = to_disk
364 | self.last_ptr = -1 # To keep track of the latest item in the buffer
365 |
366 | # Make sure the relevant directories exist
367 | if self.to_disk:
368 | if not os.path.exists(constants.DIR_PREV_STATES):
369 | os.mkdir(constants.DIR_PREV_STATES)
370 | if not os.path.exists(constants.DIR_NEW_STATES):
371 | os.mkdir(constants.DIR_NEW_STATES)
372 | if not os.path.exists(constants.DIR_DATASET_ITEMS):
373 | os.mkdir(constants.DIR_DATASET_ITEMS)
374 |
375 | super().__init__(folder=folder, transform=transform)
376 |
377 | def get_items(self, folder):
378 | """
379 | Superclass: "This function must be implemented by whatever class inherits BaseDataset.
380 | It should return a list of items in the given folder, each of which is
381 | processed by process_items in some way to produce mixes, sources, class
382 | labels, etc."
383 |
384 | Implementation: Adds file paths to items list. Keeps list under MAX_ITEMS.
385 |
386 | Args:
387 | folder (str): location that should be processed to produce the list of files
388 | Returns:
389 | list: list of items (path to json files in our case) that should be processed
390 | """
391 | return self.items
392 |
393 | def process_item(self, item):
394 | """
395 | Superclass: Each file returned by get_items is processed by this function. For example,
396 | if each file is a json file containing the paths to the mixture and sources,
397 | then this function should parse the json file and load the mixture and sources
398 | and return them.
399 | Exact behavior of this functionality is determined by implementation by subclass."
400 |
401 | Implementation: read json of format:
402 | {'prev_state': '../data/prev_states/prev8-224.wav',
403 | 'action': 0,
404 | 'reward': -0.1,
405 | 'new_state': '../data/new_states/new8-224.wav'}
406 |
407 | convert the wav files to AudioSignals and return and output dict:
408 | {
409 | 'observations': {
410 | 'prev_state': AudioSignal,
411 | 'new_state': AudioSignal,
412 | }
413 | 'reward': -0.1
414 | 'action': 0
415 | }
416 |
417 | Args:
418 | item (object): the item that will be processed by this function. Input depends
419 | on implementation of ``self.get_items``.
420 | Returns:
421 | This should return a dictionary that gets processed by the transforms.
422 | """
423 |
424 | # load data from memory
425 | output = item.copy()
426 | prev_state, new_state = output['prev_state'], output['new_state']
427 |
428 | # convert to output dict format
429 | del output['prev_state'], output['new_state']
430 | output['prev_state'], output['new_state'] = prev_state, new_state
431 | output['reward'] = np.array([output['reward']], dtype='float32')
432 | output['action'] = np.array([output['action']], dtype='int64')
433 | output['agent_info'] = np.array(output['agent_info'], dtype='float32')
434 | return output
435 |
436 | def random_sample(self, bs):
437 | indices = np.random.choice(len(self.items), bs-1, replace=False)
438 | indices = np.append(indices, self.last_ptr)
439 | return indices
440 |
441 | def append(self, item):
442 | """
443 | Override the default append function to work as circular buffer
444 | Args:
445 | item (object): Item to append to the list
446 |
447 | Returns: Nothing (Item is appended to the circular buffer in place)
448 | """
449 | if self.full_buffer:
450 | self.items[self.ptr] = item
451 | self.last_ptr = self.ptr
452 | self.ptr = (self.ptr + 1) % self.MAX_BUFFER_ITEMS
453 | else:
454 | self.items.append(item)
455 | self.last_ptr += 1
456 | if len(self.items) == self.MAX_BUFFER_ITEMS:
457 | self.ptr = 0
458 | self.last_ptr = 0
459 | self.full_buffer = True
460 |
461 | def write_buffer_data(self, prev_state, action, reward, new_state, agent_info, episode, step):
462 | """
463 | Writes states (AudioSignal objects) to .wav files and stores this buffer data
464 | in json files with the states keys pointing to the .wav files. The json files
465 | are to be read by nussl.datasets.BaseDataset subclass as items.
466 |
467 | E.g. {
468 | 'prev_state': '/path/to/previous/mix.wav',
469 | 'reward': [the reward obtained for reaching current state],
470 | 'action': [the action taken to reach current state from previous state]
471 | 'current_state': '/path/to/current/mix.wav',
472 | }
473 |
474 | The unique file names are structured as path/[prev or new]-[episode #]-[step #]
475 |
476 | Args:
477 | prev_state (nussl.AudioSignal): previous state to be converted and saved as .wav file
478 | action (int): action
479 | reward (int): reward
480 | new_state (nussl.AudioSignal): new state to be converted and saved as wav file
481 | agent_info (list): Consists of agent location and current orientation of the agent
482 | episode (int): which episode we're on, used to create unique file name for state
483 | step (int): which step we're on within episode, used to create unique file name for state
484 |
485 | """
486 | if episode not in self.metadata:
487 | self.metadata[episode] = 1
488 | else:
489 | self.metadata[episode] += 1
490 |
491 | agent_loc, cur_angle = agent_info
492 | agent_info = np.append(agent_loc, cur_angle) # Jut make it a single list to keep things simple
493 | # create buffer dictionary
494 | buffer_dict = {
495 | 'prev_state': prev_state,
496 | 'action': action,
497 | 'reward': reward,
498 | 'new_state': new_state,
499 | 'agent_info': agent_info
500 | }
501 | self.append(buffer_dict)
502 |
503 | # write data for inspection
504 | if self.to_disk and step == 1:
505 | # Unique file names for each state
506 | cur_file = str(episode) + '-' + str(step)
507 | prev_state_file_path = os.path.join(
508 | constants.DIR_PREV_STATES, 'prev' + cur_file + '.wav'
509 | )
510 | new_state_file_path = os.path.join(
511 | constants.DIR_NEW_STATES, 'new' + cur_file + '.wav'
512 | )
513 | dataset_json_file_path = os.path.join(
514 | constants.DIR_DATASET_ITEMS, cur_file + '.json'
515 | )
516 |
517 | prev_state.write_audio_to_file(prev_state_file_path)
518 | new_state.write_audio_to_file(new_state_file_path)
519 |
520 | # write to json
521 | buffer_dict_json = {
522 | 'prev_state': prev_state_file_path,
523 | 'action': action,
524 | 'reward': reward,
525 | 'new_state': new_state_file_path,
526 | }
527 |
528 | with open(dataset_json_file_path, 'w') as json_file:
529 | json.dump(buffer_dict_json, json_file)
530 |
531 |
532 | class RLDataset(IterableDataset):
533 | """
534 | Dataset which gets updated as buffer gets filled
535 | """
536 | def __init__(self, buffer, sample_size):
537 | self.buffer = buffer
538 | self.sample_size = sample_size
539 |
540 | def __iter__(self):
541 | batch_indices = self.buffer.random_sample(self.sample_size)
542 | for index in batch_indices:
543 | yield self.buffer[index]
544 |
--------------------------------------------------------------------------------
/src/transforms.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import logging
4 | import random
5 | from collections import OrderedDict
6 |
7 | import torch
8 | import zarr
9 | import numcodecs
10 | import numpy as np
11 | from sklearn.preprocessing import OneHotEncoder
12 |
13 | from nussl.core import utils
14 |
15 | # This is for when you're running multiple
16 | # training threads
17 | if hasattr(numcodecs, 'blosc'):
18 | numcodecs.blosc.use_threads = False
19 |
20 |
21 | def compute_ideal_binary_mask(source_magnitudes):
22 | ibm = (
23 | source_magnitudes == np.max(source_magnitudes, axis=-1, keepdims=True)
24 | ).astype(float)
25 |
26 | ibm = ibm / np.sum(ibm, axis=-1, keepdims=True)
27 | ibm[ibm <= .5] = 0
28 | return ibm
29 |
30 |
31 | # Keys that correspond to the time-frequency representations after being passed through
32 | # the transforms here.
33 | time_frequency_keys = ['mix_magnitude', 'source_magnitudes', 'ideal_binary_mask', 'weights']
34 |
35 |
36 | class SumSources(object):
37 | """
38 | Sums sources together. Looks for sources in ``data[self.source_key]``. If
39 | a source belongs to a group, it is popped from the ``data[self.source_key]`` and
40 | summed with the other sources in the group. If there is a corresponding
41 | group_name in group_names, it is named that in ``data[self.source_key]``. If
42 | group_names are not given, then the names are constructed using the keys
43 | in each group (e.g. `drums+bass+other`).
44 |
45 | If using Scaper datasets, then there may be multiple sources with the same
46 | label but different counts. The Scaper dataset hook organizes the source
47 | dictionary as follows:
48 |
49 | .. code-block:: none
50 |
51 | data['sources] = {
52 | '{label}::{count}': AudioSignal,
53 | '{label}::{count}': AudioSignal,
54 | ...
55 | }
56 |
57 | SumSources sums by source label, so the ``::count`` will be ignored and only the
58 | label part will be used when grouping sources.
59 |
60 | Example:
61 | >>> tfm = transforms.SumSources(
62 | >>> groupings=[['drums', 'bass', 'other]],
63 | >>> group_names=['accompaniment],
64 | >>> )
65 | >>> # data['sources'] is a dict containing keys:
66 | >>> # ['vocals', 'drums', 'bass', 'other]
67 | >>> data = tfm(data)
68 | >>> # data['sources'] is now a dict containing keys:
69 | >>> # ['vocals', 'accompaniment']
70 |
71 | Args:
72 | groupings (list): a list of lists telling how to group each sources.
73 | group_names (list, optional): A list containing the names of each group, or None.
74 | Defaults to None.
75 | source_key (str, optional): The key to look for in the data containing the list of
76 | source AudioSignals. Defaults to 'sources'.
77 |
78 | Raises:
79 | TransformException: if groupings is not a list
80 | TransformException: if group_names is not None but
81 | len(groupings) != len(group_names)
82 |
83 | Returns:
84 | data: modified dictionary with summed sources
85 | """
86 |
87 | def __init__(self, groupings, group_names=None, source_key='sources'):
88 | if not isinstance(groupings, list):
89 | raise TransformException(
90 | f"groupings must be a list, got {type(groupings)}!")
91 |
92 | if group_names:
93 | if len(group_names) != len(groupings):
94 | raise TransformException(
95 | f"group_names and groupings must be same length or "
96 | f"group_names can be None! Got {len(group_names)} for "
97 | f"len(group_names) and {len(groupings)} for len(groupings)."
98 | )
99 |
100 | self.groupings = groupings
101 | self.source_key = source_key
102 | if group_names is None:
103 | group_names = ['+'.join(groupings[i]) for i in range(len(groupings))]
104 | self.group_names = group_names
105 |
106 | def __call__(self, data):
107 | if self.source_key not in data:
108 | raise TransformException(
109 | f"Expected {self.source_key} in dictionary "
110 | f"passed to this Transform!"
111 | )
112 | sources = data[self.source_key]
113 | source_keys = [(k.split('::')[0], k) for k in list(sources.keys())]
114 |
115 | for i, group in enumerate(self.groupings):
116 | combined = []
117 | group_name = self.group_names[i]
118 | for key1 in group:
119 | for key2 in source_keys:
120 | if key2[0] == key1:
121 | combined.append(sources[key2[1]])
122 | sources.pop(key2[1])
123 | sources[group_name] = sum(combined)
124 | sources[group_name].path_to_input_file = group_name
125 |
126 | data[self.source_key] = sources
127 | if 'metadata' in data:
128 | if 'labels' in data['metadata']:
129 | data['metadata']['labels'].extend(self.group_names)
130 |
131 | return data
132 |
133 | def __repr__(self):
134 | return (
135 | f"{self.__class__.__name__}("
136 | f"groupings = {self.groupings}, "
137 | f"group_names = {self.group_names}, "
138 | f"source_key = {self.source_key}"
139 | f")"
140 | )
141 |
142 |
143 | class LabelsToOneHot(object):
144 | """
145 | Takes a data dictionary with sources and their keys and converts the keys to
146 | a one-hot numpy array using the list in data['metadata']['labels'] to figure
147 | out which index goes where.
148 | """
149 |
150 | def __init__(self, source_key='sources'):
151 | self.source_key = source_key
152 |
153 | def __call__(self, data):
154 | if 'metadata' not in data:
155 | raise TransformException(
156 | f"Expected metadata in data, got {list(data.keys())}")
157 | if 'labels' not in data['metadata']:
158 | raise TransformException(
159 | f"Expected labels in data['metadata'], got "
160 | f"{list(data['metadata'].keys())}")
161 |
162 | enc = OneHotEncoder(categories=[data['metadata']['labels']])
163 |
164 | sources = data[self.source_key]
165 | source_keys = [k.split('::')[0] for k in list(sources.keys())]
166 | source_labels = [[l] for l in sorted(source_keys)]
167 |
168 | one_hot_labels = enc.fit_transform(source_labels)
169 | data['one_hot_labels'] = one_hot_labels.toarray()
170 |
171 | return data
172 |
173 |
174 | class MagnitudeSpectrumApproximation(object):
175 | """
176 | Takes a dictionary and looks for two special keys, defined by the
177 | arguments ``mix_key`` and ``source_key``. These default to `mix` and `sources`.
178 | These values of these keys are used to calculate the magnitude spectrum
179 | approximation [1]. The input dictionary is modified to have additional
180 | keys:
181 |
182 | - mix_magnitude: The magnitude spectrogram of the mixture audio signal.
183 | - source_magnitudes: The magnitude spectrograms of each source spectrogram.
184 | - assignments: The ideal binary assignments for each time-frequency bin.
185 |
186 | ``data[self.source_key]`` points to a dictionary containing the source names in
187 | the keys and the corresponding AudioSignal in the values. The keys are sorted
188 | in alphabetical order and then appended to the mask. ``data[self.source_key]``
189 | then points to an OrderedDict instead, where the keys are in the same order
190 | as in ``data['source_magnitudes']`` and ``data['assignments']``.
191 |
192 | This transform uses the STFTParams that are attached to the AudioSignal objects
193 | contained in ``data[mix_key]`` and ``data[source_key]``.
194 |
195 | [1] Erdogan, Hakan, John R. Hershey, Shinji Watanabe, and Jonathan Le Roux.
196 | "Phase-sensitive and recognition-boosted speech separation using
197 | deep recurrent neural networks." In 2015 IEEE International Conference
198 | on Acoustics, Speech and Signal Processing (ICASSP), pp. 708-712. IEEE,
199 | 2015.
200 |
201 | Args:
202 | mix_key (str, optional): The key to look for in data for the mixture AudioSignal.
203 | Defaults to 'mix'.
204 | source_key (str, optional): The key to look for in the data containing the dict of
205 | source AudioSignals. Defaults to 'sources'.
206 |
207 | Raises:
208 | TransformException: if the expected keys are not in the dictionary, an
209 | Exception is raised.
210 |
211 | Returns:
212 | data: Modified version of the input dictionary.
213 | """
214 |
215 | def __init__(self, mix_key='mix', source_key='sources'):
216 | self.mix_key = mix_key
217 | self.source_key = source_key
218 |
219 | def __call__(self, data):
220 | if self.mix_key not in data:
221 | raise TransformException(
222 | f"Expected {self.mix_key} in dictionary "
223 | f"passed to this Transform! Got {list(data.keys())}."
224 | )
225 |
226 | mixture = data[self.mix_key]
227 | mixture.stft()
228 | mix_magnitude = mixture.magnitude_spectrogram_data
229 |
230 | data['mix_magnitude'] = mix_magnitude
231 |
232 | if self.source_key not in data:
233 | return data
234 |
235 | _sources = data[self.source_key]
236 | source_names = sorted(list(_sources.keys()))
237 |
238 | sources = OrderedDict()
239 | for key in source_names:
240 | sources[key] = _sources[key]
241 | data[self.source_key] = sources
242 |
243 | source_magnitudes = []
244 | for key in source_names:
245 | s = sources[key]
246 | s.stft()
247 | source_magnitudes.append(s.magnitude_spectrogram_data)
248 |
249 | source_magnitudes = np.stack(source_magnitudes, axis=-1)
250 |
251 | data['ideal_binary_mask'] = compute_ideal_binary_mask(source_magnitudes)
252 | data['source_magnitudes'] = source_magnitudes
253 |
254 | return data
255 |
256 | def __repr__(self):
257 | return (
258 | f"{self.__class__.__name__}("
259 | f"mix_key = {self.mix_key}, "
260 | f"source_key = {self.source_key}"
261 | f")"
262 | )
263 |
264 |
265 | class MagnitudeWeights(object):
266 | """
267 | Applying time-frequency weights to the deep clustering objective results in a
268 | huge performance boost. This transform looks for 'mix_magnitude', which is output
269 | by either MagnitudeSpectrumApproximation or PhaseSensitiveSpectrumApproximation
270 | and puts it into the weights.
271 |
272 | [1] Wang, Zhong-Qiu, Jonathan Le Roux, and John R. Hershey.
273 | "Alternative objective functions for deep clustering." 2018 IEEE International
274 | Conference on Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2018.
275 |
276 | Args:
277 | mix_magnitude_key (str): Which key to look for the mix_magnitude data in.
278 | """
279 |
280 | def __init__(self, mix_key='mix', mix_magnitude_key='mix_magnitude'):
281 | self.mix_magnitude_key = mix_magnitude_key
282 | self.mix_key = mix_key
283 |
284 | def __call__(self, data):
285 | if self.mix_magnitude_key not in data and self.mix_key not in data:
286 | raise TransformException(
287 | f"Expected {self.mix_magnitude_key} or {self.mix_key} in dictionary "
288 | f"passed to this Transform! Got {list(data.keys())}. "
289 | "Either MagnitudeSpectrumApproximation or "
290 | "PhaseSensitiveSpectrumApproximation should be called "
291 | "on the data dict prior to this transform. "
292 | )
293 | elif self.mix_magnitude_key not in data:
294 | data[self.mix_magnitude_key] = np.abs(data[self.mix_key].stft())
295 |
296 | magnitude_spectrogram = data[self.mix_magnitude_key]
297 | weights = magnitude_spectrogram / (np.sum(magnitude_spectrogram) + 1e-6)
298 | weights *= (
299 | magnitude_spectrogram.shape[0] * magnitude_spectrogram.shape[1]
300 | )
301 | data['weights'] = np.sqrt(weights)
302 | return data
303 |
304 |
305 | class PhaseSensitiveSpectrumApproximation(object):
306 | """
307 | Takes a dictionary and looks for two special keys, defined by the
308 | arguments ``mix_key`` and ``source_key``. These default to `mix` and `sources`.
309 | These values of these keys are used to calculate the phase sensitive spectrum
310 | approximation [1]. The input dictionary is modified to have additional
311 | keys:
312 |
313 | - mix_magnitude: The magnitude spectrogram of the mixture audio signal.
314 | - source_magnitudes: The magnitude spectrograms of each source spectrogram.
315 | - assignments: The ideal binary assignments for each time-frequency bin.
316 |
317 | ``data[self.source_key]`` points to a dictionary containing the source names in
318 | the keys and the corresponding AudioSignal in the values. The keys are sorted
319 | in alphabetical order and then appended to the mask. ``data[self.source_key]``
320 | then points to an OrderedDict instead, where the keys are in the same order
321 | as in ``data['source_magnitudes']`` and ``data['assignments']``.
322 |
323 | This transform uses the STFTParams that are attached to the AudioSignal objects
324 | contained in ``data[mix_key]`` and ``data[source_key]``.
325 |
326 | [1] Erdogan, Hakan, John R. Hershey, Shinji Watanabe, and Jonathan Le Roux.
327 | "Phase-sensitive and recognition-boosted speech separation using
328 | deep recurrent neural networks." In 2015 IEEE International Conference
329 | on Acoustics, Speech and Signal Processing (ICASSP), pp. 708-712. IEEE,
330 | 2015.
331 |
332 | Args:
333 | mix_key (str, optional): The key to look for in data for the mixture AudioSignal.
334 | Defaults to 'mix'.
335 | source_key (str, optional): The key to look for in the data containing the list of
336 | source AudioSignals. Defaults to 'sources'.
337 | range_min (float, optional): The lower end to use when truncating the source
338 | magnitudes in the phase sensitive spectrum approximation. Defaults to 0.0 (construct
339 | non-negative masks). Use -np.inf for untruncated source magnitudes.
340 | range_max (float, optional): The higher end of the truncated spectrum. This gets
341 | multiplied by the magnitude of the mixture. Use 1.0 to truncate the source
342 | magnitudes to `max(source_magnitudes, mix_magnitude)`. Use np.inf for untruncated
343 | source magnitudes (best performance for an oracle mask but may be beyond what a
344 | neural network is capable of masking). Defaults to 1.0.
345 |
346 | Raises:
347 | TransformException: if the expected keys are not in the dictionary, an
348 | Exception is raised.
349 |
350 | Returns:
351 | data: Modified version of the input dictionary.
352 | """
353 |
354 | def __init__(self, mix_key='mix', source_key='sources',
355 | range_min=0.0, range_max=1.0):
356 | self.mix_key = mix_key
357 | self.source_key = source_key
358 | self.range_min = range_min
359 | self.range_max = range_max
360 |
361 | def __call__(self, data):
362 | if self.mix_key not in data:
363 | raise TransformException(
364 | f"Expected {self.mix_key} in dictionary "
365 | f"passed to this Transform! Got {list(data.keys())}."
366 | )
367 |
368 | mixture = data[self.mix_key]
369 |
370 | mix_stft = mixture.stft()
371 | mix_magnitude = np.abs(mix_stft)
372 | mix_angle = np.angle(mix_stft)
373 | data['mix_magnitude'] = mix_magnitude
374 |
375 | if self.source_key not in data:
376 | return data
377 |
378 | _sources = data[self.source_key]
379 | source_names = sorted(list(_sources.keys()))
380 |
381 | sources = OrderedDict()
382 | for key in source_names:
383 | sources[key] = _sources[key]
384 | data[self.source_key] = sources
385 |
386 | source_angles = []
387 | source_magnitudes = []
388 | for key in source_names:
389 | s = sources[key]
390 | _stft = s.stft()
391 | source_magnitudes.append(np.abs(_stft))
392 | source_angles.append(np.angle(_stft))
393 |
394 | source_magnitudes = np.stack(source_magnitudes, axis=-1)
395 | source_angles = np.stack(source_angles, axis=-1)
396 | range_min = self.range_min
397 | range_max = self.range_max * mix_magnitude[..., None]
398 |
399 | # Section 3.1: https://arxiv.org/pdf/1909.08494.pdf
400 | source_magnitudes = np.minimum(
401 | np.maximum(
402 | source_magnitudes * np.cos(source_angles - mix_angle[..., None]),
403 | range_min
404 | ),
405 | range_max
406 | )
407 |
408 | data['ideal_binary_mask'] = compute_ideal_binary_mask(source_magnitudes)
409 | data['source_magnitudes'] = source_magnitudes
410 |
411 | return data
412 |
413 | def __repr__(self):
414 | return (
415 | f"{self.__class__.__name__}("
416 | f"mix_key = {self.mix_key}, "
417 | f"source_key = {self.source_key}"
418 | f")"
419 | )
420 |
421 |
422 | class IndexSources(object):
423 | """
424 | Takes in a dictionary containing Torch tensors or numpy arrays and extracts the
425 | indexed sources from the set key (usually either `source_magnitudes` or
426 | `ideal_binary_mask`). Can be used to train single-source separation models
427 | (e.g. mix goes in, vocals come out).
428 |
429 | You need to know which slice of the source magnitudes or ideal binary mask arrays
430 | to extract. The order of the sources in the source magnitudes array will be in
431 | alphabetical order according to their source labels.
432 |
433 | For example, if source magnitudes has shape `(257, 400, 1, 4)`, and the data is
434 | from MUSDB, then the four possible source labels are bass, drums, other, and vocals.
435 | The data in source magnitudes is in alphabetical order, so:
436 |
437 | .. code-block:: python
438 |
439 | # source_magnitudes is an array returned by either MagnitudeSpectrumApproximation
440 | # or PhaseSensitiveSpectrumApproximation
441 | source_magnitudes[..., 0] # bass spectrogram
442 | source_magnitudes[..., 1] # drums spectrogram
443 | source_magnitudes[..., 2] # other spectrogram
444 | source_magnitudes[..., 3] # vocals spectrogram
445 |
446 | # ideal_binary_mask is an array returned by either MagnitudeSpectrumApproximation
447 | # or PhaseSensitiveSpectrumApproximation
448 | ideal_binary_mask[..., 0] # bass ibm mask
449 | ideal_binary_mask[..., 1] # drums ibm mask
450 | ideal_binary_mask[..., 2] # other ibm mask
451 | ideal_binary_mask[..., 3] # vocals ibm mask
452 |
453 | You can apply this transform to either the `source_magnitudes` or the
454 | `ideal_binary_mask` or both.
455 |
456 |
457 | Args:
458 | object ([type]): [description]
459 | """
460 |
461 | def __init__(self, target_key, index):
462 | self.target_key = target_key
463 | self.index = index
464 |
465 | def __call__(self, data):
466 | if self.target_key not in data:
467 | raise TransformException(
468 | f"Expected {self.target_key} in dictionary, got {list(data.keys())}")
469 | if self.index >= data[self.target_key].shape[-1]:
470 | raise TransformException(
471 | f"Shape of data[{self.target_key}] is {data[self.target_key].shape} "
472 | f"but index = {self.index} out of bounds bounds of last dim.")
473 | data[self.target_key] = data[self.target_key][..., self.index, None]
474 | return data
475 |
476 |
477 | class GetExcerpt(object):
478 | """
479 | Takes in a dictionary containing Torch tensors or numpy arrays and extracts an
480 | excerpt from each tensor corresponding to a spectral representation of a specified
481 | length in frames. Can be used to get L-length spectrograms from mixture and source
482 | spectrograms. If the data is shorter than the specified length, it
483 | is padded to the specified length. If it is longer, a random offset between
484 | ``(0, data_length - specified_length)`` is chosen. This function assumes that
485 | it is being passed data AFTER ToSeparationModel. Thus the time dimension is
486 | on axis=1.
487 |
488 | Args:
489 | excerpt_length (int): Specified length of transformed data in frames.
490 |
491 | time_dim (int): Which dimension time is on (excerpts are taken along this axis).
492 | Defaults to 0.
493 |
494 | time_frequency_keys (list): Which keys to look at it in the data dictionary to
495 | take excerpts from.
496 | """
497 |
498 | def __init__(self, excerpt_length, time_dim=0,
499 | tf_keys=None):
500 | self.excerpt_length = excerpt_length
501 | self.time_dim = time_dim
502 | self.time_frequency_keys = tf_keys if tf_keys else time_frequency_keys
503 | # print('time_freqency_keys:', self.time_frequency_keys)
504 |
505 | @staticmethod
506 | def _validate(data, key):
507 | is_tensor = torch.is_tensor(data[key])
508 | is_array = isinstance(data[key], np.ndarray)
509 | if not is_tensor and not is_array:
510 | raise TransformException(
511 | f"data[{key}] was not a torch Tensor or a numpy array!")
512 | return is_tensor, is_array
513 |
514 | def _get_offset(self, data, key):
515 | self._validate(data, key)
516 | data_length = data[key].shape[self.time_dim]
517 |
518 | if data_length >= self.excerpt_length:
519 | offset = random.randint(0, data_length - self.excerpt_length)
520 | else:
521 | offset = 0
522 |
523 | pad_amount = max(0, self.excerpt_length - data_length)
524 | # print('key:', key)
525 | # print('data len:', data_length)
526 | # print('PAD:', pad_amount)
527 | return offset, pad_amount
528 |
529 | def _construct_pad_func_tuple(self, shape, pad_amount, is_tensor):
530 | if is_tensor:
531 | pad_func = torch.nn.functional.pad
532 | pad_tuple = [0 for _ in range(2 * len(shape))]
533 | pad_tuple[2 * self.time_dim] = pad_amount
534 | pad_tuple = pad_tuple[::-1]
535 | else:
536 | pad_func = np.pad
537 | pad_tuple = [(0, 0) for _ in range(len(shape))]
538 | pad_tuple[self.time_dim] = (0, pad_amount)
539 | return pad_func, pad_tuple
540 |
541 | def __call__(self, data):
542 | offset, pad_amount = self._get_offset(
543 | data, self.time_frequency_keys[0])
544 |
545 | for key in data:
546 | if key in self.time_frequency_keys:
547 | is_tensor, is_array = self._validate(data, key)
548 |
549 | if pad_amount > 0:
550 | pad_func, pad_tuple = self._construct_pad_func_tuple(
551 | data[key].shape, pad_amount, is_tensor)
552 | data[key] = pad_func(data[key], pad_tuple)
553 |
554 | data[key] = utils._slice_along_dim(
555 | data[key], self.time_dim, offset, offset + self.excerpt_length)
556 |
557 | # to verify mix lengths are the same
558 | # print(key, data[key].size())
559 |
560 | return data
561 |
562 |
563 | class Cache(object):
564 | """
565 | The Cache transform can be placed within a Compose transform. The data
566 | dictionary coming into this transform will be saved to the specified
567 | location using ``zarr``. Then instead of computing all of the transforms
568 | before the cache, one can simply read from the cache. The transforms after
569 | this will then be applied to the data dictionary that is read from the
570 | cache. A typical pipeline might look like this:
571 |
572 | .. code-block:: python
573 |
574 | dataset = datasets.Scaper('path/to/scaper/folder')
575 | tfm = transforms.Compose([
576 | transforms.PhaseSensitiveApproximation(),
577 | transforms.ToSeparationModel(),
578 | transforms.Cache('~/.nussl/cache/tag', overwrite=True),
579 | transforms.GetExcerpt()
580 | ])
581 | dataset[0] # first time will write to cache then apply GetExcerpt
582 | dataset.cache_populated = True # switches to reading from cache
583 | dataset[0] # second time will read from cache then apply GetExcerpt
584 | dataset[1] # will error out as it wasn't written to the cache!
585 |
586 | dataset.cache_populated = False
587 | for i in range(len(dataset)):
588 | dataset[i] # every item will get written to cache
589 | dataset.cache_populated = True
590 | dataset[1] # now it exists
591 |
592 | dataset = datasets.Scaper('path/to/scaper/folder') # next time around
593 | tfm = transforms.Compose([
594 | transforms.PhaseSensitiveApproximation(),
595 | transforms.ToSeparationModel(),
596 | transforms.Cache('~/.nussl/cache/tag', overwrite=False),
597 | transforms.GetExcerpt()
598 | ])
599 | dataset.cache_populated = True
600 | dataset[0] # will read from cache, which still exists from last time
601 |
602 | Args:
603 | object ([type]): [description]
604 | """
605 |
606 | def __init__(self, location, cache_size=1, overwrite=False):
607 | self.location = location
608 | self.cache_size = cache_size
609 | self.cache = None
610 | self.overwrite = overwrite
611 |
612 | @property
613 | def info(self):
614 | return self.cache.info
615 |
616 | @property
617 | def overwrite(self):
618 | return self._overwrite
619 |
620 | @overwrite.setter
621 | def overwrite(self, value):
622 | self._overwrite = value
623 | self._clear_cache(self.location)
624 | self._open_cache(self.location)
625 |
626 | def _clear_cache(self, location):
627 | if os.path.exists(location):
628 | if self.overwrite:
629 | logging.info(
630 | f"Cache {location} exists and overwrite = True, clearing cache.")
631 | shutil.rmtree(location, ignore_errors=True)
632 |
633 | def _open_cache(self, location):
634 | if self.overwrite:
635 | self.cache = zarr.open(location, mode='w', shape=(self.cache_size,),
636 | chunks=(1,), dtype=object,
637 | object_codec=numcodecs.Pickle(),
638 | synchronizer=zarr.ThreadSynchronizer())
639 | else:
640 | if os.path.exists(location):
641 | self.cache = zarr.open(location, mode='r',
642 | object_codec=numcodecs.Pickle(),
643 | synchronizer=zarr.ThreadSynchronizer())
644 |
645 | def __call__(self, data):
646 | if 'index' not in data:
647 | raise TransformException(
648 | f"Expected 'index' in dictionary, got {list(data.keys())}")
649 | index = data['index']
650 | if self.overwrite:
651 | self.cache[index] = data
652 | data = self.cache[index]
653 |
654 | if not isinstance(data, dict):
655 | raise TransformException(
656 | f"Reading from cache resulted in not a dictionary! "
657 | f"Maybe you haven't written to index {index} yet in "
658 | f"the cache?")
659 |
660 | return data
661 |
662 |
663 | class GetAudio(object):
664 | """
665 | Extracts the audio from each signal in `mix_key` and `source_key`.
666 | These will be at new keys, called `mix_audio` and `source_audio`.
667 | Can be used for training end-to-end models.
668 |
669 | Args:
670 | mix_key (str, optional): The key to look for in data for the mixture AudioSignal.
671 | Defaults to 'mix'.
672 | source_key (str, optional): The key to look for in the data containing the dict of
673 | source AudioSignals. Defaults to 'sources'.
674 | """
675 |
676 | def __init__(self, mix_key=['mix'], source_key='sources'):
677 | self.mix_key = mix_key
678 | self.source_key = source_key
679 |
680 | if not isinstance(mix_key, list):
681 | raise TransformException(
682 | f"Expected a list of mix keys"
683 | )
684 |
685 | def __call__(self, data):
686 |
687 | for key in self.mix_key:
688 | if key not in data:
689 | raise TransformException(
690 | f"Expected {key} in dictionary "
691 | f"passed to this Transform! Got {list(data.keys())}."
692 | )
693 |
694 | # Check if mix_key is a list of keys
695 | for key in self.mix_key:
696 | new_key = 'mix_audio_' + key
697 | data[new_key] = data[key].audio_data
698 |
699 | if self.source_key not in data:
700 | return data
701 |
702 | _sources = data[self.source_key]
703 | source_names = sorted(list(_sources.keys()))
704 |
705 | source_audio = []
706 | for key in source_names:
707 | source_audio.append(_sources[key].audio_data)
708 | # sources on last axis
709 | source_audio = np.stack(source_audio, axis=-1)
710 |
711 | data['source_audio'] = source_audio
712 | return data
713 |
714 |
715 | class ToSeparationModel(object):
716 | """
717 | Takes in a dictionary containing objects and removes any objects that cannot
718 | be passed to SeparationModel (e.g. not a numpy array or torch Tensor).
719 | If these objects are passed to SeparationModel, then an error will occur. This
720 | class should be the last one in your list of transforms, if you're using
721 | this dataset in a DataLoader object for training a network. If the keys
722 | correspond to numpy arrays, they are converted to tensors using
723 | ``torch.from_numpy``. Finally, the dimensions corresponding to time and
724 | frequency are swapped for all the keys in swap_tf_dims, as this is how
725 | SeparationModel expects it.
726 |
727 | Example:
728 |
729 | .. code-block:: none
730 |
731 | data = {
732 | # 2ch spectrogram for mixture
733 | 'mix_magnitude': torch.randn(513, 400, 2),
734 | # 2ch spectrogram for each source
735 | 'source_magnitudes': torch.randn(513, 400, 2, 4)
736 | 'mix': AudioSignal()
737 | }
738 |
739 | tfm = transforms.ToSeparationModel()
740 | data = tfm(data)
741 |
742 | data['mix_magnitude'].shape # (400, 513, 2)
743 | data['source_magnitudes].shape # (400, 513, 2, 4)
744 | 'mix' in data.keys() # False
745 |
746 |
747 | If this class isn't in your transforms list for the dataset, but you are
748 | using it in the Trainer class, then it is added automatically as the
749 | last transform.
750 | """
751 |
752 | def __init__(self, swap_tf_dims=None):
753 | self.swap_tf_dims = swap_tf_dims if swap_tf_dims else time_frequency_keys
754 |
755 | def __call__(self, data):
756 | keys = list(data.keys())
757 | for key in keys:
758 | if key != 'index':
759 | is_array = isinstance(data[key], np.ndarray)
760 | if is_array:
761 | data[key] = torch.from_numpy(data[key])
762 | if not torch.is_tensor(data[key]):
763 | data.pop(key)
764 | if key in self.swap_tf_dims:
765 | data[key] = data[key].transpose(1, 0)
766 | return data
767 |
768 | def __repr__(self):
769 | return f"{self.__class__.__name__}()"
770 |
771 |
772 | class Compose(object):
773 | """Composes several transforms together. Inspired by torchvision implementation.
774 |
775 | Args:
776 | transforms (list of ``Transform`` objects): list of transforms to compose.
777 |
778 | Example:
779 | >>> transforms.Compose([
780 | >>> transforms.MagnitudeSpectrumApproximation(),
781 | >>> transforms.ToSeparationModel(),
782 | >>> ])
783 | """
784 |
785 | def __init__(self, transforms):
786 | self.transforms = transforms
787 |
788 | def __call__(self, data):
789 | for t in self.transforms:
790 | data = t(data)
791 | if not isinstance(data, dict):
792 | raise TransformException(
793 | "The output of every transform must be a dictionary!")
794 | return data
795 |
796 | def __repr__(self):
797 | format_string = self.__class__.__name__ + '('
798 | for t in self.transforms:
799 | format_string += '\n'
800 | format_string += ' {0}'.format(t)
801 | format_string += '\n)'
802 | return format_string
803 |
804 |
805 | class TransformException(Exception):
806 | """
807 | Exception class for errors when working with transforms in nussl.
808 | """
809 | pass
810 |
811 |
--------------------------------------------------------------------------------