├── imbDRL ├── __init__.py ├── agents │ ├── __init__.py │ └── ddqn.py ├── examples │ ├── __init__.py │ └── ddqn │ │ ├── __init__.py │ │ ├── analyse_model.py │ │ ├── train_imdb.py │ │ ├── train_mnist.py │ │ ├── train_credit.py │ │ ├── train_famnist.py │ │ ├── train_aki.py │ │ └── train_titanic.py ├── environments │ ├── __init__.py │ └── classifierenv.py ├── utils.py ├── metrics.py └── data.py ├── images └── results.png ├── requirements.txt ├── tox.ini ├── .github └── workflows │ ├── publish-pip.yml │ └── python-build.yaml ├── setup.py ├── tests ├── test_utils.py ├── test_environments.py ├── test_ddqn.py ├── test_metrics.py └── test_data.py ├── .gitignore ├── README.md └── LICENSE /imbDRL/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imbDRL/agents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imbDRL/examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imbDRL/environments/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imbDRL/examples/ddqn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /images/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denbergvanthijs/imbDRL/HEAD/images/results.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow>=2.3 2 | tf_agents>=0.6.0 3 | numpy!=1.19.4 4 | gym 5 | matplotlib 6 | pandas 7 | seaborn 8 | scikit-learn 9 | tqdm 10 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [coverage:report] 2 | skip_empty = true 3 | exclude_lines = 4 | pragma: no cover 5 | # Skip abstract methods 6 | raise NotImplementedError 7 | omit = *examples* 8 | 9 | [pytest] 10 | addopts = -v --disable-pytest-warnings --cov=imbDRL --cov-report html 11 | 12 | [flake8] 13 | ignore = D100,D104,D205,D401,I100,I201 14 | max-complexity = -1 15 | max-line-length = 140 16 | enable-extension = G 17 | show-source = True 18 | count = True -------------------------------------------------------------------------------- /.github/workflows/publish-pip.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.8' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | from setuptools import find_packages, setup 4 | 5 | base = pathlib.Path(__file__).parent.resolve() 6 | long_description = (base / "README.md").read_text(encoding="utf-8") 7 | install_requires = (base / "requirements.txt").read_text(encoding="utf-8").split("\n")[:-1] # Remove empty string at last index 8 | 9 | setup(name="imbDRL", 10 | version="2021.1.26.1", 11 | author="Thijs van den Berg", 12 | author_email="denbergvanthijs@gmail.com", 13 | description="Imbalanced Classification with Deep Reinforcement Learning.", 14 | long_description=long_description, 15 | long_description_content_type="text/markdown", 16 | url="https://github.com/Denbergvanthijs/imbDRL", 17 | packages=find_packages(), 18 | classifiers=["Programming Language :: Python :: 3", 19 | "License :: OSI Approved :: Apache Software License", 20 | "Operating System :: OS Independent", 21 | "Environment :: GPU :: NVIDIA CUDA", 22 | "Topic :: Scientific/Engineering :: Artificial Intelligence"], 23 | keywords="imbalanced classification, deep reinforcement learning, deep q-network, reward-function, classification, medical", 24 | install_requires=install_requires, 25 | python_requires=">=3.7") 26 | -------------------------------------------------------------------------------- /imbDRL/examples/ddqn/analyse_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from imbDRL.agents.ddqn import TrainDDQN 4 | from imbDRL.data import load_csv 5 | from imbDRL.metrics import (classification_metrics, network_predictions, 6 | plot_confusion_matrix, plot_pr_curve, 7 | plot_roc_curve) 8 | from imbDRL.utils import rounded_dict 9 | 10 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # CPU is faster than GPU on structured data 11 | 12 | min_class = [1] # Minority classes, same setup as in original paper 13 | maj_class = [0] # Majority classes 14 | fp_model = "./models/20210118_132311.pkl" 15 | 16 | X_train, y_train, X_test, y_test = load_csv("./data/credit0.csv", "./data/credit1.csv", "Class", ["Time"], normalization=True) 17 | network = TrainDDQN.load_network(fp_model) 18 | 19 | y_pred_train = network_predictions(network, X_train) 20 | y_pred_test = network_predictions(network, X_test) 21 | 22 | stats = classification_metrics(y_train, y_pred_train) 23 | print(f"Train: {rounded_dict(stats)}") 24 | stats = classification_metrics(y_test, y_pred_test) 25 | print(f"Test: {rounded_dict(stats)}") 26 | 27 | plot_pr_curve(network, X_test, y_test, X_train, y_train) 28 | plot_roc_curve(network, X_test, y_test, X_train, y_train) 29 | plot_confusion_matrix(stats.get("TP"), stats.get("FN"), stats.get("FP"), stats.get("TN")) 30 | -------------------------------------------------------------------------------- /.github/workflows/python-build.yaml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with Python 3.6 and 3.8 2 | 3 | name: Build 4 | 5 | on: 6 | push: 7 | branches: [master] 8 | pull_request: 9 | branches: [master] 10 | 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | strategy: 15 | matrix: 16 | python-version: [3.7, 3.8] 17 | 18 | steps: 19 | - uses: actions/checkout@v2 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v2 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install . 28 | pip install flake8 29 | pip install flake8-import-order 30 | pip install flake8-blind-except 31 | pip install flake8-builtins 32 | pip install flake8-docstrings 33 | pip install pytest 34 | pip install pytest-cov 35 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 36 | - name: Lint with flake8 37 | run: | 38 | # Stop the build if there are Python syntax errors or undefined names 39 | flake8 . --ignore=D100,D104,D205,D401,I100,I201 --show-source --enable-extension=G --max-line-length=140 --max-complexity=-1 --count 40 | - name: Test with pytest 41 | run: | 42 | pytest -vs 43 | -------------------------------------------------------------------------------- /imbDRL/examples/ddqn/train_imdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from imbDRL.agents.ddqn import TrainDDQN 4 | from imbDRL.data import get_train_test_val, load_imdb 5 | from imbDRL.utils import rounded_dict 6 | from tensorflow.keras.layers import Dense 7 | 8 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # CPU is faster than GPU 9 | 10 | episodes = 150_000 # Total number of episodes 11 | warmup_steps = 50_000 # Amount of warmup steps to collect data with random policy 12 | memory_length = 100_000 # Max length of the Replay Memory 13 | batch_size = 32 14 | collect_steps_per_episode = 1000 15 | collect_every = 1000 16 | 17 | target_update_period = 10_000 18 | target_update_tau = 1 19 | n_step_update = 4 20 | 21 | layers = [Dense(250, activation="relu"), 22 | Dense(2, activation=None)] 23 | 24 | learning_rate = 0.00025 # Learning rate 25 | gamma = 0.1 # Discount factor 26 | min_epsilon = 0.01 # Minimal and final chance of choosing random action 27 | decay_episodes = 100_000 # Number of episodes to decay from 1.0 to `min_epsilon` 28 | 29 | imb_ratio = 0.1 # Imbalance rate 30 | min_class = [0] # Minority classes 31 | maj_class = [1] # Majority classes 32 | X_train, y_train, X_test, y_test = load_imdb(config=(5_000, 500)) 33 | X_train, y_train, X_test, y_test, X_val, y_val = get_train_test_val(X_train, y_train, X_test, y_test, 34 | min_class, maj_class, val_frac=0.1, imb_ratio=imb_ratio, imb_test=False) 35 | 36 | model = TrainDDQN(episodes, warmup_steps, learning_rate, gamma, min_epsilon, decay_episodes, target_update_period=target_update_period, 37 | target_update_tau=target_update_tau, batch_size=batch_size, collect_steps_per_episode=collect_steps_per_episode, 38 | memory_length=memory_length, collect_every=collect_every, n_step_update=n_step_update) 39 | 40 | model.compile_model(X_train, y_train, layers, imb_ratio=imb_ratio) 41 | model.train(X_test, y_test, "Gmean") 42 | 43 | stats = model.evaluate(X_test, y_test, X_train, y_train) 44 | print(rounded_dict(stats)) 45 | # {'Gmean': 0.286451, 'F1': 0.152846, 'Precision': 0.4967, 'Recall': 0.09032, 'TP': 1129, 'TN': 11356, 'FP': 1144, 'FN': 11371} 46 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import imbDRL.utils as utils 4 | import numpy as np 5 | import pytest 6 | 7 | 8 | def test_split_csv(tmp_path): 9 | """Tests imbDRL.utils.split_csv.""" 10 | cols = "V1,V2,Class\n" 11 | row0 = "0,0,0\n" 12 | row1 = "1,1,1\n" 13 | 14 | with pytest.raises(FileNotFoundError) as exc: 15 | utils.split_csv(fp=tmp_path / "thisfiledoesnotexist.csv", fp_dest=tmp_path) 16 | assert "File at" in str(exc.value) 17 | 18 | data_file = tmp_path / "data_file.csv" 19 | with open(data_file, "w") as f: 20 | f.writelines([cols, row0, row0, row1, row1]) 21 | 22 | with pytest.raises(ValueError) as exc: 23 | utils.split_csv(fp=data_file, fp_dest=tmp_path / "thisfolderdoesnotexist") 24 | assert "Directory at" in str(exc.value) 25 | 26 | with pytest.raises(ValueError) as exc: 27 | utils.split_csv(fp=data_file, fp_dest=tmp_path, test_size=0.0) 28 | assert "is not in interval" in str(exc.value) 29 | 30 | with pytest.raises(ValueError) as exc: 31 | utils.split_csv(fp=data_file, fp_dest=tmp_path, test_size=1) 32 | assert "is not in interval" in str(exc.value) 33 | 34 | with pytest.raises(ValueError) as exc: 35 | utils.split_csv(fp=data_file, fp_dest=tmp_path, strat_col="ThisColDoesNotExist") 36 | assert "not found in DataFrame" in str(exc.value) 37 | 38 | utils.split_csv(fp=data_file, fp_dest=tmp_path, test_size=0.5) 39 | assert os.path.isfile(tmp_path / "credit0.csv") 40 | assert os.path.isfile(tmp_path / "credit1.csv") 41 | 42 | 43 | def test_rounded_dict(): 44 | """Tests imbDRL.utils.rounded_dict.""" 45 | d = {"A": 10.123456789, "B": 100} 46 | assert utils.rounded_dict(d) == {"A": 10.123457, "B": 100} 47 | 48 | 49 | def test_imbalance_ratio(): 50 | """Tests imbDRL.utils.imbalance_ratio.""" 51 | y = np.array([1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) 52 | assert utils.imbalance_ratio(y) == 0.5 53 | 54 | y = np.array([2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]) 55 | assert utils.imbalance_ratio(y, 2, 3) == 0.5 56 | 57 | y = np.array(["a", "a", "a", "a", "a", "b", "b", "b", "b", "b", "b", "b", "b", "b", "b", ]) 58 | assert utils.imbalance_ratio(y, "a", "b") == 0.5 59 | -------------------------------------------------------------------------------- /tests/test_environments.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from imbDRL.environments.classifierenv import ClassifierEnv 3 | from tf_agents.environments.utils import validate_py_environment 4 | 5 | 6 | def test_ClassifierEnv(): 7 | """Tests imbDRL.environments.classifierenv.ClassifierEnv.""" 8 | X = np.arange(10, dtype=np.float32) 9 | y = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], dtype=np.int32) 10 | 11 | env = ClassifierEnv(X, y, 0.2) 12 | validate_py_environment(env, episodes=5) 13 | 14 | 15 | def test_reset(): 16 | """Tests imbDRL.environments.classifierenv.ClassifierEnv._reset.""" 17 | X = np.arange(10, dtype=np.float32) 18 | y = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], dtype=np.int32) 19 | 20 | env = ClassifierEnv(X, y, 0.2) 21 | env.reset() 22 | env.step([1]) 23 | ts_restart = env.reset() 24 | 25 | assert env.episode_step == 0 26 | assert not env._episode_ended 27 | assert ts_restart.step_type == 0 28 | assert ts_restart.reward == 0 29 | assert ts_restart.discount == 1 30 | assert ts_restart.observation in X # Next observation is any of the values in X since its shuffled each reset 31 | 32 | 33 | def test_step(): 34 | """Tests imbDRL.environments.ClassifyEnv._step.""" 35 | X = np.arange(10, dtype=np.float32) 36 | y = np.ones(10, dtype=np.int32) # All labels are positive 37 | 38 | env = ClassifierEnv(X, y, 0.2) 39 | env.reset() 40 | time_step = env.step([1]) # True Positive 41 | assert time_step.reward == 1 42 | time_step = env.step([0]) # False Negative 43 | assert time_step.reward == -1 44 | 45 | time_step = env.step([1]) 46 | assert time_step.step_type == 0 # Reset since last step was False Negative 47 | 48 | X = np.arange(10, dtype=np.float32) 49 | y = np.zeros(10, dtype=np.int32) # All labels are negative 50 | 51 | env = ClassifierEnv(X, y, 0.2) 52 | env.reset() 53 | time_step = env.step([0]) # True Negative 54 | assert time_step.reward == np.array([0.2], dtype=np.float32) 55 | time_step = env.step([1]) # False Positive 56 | assert time_step.reward == np.array([-0.2], dtype=np.float32) 57 | 58 | env.reset() 59 | for _ in range(X.size): 60 | time_step = env.step([0]) # Take random step 61 | assert time_step.step_type == 0 # Reset since last step end of dataset 62 | -------------------------------------------------------------------------------- /imbDRL/examples/ddqn/train_mnist.py: -------------------------------------------------------------------------------- 1 | from imbDRL.agents.ddqn import TrainDDQN 2 | from imbDRL.data import get_train_test_val, load_image 3 | from imbDRL.utils import rounded_dict 4 | from tensorflow.keras.layers import Conv2D, Dense, Flatten, MaxPooling2D 5 | 6 | episodes = 120_000 # Total number of episodes 7 | warmup_steps = 50_000 # Amount of warmup steps to collect data with random policy 8 | memory_length = 100_000 # Max length of the Replay Memory 9 | batch_size = 32 10 | collect_steps_per_episode = 1000 11 | collect_every = 1000 12 | 13 | target_update_period = 10_000 14 | target_update_tau = 1 15 | n_step_update = 4 16 | 17 | layers = [Conv2D(32, (5, 5), padding="Same", activation="relu"), 18 | MaxPooling2D(pool_size=(2, 2)), 19 | Conv2D(32, (5, 5), padding="Same", activation="relu"), 20 | MaxPooling2D(pool_size=(2, 2)), 21 | Flatten(), 22 | Dense(256, activation="relu"), 23 | Dense(2, activation=None)] 24 | 25 | learning_rate = 0.00025 # Learning rate 26 | gamma = 0.1 # Discount factor 27 | min_epsilon = 0.01 # Minimal and final chance of choosing random action 28 | decay_episodes = 100_000 # Number of episodes to decay from 1.0 to `min_epsilon` 29 | 30 | imb_ratio = 0.01 # Imbalance rate 31 | min_class = [2] # Minority classes 32 | maj_class = [0, 1, 3, 4, 5, 6, 7, 8, 9] # Majority classes 33 | X_train, y_train, X_test, y_test = load_image("mnist") 34 | X_train, y_train, X_test, y_test, X_val, y_val = get_train_test_val(X_train, y_train, X_test, y_test, 35 | min_class, maj_class, val_frac=0.1, imb_ratio=imb_ratio, imb_test=False) 36 | 37 | model = TrainDDQN(episodes, warmup_steps, learning_rate, gamma, min_epsilon, decay_episodes, target_update_period=target_update_period, 38 | target_update_tau=target_update_tau, batch_size=batch_size, collect_steps_per_episode=collect_steps_per_episode, 39 | memory_length=memory_length, collect_every=collect_every, n_step_update=n_step_update) 40 | 41 | model.compile_model(X_train, y_train, layers, imb_ratio=imb_ratio) 42 | model.train(X_val, y_val, "Gmean") 43 | 44 | stats = model.evaluate(X_test, y_test, X_train, y_train) 45 | print(rounded_dict(stats)) 46 | # {'Gmean': 0.987032, 'F1': 0.955702, 'Precision': 0.930275, 'Recall': 0.982558, 'TP': 1014, 'TN': 8892, 'FP': 76, 'FN': 18} 47 | -------------------------------------------------------------------------------- /imbDRL/examples/ddqn/train_credit.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from imbDRL.agents.ddqn import TrainDDQN 4 | from imbDRL.data import get_train_test_val, load_csv 5 | from imbDRL.utils import rounded_dict 6 | from tensorflow.keras.layers import Dense, Dropout 7 | 8 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # CPU is faster than GPU on structured data 9 | 10 | episodes = 100_000 # Total number of episodes 11 | warmup_steps = 170_000 # Amount of warmup steps to collect data with random policy 12 | memory_length = warmup_steps # Max length of the Replay Memory 13 | batch_size = 32 14 | collect_steps_per_episode = 2000 15 | collect_every = 500 16 | 17 | target_update_period = 800 # Period to overwrite the target Q-network with the default Q-network 18 | target_update_tau = 1 # Soften the target model update 19 | n_step_update = 1 20 | 21 | layers = [Dense(256, activation="relu"), Dropout(0.2), 22 | Dense(256, activation="relu"), Dropout(0.2), 23 | Dense(2, activation=None)] # No activation, pure Q-values 24 | 25 | learning_rate = 0.00025 # Learning rate 26 | gamma = 0.0 # Discount factor 27 | min_epsilon = 0.5 # Minimal and final chance of choosing random action 28 | decay_episodes = episodes // 10 # Number of episodes to decay from 1.0 to `min_epsilon`` 29 | 30 | min_class = [1] # Minority classes 31 | maj_class = [0] # Majority classes 32 | X_train, y_train, X_test, y_test = load_csv("./data/credit0.csv", "./data/credit1.csv", "Class", ["Time"], normalization=True) 33 | X_train, y_train, X_test, y_test, X_val, y_val = get_train_test_val(X_train, y_train, X_test, y_test, 34 | min_class, maj_class, val_frac=0.2) 35 | 36 | model = TrainDDQN(episodes, warmup_steps, learning_rate, gamma, min_epsilon, decay_episodes, target_update_period=target_update_period, 37 | target_update_tau=target_update_tau, batch_size=batch_size, collect_steps_per_episode=collect_steps_per_episode, 38 | memory_length=memory_length, collect_every=collect_every, n_step_update=n_step_update) 39 | 40 | model.compile_model(X_train, y_train, layers) 41 | model.q_net.summary() 42 | model.train(X_val, y_val, "F1") 43 | 44 | stats = model.evaluate(X_test, y_test, X_train, y_train) 45 | print(rounded_dict(stats)) 46 | # {'Gmean': 0.886281, 'F1': 0.806283, 'Precision': 0.827957, 'Recall': 0.785714, 'TP': 77, 'TN': 56848, 'FP': 16, 'FN': 21} 47 | -------------------------------------------------------------------------------- /imbDRL/examples/ddqn/train_famnist.py: -------------------------------------------------------------------------------- 1 | from imbDRL.agents.ddqn import TrainDDQN 2 | from imbDRL.data import get_train_test_val, load_image 3 | from imbDRL.utils import rounded_dict 4 | from tensorflow.keras.layers import Conv2D, Dense, Flatten, MaxPooling2D 5 | from tf_agents.utils import common 6 | 7 | episodes = 120_000 # Total number of episodes 8 | warmup_steps = 50_000 # Amount of warmup steps to collect data with random policy 9 | memory_length = 100_000 # Max length of the Replay Memory 10 | batch_size = 32 11 | collect_steps_per_episode = 1000 12 | collect_every = 1000 13 | 14 | target_update_period = 10_000 15 | target_update_tau = 1 16 | n_step_update = 4 17 | 18 | layers = [Conv2D(32, (5, 5), padding="Same", activation="relu"), 19 | MaxPooling2D(pool_size=(2, 2)), 20 | Conv2D(32, (5, 5), padding="Same", activation="relu"), 21 | MaxPooling2D(pool_size=(2, 2)), 22 | Flatten(), 23 | Dense(256, activation="relu"), 24 | Dense(2, activation=None)] 25 | 26 | learning_rate = 0.00025 # Learning rate 27 | gamma = 0.1 # Discount factor 28 | min_epsilon = 0.01 # Minimal and final chance of choosing random action 29 | decay_episodes = 100_000 # Number of episodes to decay from 1.0 to `min_epsilon` 30 | 31 | loss_fn = common.element_wise_huber_loss 32 | 33 | imb_ratio = 0.04 # Imbalance rate 34 | min_class = [4, 5, 6] # Minority classes 35 | maj_class = [7, 8, 9] # Majority classes 36 | X_train, y_train, X_test, y_test = load_image("famnist") 37 | X_train, y_train, X_test, y_test, X_val, y_val = get_train_test_val(X_train, y_train, X_test, y_test, min_class, maj_class, 38 | imb_ratio=imb_ratio, imb_test=False, val_frac=0.1) 39 | 40 | model = TrainDDQN(episodes, warmup_steps, learning_rate, gamma, min_epsilon, decay_episodes, target_update_period=target_update_period, 41 | target_update_tau=target_update_tau, batch_size=batch_size, collect_steps_per_episode=collect_steps_per_episode, 42 | memory_length=memory_length, collect_every=collect_every, n_step_update=n_step_update) 43 | 44 | model.compile_model(X_train, y_train, layers, imb_ratio=imb_ratio, loss_fn=loss_fn) 45 | model.train(X_val, y_val, "Gmean") 46 | 47 | stats = model.evaluate(X_test, y_test, X_train, y_train) 48 | print(rounded_dict(stats)) 49 | # {'Gmean': 0.964648, 'F1': 0.964877, 'Precision': 0.959157, 'Recall': 0.970667, 'TP': 2912, 'TN': 2876, 'FP': 124, 'FN': 88} 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom ignores 2 | .vscode 3 | data/ 4 | models/ 5 | logs/ 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | -------------------------------------------------------------------------------- /imbDRL/examples/ddqn/train_aki.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from imbDRL.agents.ddqn import TrainDDQN 4 | from imbDRL.data import get_train_test_val, load_csv 5 | from imbDRL.utils import rounded_dict 6 | from tensorflow.keras.layers import Dense, Dropout 7 | 8 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # CPU is faster than GPU on structured data 9 | 10 | episodes = 25_000 # Total number of episodes 11 | warmup_steps = 32_000 # Amount of warmup steps to collect data with random policy 12 | memory_length = warmup_steps # Max length of the Replay Memory 13 | batch_size = 32 14 | collect_steps_per_episode = 2000 15 | collect_every = episodes // 100 16 | 17 | target_update_period = episodes // 30 # Period to overwrite the target Q-network with the default Q-network 18 | target_update_tau = 1 # Soften the target model update 19 | n_step_update = 4 20 | 21 | layers = [Dense(256, activation="relu"), Dropout(0.2), 22 | Dense(256, activation="relu"), Dropout(0.2), 23 | Dense(2, activation=None)] # No activation, pure Q-values 24 | 25 | learning_rate = 0.001 # Learning rate 26 | gamma = 0.0 # Discount factor 27 | min_epsilon = 0.5 # Minimal and final chance of choosing random action 28 | decay_episodes = episodes // 10 # Number of episodes to decay from 1.0 to `min_epsilon` 29 | 30 | min_class = [1] # Minority classes 31 | maj_class = [0] # Majority classes 32 | X_train, y_train, X_test, y_test = load_csv("./data/aki0.csv", "./data/aki1.csv", "aki", ["hadm_id"], normalization=True) 33 | X_train, y_train, X_test, y_test, X_val, y_val = get_train_test_val(X_train, y_train, X_test, y_test, 34 | min_class, maj_class, val_frac=0.2) 35 | 36 | # X_train = X_train.reshape(X_train.shape[:-1] + (18, 7)) # RNN reshaping 37 | # X_test = X_test.reshape(X_test.shape[:-1] + (18, 7)) 38 | # X_val = X_val.reshape(X_val.shape[:-1] + (18, 7)) 39 | 40 | model = TrainDDQN(episodes, warmup_steps, learning_rate, gamma, min_epsilon, decay_episodes, target_update_period=target_update_period, 41 | target_update_tau=target_update_tau, batch_size=batch_size, collect_steps_per_episode=collect_steps_per_episode, 42 | memory_length=memory_length, collect_every=collect_every, n_step_update=n_step_update) 43 | 44 | model.compile_model(X_train, y_train, layers) 45 | model.train(X_val, y_val, "F1") 46 | 47 | stats = model.evaluate(X_test, y_test, X_train, y_train) 48 | print(rounded_dict(stats)) 49 | # {'Gmean': 0.736675, 'F1': 0.514218, 'Precision': 0.396468, 'Recall': 0.731461, 'TP': 651, 'TN': 2849, 'FP': 991, 'FN': 239} 50 | -------------------------------------------------------------------------------- /imbDRL/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from sklearn.model_selection import train_test_split 7 | 8 | 9 | def split_csv(fp: str = "./data/creditcard.csv", fp_dest: str = "./data", 10 | name: str = "credit", test_size: int = 0.5, strat_col: str = "Class") -> None: 11 | """Splits a csv file in two, in a stratified fashion. 12 | Format for filenames will be `{name}0.csv and `{name}1.csv`. 13 | 14 | :param fp: The path at which the csv file is located. 15 | :type fp: str 16 | :param fp_dest: The path to save the train and test files. 17 | :type fp_dest: str 18 | :param name: The prefix for the files. 19 | :type name: str 20 | :param test_size: The fraction of total size for the test file. 21 | :type test_size: float 22 | :param strat_col: The column in the original csv file to stratify. 23 | 24 | :return: None, two files located at `fp_dest`. 25 | :rtype: NoneType 26 | """ 27 | if not os.path.isfile(fp): 28 | raise FileNotFoundError(f"File at {fp} does not exist.") 29 | if not os.path.isdir(fp_dest): 30 | raise ValueError(f"Directory at {fp_dest} does not exist.") 31 | if not 0 < test_size < 1: 32 | raise ValueError(f"{test_size} is not in interval 0 < x < 1.") 33 | 34 | df = pd.read_csv(fp) 35 | 36 | if not (strat_col in df.columns): 37 | raise ValueError(f"Stratify column {strat_col} not found in DataFrame.") 38 | 39 | train, test = train_test_split(df, test_size=test_size, stratify=df[strat_col]) 40 | 41 | train.to_csv(f"{fp_dest}/{name}0.csv", index=False) 42 | test.to_csv(f"{fp_dest}/{name}1.csv", index=False) 43 | 44 | 45 | def rounded_dict(d: dict, precision: int = 6) -> dict: 46 | """Rounds all values in a dictionairy to `precision` digits after the decimal point. 47 | 48 | :param d: Dictionairy containing only floats or ints as values 49 | :type d: dict 50 | 51 | :return: Rounded dictionairy 52 | :rtype: dict 53 | """ 54 | return {k: round(v, precision) for k, v in d.items()} 55 | 56 | 57 | def imbalance_ratio(y: np.ndarray, min_classes: List[int] = [1], maj_classes: List[int] = [0]) -> float: 58 | """Calculates imbalance ratio of minority class(es) and majority class(es). 59 | 60 | :param y: y-vector with labels. 61 | :type y: np.ndarray 62 | :param min_classes: The labels of the minority classes 63 | :type min_classes: list 64 | :param maj_classes: The labels of the minority classes 65 | :type maj_classes: list 66 | 67 | :return: The imbalance ratio 68 | :rtype: float 69 | """ 70 | return np.isin(y, min_classes).sum() / np.isin(y, maj_classes).sum() 71 | -------------------------------------------------------------------------------- /imbDRL/examples/ddqn/train_titanic.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import tensorflow_datasets as tfds 5 | from imbDRL.agents.ddqn import TrainDDQN 6 | from imbDRL.data import get_train_test_val 7 | from imbDRL.utils import rounded_dict 8 | from sklearn.model_selection import train_test_split 9 | from tensorflow.keras.layers import Dense, Dropout 10 | 11 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # CPU is faster than GPU on structured data 12 | 13 | episodes = 16_000 # Total number of episodes 14 | warmup_steps = 16_000 # Amount of warmup steps to collect data with random policy 15 | memory_length = warmup_steps # Max length of the Replay Memory 16 | batch_size = 32 17 | collect_steps_per_episode = 500 18 | collect_every = 500 19 | 20 | target_update_period = 400 # Period to overwrite the target Q-network with the default Q-network 21 | target_update_tau = 1 # Soften the target model update 22 | n_step_update = 1 23 | 24 | layers = [Dense(256, activation="relu"), Dropout(0.2), 25 | Dense(256, activation="relu"), Dropout(0.2), 26 | Dense(2, activation=None)] # No activation, pure Q-values 27 | 28 | learning_rate = 0.00025 # Learning rate 29 | gamma = 0.0 # Discount factor 30 | min_epsilon = 0.5 # Minimal and final chance of choosing random action 31 | decay_episodes = episodes // 10 # Number of episodes to decay from 1.0 to `min_epsilon`` 32 | 33 | min_class = [1] # Minority classes 34 | maj_class = [0] # Majority classes 35 | 36 | df = tfds.as_dataframe(*tfds.load("titanic", split='train', with_info=True)) 37 | y = df.survived.values 38 | df = df.drop(columns=["survived", "features/boat", "features/cabin", "features/home.dest", "features/name", "features/ticket"]) 39 | df = df.astype(np.float64) 40 | df = (df - df.min()) / (df.max() - df.min()) # Normalization should happen after splitting train and test sets 41 | 42 | X_train, X_test, y_train, y_test = train_test_split(df.to_numpy(), y, stratify=y, test_size=0.2) 43 | X_train, y_train, X_test, y_test, X_val, y_val = get_train_test_val(X_train, y_train, X_test, y_test, 44 | min_class, maj_class, val_frac=0.2) 45 | 46 | model = TrainDDQN(episodes, warmup_steps, learning_rate, gamma, min_epsilon, decay_episodes, target_update_period=target_update_period, 47 | target_update_tau=target_update_tau, batch_size=batch_size, collect_steps_per_episode=collect_steps_per_episode, 48 | memory_length=memory_length, collect_every=collect_every, n_step_update=n_step_update) 49 | 50 | model.compile_model(X_train, y_train, layers) 51 | model.q_net.summary() 52 | model.train(X_val, y_val, "F1") 53 | 54 | stats = model.evaluate(X_test, y_test, X_train, y_train) 55 | print(rounded_dict(stats)) 56 | # {'Gmean': 0.824172, 'F1': 0.781395, 'Precision': 0.730435, 'Recall': 0.84, 'TP': 84, 'TN': 131, 'FP': 31, 'FN': 16} 57 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # imbDRL 2 | 3 | ![GitHub Workflow Status](https://img.shields.io/github/workflow/status/Denbergvanthijs/imbDRL/Build) ![License](https://img.shields.io/github/license/Denbergvanthijs/imbDRL) 4 | 5 | ***Imbalanced Classification with Deep Reinforcement Learning.*** 6 | 7 | This repository contains an (Double) Deep Q-Network implementation of binary classification on unbalanced datasets using [TensorFlow 2.3+](https://www.tensorflow.org/) and [TF Agents 0.6+](https://www.tensorflow.org/agents). The Double DQN as published in [this paper](https://arxiv.org/abs/1509.06461) by *van Hasselt et al. (2015)* is using a custom environment based on [this paper](https://arxiv.org/abs/1901.01379) by *Lin, Chen & Qi (2019)*. 8 | 9 | Example scripts on the [Mnist](http://yann.lecun.com/exdb/mnist/), [Fashion Mnist](https://github.com/zalandoresearch/fashion-mnist), [Credit Card Fraud](https://www.kaggle.com/mlg-ulb/creditcardfraud) and [Titanic](https://www.tensorflow.org/datasets/catalog/titanic) datasets can be found in the `./imbDRL/examples/ddqn/` folder. 10 | 11 | ## Results 12 | 13 | The following results are collected with the scripts in the appendix: [imbDRLAppendix](https://github.com/Denbergvanthijs/imbDRLAppendix). Experiments conducted on the latest release of *imbDRL* and based on [this paper](https://arxiv.org/abs/1901.01379) by *Lin, Chen & Qi (2019)*. 14 | 15 | ![Results](./images/results.png) 16 | 17 | ## Requirements 18 | 19 | * [Python 3.7+](https://www.python.org/) 20 | * The required packages as listed in: `requirements.txt` 21 | * Logs are by default saved in `./logs/` 22 | * Trained models are by default saved in `./models/` 23 | * Optional: `./data/` folder located at the root of this repository. 24 | * This folder must contain ```creditcard.csv``` downloaded from [Kaggle](https://www.kaggle.com/mlg-ulb/creditcardfraud) if you would like to use the [Credit Card Fraud](https://www.kaggle.com/mlg-ulb/creditcardfraud) dataset. 25 | * Note: `creditcard.csv` needs to be split in a seperate train and test file. Please use the function `imbDRL.utils.split_csv` 26 | 27 | ## Getting started 28 | 29 | Install via `pip`: 30 | 31 | * `pip install imbDRL` 32 | 33 | Run any of the following scripts: 34 | 35 | * `python .\imbDRL\examples\ddqn\train_credit.py` 36 | * `python .\imbDRL\examples\ddqn\train_famnist.py` 37 | * `python .\imbDRL\examples\ddqn\train_mnist.py` 38 | * `python .\imbDRL\examples\ddqn\train_titanic.py` 39 | 40 | ## TensorBoard 41 | 42 | To enable [TensorBoard](https://www.tensorflow.org/tensorboard), run ```tensorboard --logdir logs``` 43 | 44 | ## Tests and linting 45 | 46 | Extra arguments are handled with the `./tox.ini` file. 47 | 48 | * Pytest: `python -m pytest` 49 | * Flake8: `flake8` 50 | * Coverage can be found in the generated `./htmlcov` folder 51 | 52 | ## Appendix 53 | 54 | The appendix can be found in the [imbDRLAppendix](https://github.com/Denbergvanthijs/imbDRLAppendix) repository. 55 | -------------------------------------------------------------------------------- /tests/test_ddqn.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | import numpy as np 4 | import pytest 5 | from imbDRL.agents.ddqn import TrainDDQN 6 | from tensorflow.keras.layers import Dense 7 | 8 | 9 | def test_TrainDDQN(tmp_path): 10 | """Tests imbDRL.agents.ddqn.TrainDDQN.""" 11 | tmp_models = str(tmp_path / "test_model") # No support for pathLib https://github.com/tensorflow/tensorflow/issues/37357 12 | tmp_logs = str(tmp_path / "test_log") 13 | 14 | model = TrainDDQN(10, 10, 0.001, 0.0, 0.1, 5, model_path=tmp_models, log_dir=tmp_logs) 15 | assert model.model_path == tmp_models 16 | assert not model.compiled 17 | 18 | NOW = datetime.now().strftime("%Y%m%d") # yyyymmdd 19 | model = TrainDDQN(10, 10, 0.001, 0.0, 0.1, 5) 20 | assert "./models/" + NOW in model.model_path # yyyymmdd in yyyymmdd_hhmmss 21 | 22 | 23 | def test_compile_model(tmp_path): 24 | """Tests imbDRL.agents.ddqn.TrainDDQN.compile_model.""" 25 | tmp_models = str(tmp_path / "test_model") # No support for pathLib https://github.com/tensorflow/tensorflow/issues/37357 26 | tmp_logs = str(tmp_path / "test_log") 27 | 28 | model = TrainDDQN(10, 10, 0.001, 0.0, 0.1, 5, model_path=tmp_models, log_dir=tmp_logs) 29 | assert not model.compiled 30 | model.compile_model(np.random.rand(4, 12).astype(np.float32), np.random.choice(2, size=4).astype(np.int32), [Dense(4), Dense(2)]) 31 | assert model.compiled 32 | 33 | 34 | def test_train(tmp_path): 35 | """Tests imbDRL.agents.ddqn.TrainDDQN.train.""" 36 | tmp_models = str(tmp_path / "test_model") # No support for pathLib https://github.com/tensorflow/tensorflow/issues/37357 37 | tmp_logs = str(tmp_path / "test_log") 38 | 39 | model = TrainDDQN(10, 10, 0.001, 0.0, 0.1, 5, model_path=tmp_models, log_dir=tmp_logs, val_every=2, memory_length=30) 40 | 41 | with pytest.raises(Exception) as exc: 42 | model.train() 43 | assert "must be compiled" in str(exc.value) 44 | 45 | model.compile_model(np.random.rand(4, 10).astype(np.float32), np.random.choice(2, size=4).astype(np.int32), [Dense(4), Dense(2)]) 46 | model.train(np.random.rand(4, 10).astype(np.float32), np.random.choice(2, size=4).astype(np.int32)) 47 | assert model.replay_buffer.num_frames().numpy() >= 10 + 10 # 10 for warmup + 1 for each episode 48 | assert model.global_episode == 10 49 | 50 | model = TrainDDQN(10, 10, 0.001, 0.0, 0.1, 5, model_path=tmp_models, log_dir=tmp_logs, val_every=2) 51 | 52 | with pytest.raises(Exception) as exc: 53 | model.train() 54 | assert "must be compiled" in str(exc.value) 55 | 56 | model.compile_model(np.random.rand(4, 10).astype(np.float32), np.random.choice(2, size=4).astype(np.int32), [Dense(4), Dense(2)]) 57 | model.train(np.random.rand(4, 10).astype(np.float32), np.random.choice(2, size=4).astype(np.int32)) 58 | assert model.replay_buffer.num_frames() >= 10 # 10 in total since no memory length is defined 59 | assert model.global_episode == 10 60 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | import imbDRL.metrics as metrics 2 | import numpy as np 3 | import pytest 4 | import tensorflow as tf 5 | from imbDRL.utils import rounded_dict 6 | 7 | 8 | def test_network_predictions(): 9 | """Tests imbDRL.metrics.network_predictions.""" 10 | X = [7, 7, 7, 8, 8, 8] 11 | 12 | with pytest.raises(ValueError) as exc: 13 | metrics.network_predictions([], X) 14 | assert "`X` must be of type" in str(exc.value) 15 | 16 | X = np.array([[1, 2], [2, 1], [3, 4], [4, 3]]) 17 | y_pred = metrics.network_predictions(lambda x, step_type, training: (tf.convert_to_tensor(x), None), X) 18 | assert np.array_equal(y_pred, [1, 0, 1, 0]) 19 | 20 | 21 | def test_decision_function(): 22 | """Tests imbDRL.metrics.decision_function.""" 23 | X = [7, 7, 7, 8, 8, 8] 24 | 25 | with pytest.raises(ValueError) as exc: 26 | metrics.decision_function([], X) 27 | assert "`X` must be of type" in str(exc.value) 28 | 29 | X = np.array([[1, 2], [2, 1], [3, 4], [4, 3], [-1, 0], [-1, -10]]) 30 | y_pred = metrics.decision_function(lambda x, step_type, training: (tf.convert_to_tensor(x), None), X) 31 | assert np.array_equal(y_pred, [2, 2, 4, 4, 0, -1]) 32 | 33 | 34 | def test_classification_metrics(): 35 | """Tests imbDRL.metrics.classification_metrics.""" 36 | y_true = [1, 1, 1, 1, 1, 1] 37 | y_pred = [1, 1, 1, 0, 0, 0] 38 | 39 | with pytest.raises(ValueError) as exc: 40 | metrics.classification_metrics(1, y_pred) 41 | assert "`y_true` must be of type" in str(exc.value) 42 | 43 | with pytest.raises(ValueError) as exc: 44 | metrics.classification_metrics(y_true, -1) 45 | assert "`y_pred` must be of type" in str(exc.value) 46 | 47 | with pytest.raises(ValueError) as exc: 48 | metrics.classification_metrics(y_true, y_pred + [1]) 49 | assert "must be of same length" in str(exc.value) 50 | 51 | stats = metrics.classification_metrics(y_true, y_pred) 52 | approx = rounded_dict(stats) 53 | assert approx == {"Gmean": 0.0, "F1": 0.666667, "Precision": 1.0, "Recall": 0.5, "TP": 3, "TN": 0, "FP": 0, "FN": 3} 54 | 55 | y_true = [1, 1, 1, 1, 1, 1] 56 | y_pred = [0, 0, 0, 0, 0, 0] 57 | stats = metrics.classification_metrics(y_true, y_pred) 58 | approx = rounded_dict(stats) 59 | assert approx == {"Gmean": 0.0, "F1": 0.0, "Precision": 0.0, "Recall": 0.0, "TP": 0, "TN": 0, "FP": 0, "FN": 6} 60 | 61 | y_true = [0, 0, 0, 0, 0, 0] 62 | y_pred = [1, 1, 1, 1, 1, 1] 63 | stats = metrics.classification_metrics(y_true, y_pred) 64 | approx = rounded_dict(stats) 65 | assert approx == {"Gmean": 0.0, "F1": 0.0, "Precision": 0.0, "Recall": 0.0, "TP": 0, "TN": 0, "FP": 6, "FN": 0} 66 | 67 | 68 | def test_plot_confusion_matrix(): 69 | """Tests imbDRL.utils.plot_confusion_matrix.""" 70 | with pytest.raises(ValueError) as exc: 71 | metrics.plot_confusion_matrix(1, 2, 3, "test") 72 | assert "Not all arguments are integers" in str(exc.value) 73 | -------------------------------------------------------------------------------- /imbDRL/environments/classifierenv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tf_agents.environments.py_environment import PyEnvironment 3 | from tf_agents.specs.array_spec import ArraySpec, BoundedArraySpec 4 | from tf_agents.trajectories import time_step as ts 5 | 6 | 7 | class ClassifierEnv(PyEnvironment): 8 | """ 9 | Custom `PyEnvironment` environment for imbalanced classification. 10 | Based on https://www.tensorflow.org/agents/tutorials/2_environments_tutorial 11 | """ 12 | 13 | def __init__(self, X_train: np.ndarray, y_train: np.ndarray, imb_ratio: float): 14 | """Initialization of environment with X_train and y_train. 15 | 16 | :param X_train: Features shaped: [samples, ..., ] 17 | :type X_train: np.ndarray 18 | :param y_train: Labels shaped: [samples] 19 | :type y_train: np.ndarray 20 | :param imb_ratio: Imbalance ratio of the data 21 | :type imb_ratio: float 22 | 23 | :returns: None 24 | :rtype: NoneType 25 | 26 | """ 27 | self._action_spec = BoundedArraySpec(shape=(), dtype=np.int32, minimum=0, maximum=1, name="action") 28 | self._observation_spec = ArraySpec(shape=X_train.shape[1:], dtype=X_train.dtype, name="observation") 29 | self._episode_ended = False 30 | 31 | self.X_train = X_train 32 | self.y_train = y_train 33 | self.imb_ratio = imb_ratio # Imbalance ratio: 0 < imb_ratio < 1 34 | self.id = np.arange(self.X_train.shape[0]) # List of IDs to connect X and y data 35 | 36 | self.episode_step = 0 # Episode step, resets every episode 37 | self._state = self.X_train[self.id[self.episode_step]] 38 | 39 | def action_spec(self): 40 | """ 41 | Definition of the discrete actionspace. 42 | 1 for the positive/minority class, 0 for the negative/majority class. 43 | """ 44 | return self._action_spec 45 | 46 | def observation_spec(self): 47 | """Definition of the continous statespace e.g. the observations in typical RL environments.""" 48 | return self._observation_spec 49 | 50 | def _reset(self): 51 | """Shuffles data and returns the first state of the shuffled data to begin training on new episode.""" 52 | np.random.shuffle(self.id) # Shuffle the X and y data 53 | self.episode_step = 0 # Reset episode step counter at the end of every episode 54 | self._state = self.X_train[self.id[self.episode_step]] 55 | self._episode_ended = False # Reset terminal condition 56 | 57 | return ts.restart(self._state) 58 | 59 | def _step(self, action: int): 60 | """ 61 | Take one step in the environment. 62 | If the action is correct, the environment will either return 1 or `imb_ratio` depending on the current class. 63 | If the action is incorrect, the environment will either return -1 or -`imb_ratio` depending on the current class. 64 | """ 65 | if self._episode_ended: 66 | # The last action ended the episode. Ignore the current action and start a new episode 67 | return self.reset() 68 | 69 | env_action = self.y_train[self.id[self.episode_step]] # The label of the current state 70 | self.episode_step += 1 71 | 72 | if action == env_action: # Correct action 73 | if env_action: # Minority 74 | reward = 1 # True Positive 75 | else: # Majority 76 | reward = self.imb_ratio # True Negative 77 | 78 | else: # Incorrect action 79 | if env_action: # Minority 80 | reward = -1 # False Negative 81 | self._episode_ended = True # Stop episode when minority class is misclassified 82 | else: # Majority 83 | reward = -self.imb_ratio # False Positive 84 | 85 | if self.episode_step == self.X_train.shape[0] - 1: # If last step in data 86 | self._episode_ended = True 87 | 88 | self._state = self.X_train[self.id[self.episode_step]] # Update state with new datapoint 89 | 90 | if self._episode_ended: 91 | return ts.termination(self._state, reward) 92 | else: 93 | return ts.transition(self._state, reward) 94 | -------------------------------------------------------------------------------- /imbDRL/metrics.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import seaborn as sns 4 | from sklearn.metrics import (auc, average_precision_score, confusion_matrix, 5 | f1_score, precision_recall_curve, roc_curve) 6 | from tensorflow import constant 7 | from tf_agents.trajectories import time_step 8 | 9 | 10 | def network_predictions(network, X: np.ndarray) -> dict: 11 | """Computes y_pred using a given network. 12 | Input is array of data entries. 13 | 14 | :param network: The network to use to calculate metrics 15 | :type network: (Q)Network 16 | :param X: X data, input to network 17 | :type X: np.ndarray 18 | 19 | :return: Numpy array of predicted targets for given X 20 | :rtype: np.ndarray 21 | """ 22 | if not isinstance(X, np.ndarray): 23 | raise ValueError(f"`X` must be of type `np.ndarray` not {type(X)}") 24 | 25 | q, _ = network(X, step_type=constant([time_step.StepType.FIRST] * X.shape[0]), training=False) 26 | return np.argmax(q.numpy(), axis=1) # Max action for each x in X 27 | 28 | 29 | def decision_function(network, X: np.ndarray) -> dict: 30 | """Computes the score for the predicted class of each x in X using a given network. 31 | Input is array of data entries. 32 | 33 | :param network: The network to use to calculate the score per x in X 34 | :type network: (Q)Network 35 | :param X: X data, input to network 36 | :type X: np.ndarray 37 | 38 | :return: Numpy array of scores for given X 39 | :rtype: np.ndarray 40 | """ 41 | if not isinstance(X, np.ndarray): 42 | raise ValueError(f"`X` must be of type `np.ndarray` not {type(X)}") 43 | 44 | q, _ = network(X, step_type=constant([time_step.StepType.FIRST] * X.shape[0]), training=False) 45 | return np.max(q.numpy(), axis=1) # Value of max action for each x in X 46 | 47 | 48 | def classification_metrics(y_true: list, y_pred: list) -> dict: 49 | """Computes metrics using y_true and y_pred. 50 | 51 | :param y_true: True labels 52 | :type y_true: np.ndarray 53 | :param y_pred: Predicted labels, corresponding to y_true 54 | :type y_pred: np.ndarray 55 | 56 | :return: Dictionairy containing Geometric Mean, F1, Precision, Recall, TP, TN, FP, FN 57 | :rtype: dict 58 | """ 59 | if not isinstance(y_true, (list, tuple, np.ndarray)): 60 | raise ValueError(f"`y_true` must be of type `list` not {type(y_true)}") 61 | if not isinstance(y_pred, (list, tuple, np.ndarray)): 62 | raise ValueError(f"`y_pred` must be of type `list` not {type(y_pred)}") 63 | if len(y_true) != len(y_pred): 64 | raise ValueError("`X` and `y` must be of same length.") 65 | 66 | # labels=[0, 1] to ensure 4 elements are returned: https://stackoverflow.com/a/46230267 67 | TN, FP, FN, TP = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel() 68 | 69 | precision = TP / (TP + FP) if TP + FP else 0 # Positive predictive value 70 | recall = TP / (TP + FN) if TP + FN else 0 # Sensitivity, True Positive Rate (TPR) 71 | specificity = TN / (TN + FP) if TN + FP else 0 # Specificity, selectivity, True Negative Rate (TNR) 72 | 73 | G_mean = np.sqrt(recall * specificity) # Geometric mean of recall and specificity 74 | F1 = f1_score(y_true, y_pred, zero_division=0) # Default F-measure 75 | 76 | return {"Gmean": G_mean, "F1": F1, "Precision": precision, "Recall": recall, "TP": TP, "TN": TN, "FP": FP, "FN": FN} 77 | 78 | 79 | def plot_pr_curve(network, X_test: np.ndarray, y_test: np.ndarray, 80 | X_train: np.ndarray = None, y_train: np.ndarray = None) -> None: # pragma: no cover 81 | """Plots PR curve of X_test and y_test of given network. 82 | Optionally plots PR curve of X_train and y_train. 83 | Average precision is shown in the legend. 84 | 85 | :param network: The network to use to calculate the PR curve 86 | :type network: (Q)Network 87 | :param X_test: X data, input to network 88 | :type X_test: np.ndarray 89 | :param y_test: True labels for `X_test` 90 | :type y_test: np.ndarray 91 | :param X_train: Optional X data to plot validation PR curve 92 | :type X_train: np.ndarray 93 | :param y_train: True labels for `X_val` 94 | :type y_train: np.ndarray 95 | 96 | :return: None 97 | :rtype: NoneType 98 | """ 99 | plt.plot((0, 1), (1, 0), color="black", linestyle="--", label="Baseline") 100 | # TODO: Consider changing baseline 101 | 102 | if X_train is not None and y_train is not None: 103 | y_val_score = decision_function(network, X_train) 104 | val_precision, val_recall, _ = precision_recall_curve(y_train, y_val_score) 105 | val_AP = average_precision_score(y_train, y_val_score) 106 | plt.plot(val_recall, val_precision, label=f"Train AP: {val_AP:.3f}") 107 | 108 | y_test_score = decision_function(network, X_test) 109 | test_precision, test_recall, _ = precision_recall_curve(y_test, y_test_score) 110 | test_AP = average_precision_score(y_test, y_test_score) 111 | 112 | plt.plot(test_recall, test_precision, label=f"Test AP: {test_AP:.3f}") 113 | plt.xlim((-0.05, 1.05)) 114 | plt.ylim((-0.05, 1.05)) 115 | plt.xlabel("Recall") 116 | plt.ylabel("Precision") 117 | plt.title("PR Curve") 118 | plt.gca().set_aspect("equal", adjustable="box") 119 | plt.legend(loc="lower left") 120 | plt.grid(True) 121 | plt.show() 122 | 123 | 124 | def plot_roc_curve(network, X_test: np.ndarray, y_test: np.ndarray, 125 | X_train: np.ndarray = None, y_train: np.ndarray = None) -> None: # pragma: no cover 126 | """Plots ROC curve of X_test and y_test of given network. 127 | Optionally plots ROC curve of X_train and y_train. 128 | Average precision is shown in the legend. 129 | 130 | :param network: The network to use to calculate the PR curve 131 | :type network: (Q)Network 132 | :param X_test: X data, input to network 133 | :type X_test: np.ndarray 134 | :param y_test: True labels for `X_test` 135 | :type y_test: np.ndarray 136 | :param X_train: Optional X data to plot validation PR curve 137 | :type X_train: np.ndarray 138 | :param y_train: True labels for `X_val` 139 | :type y_train: np.ndarray 140 | 141 | :return: None 142 | :rtype: NoneType 143 | """ 144 | plt.plot((0, 1), (0, 1), color="black", linestyle="--", label="Baseline") 145 | # TODO: Consider changing baseline 146 | 147 | if X_train is not None and y_train is not None: 148 | y_train_score = decision_function(network, X_train) 149 | fpr_train, tpr_train, _ = roc_curve(y_train, y_train_score) 150 | plt.plot(fpr_train, tpr_train, label=f"Train AUROC: {auc(fpr_train, tpr_train):.2f}") 151 | 152 | y_test_score = decision_function(network, X_test) 153 | fpr_test, tpr_test, _ = roc_curve(y_test, y_test_score) 154 | 155 | plt.plot(fpr_test, tpr_test, label=f"Test AUROC: {auc(fpr_test, tpr_test):.2f}") 156 | plt.xlim((-0.05, 1.05)) 157 | plt.ylim((-0.05, 1.05)) 158 | plt.xlabel("False Positive Rate") 159 | plt.ylabel("True Positive Rate") 160 | plt.title("ROC Curve") 161 | plt.gca().set_aspect("equal", adjustable="box") 162 | plt.legend(loc="lower right") 163 | plt.grid(True) 164 | plt.show() 165 | 166 | 167 | def plot_confusion_matrix(TP: int, FN: int, FP: int, TN: int) -> None: # pragma: no cover 168 | """Plots confusion matric of given TP, FN, FP, TN. 169 | 170 | :param TP: True Positive 171 | :type TP: int 172 | :param FN: False Negative 173 | :type FN: int 174 | :param FP: False Positive 175 | :type FP: int 176 | :param TN: True Negative 177 | :type TN: int 178 | 179 | :return: None 180 | :rtype: NoneType 181 | """ 182 | if not all(isinstance(i, (int, np.integer)) for i in (TP, FN, FP, TN)): 183 | raise ValueError("Not all arguments are integers.") 184 | 185 | ticklabels = ("Minority", "Majority") 186 | sns.heatmap(((TP, FN), (FP, TN)), annot=True, fmt="_d", cmap="viridis", xticklabels=ticklabels, yticklabels=ticklabels) 187 | 188 | plt.title("Confusion matrix") 189 | plt.xlabel("Predicted labels") 190 | plt.ylabel("True labels") 191 | plt.show() 192 | -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | import imbDRL.data as data 2 | import numpy as np 3 | import pytest 4 | 5 | 6 | def test_load_image(): 7 | """Tests imbDRL.data.load_image.""" 8 | # Empty `data_source` 9 | with pytest.raises(ValueError) as exc: 10 | data.load_image("") 11 | assert "No valid" in str(exc.value) 12 | 13 | # Integer `data_source` 14 | with pytest.raises(ValueError) as exc: 15 | data.load_image(1234) 16 | assert "No valid" in str(exc.value) 17 | 18 | # Non-existing `data_source` 19 | with pytest.raises(ValueError) as exc: 20 | data.load_image("credit") 21 | assert "No valid" in str(exc.value) 22 | 23 | image_data = data.load_image("mnist") 24 | assert [x.shape for x in image_data] == [(60000, 28, 28, 1), (60000, ), (10000, 28, 28, 1), (10000, )] 25 | assert [x.dtype for x in image_data] == ["float32", "int32", "float32", "int32"] 26 | 27 | image_data = data.load_image("famnist") 28 | assert [x.shape for x in image_data] == [(60000, 28, 28, 1), (60000, ), (10000, 28, 28, 1), (10000, )] 29 | assert [x.dtype for x in image_data] == ["float32", "int32", "float32", "int32"] 30 | 31 | image_data = data.load_image("cifar10") 32 | assert [x.shape for x in image_data] == [(50000, 32, 32, 3), (50000, ), (10000, 32, 32, 3), (10000, )] 33 | assert [x.dtype for x in image_data] == ["float32", "int32", "float32", "int32"] 34 | 35 | 36 | def test_load_csv(tmp_path): 37 | """Tests imbDRL.data.load_csv.""" 38 | cols = "Time,V1,V2,V3,V4,V5,V6,V7,V8,V9,V10,V11,V12,V13,V14,V15,V16,V17,V18,V19,V20,V21,V22,V23,V24,V25,V26,V27,V28,Amount,Class\n" 39 | row1 = str(list(range(0, 31))).strip("[]") + "\n" 40 | row2 = str(list(range(31, 62))).strip("[]") + "\n" 41 | row3 = str(list(range(62, 93))).strip("[]") + "\n" 42 | 43 | with pytest.raises(FileNotFoundError) as exc: 44 | data.load_csv(tmp_path / "thisfiledoesnotexist.csv", tmp_path / "thisfiledoesnotexist.csv", "Class", ["Time"]) 45 | assert "fp_train" in str(exc.value) 46 | 47 | data_file = tmp_path / "data_file.csv" 48 | with open(data_file, "w") as f: 49 | f.writelines([cols, row1, row2, row3]) 50 | 51 | with pytest.raises(FileNotFoundError) as exc: 52 | data.load_csv(data_file, tmp_path / "thisfiledoesnotexist.csv", "Class", ["Time"]) 53 | assert "fp_test" in str(exc.value) 54 | 55 | with pytest.raises(TypeError) as exc: 56 | data.load_csv(data_file, data_file, "Class", ["Time"], normalization=1234) 57 | assert "must be of type `bool`" in str(exc.value) 58 | 59 | credit_data = data.load_csv(data_file, data_file, "Class", ["Time"], normalization=False) 60 | assert [x.shape for x in credit_data] == [(3, 29), (3, ), (3, 29), (3, )] 61 | assert [x.dtype for x in credit_data] == ["float32", "int32", "float32", "int32"] 62 | assert np.array_equal(credit_data[0][0], np.arange(1, 30, dtype=np.float32)) # No normalization 63 | assert np.array_equal(credit_data[0][1], np.arange(32, 61, dtype=np.float32)) 64 | assert np.array_equal(credit_data[0][2], np.arange(63, 92, dtype=np.float32)) 65 | 66 | credit_data = data.load_csv(data_file, data_file, "Class", ["Time"], normalization=True) 67 | assert [x.shape for x in credit_data] == [(3, 29), (3, ), (3, 29), (3, )] 68 | assert [x.dtype for x in credit_data] == ["float32", "int32", "float32", "int32"] 69 | assert np.array_equal(credit_data[0][0], np.zeros(29, dtype=np.float32)) # Min value 70 | assert np.array_equal(credit_data[0][1], np.full(29, 0.5, dtype=np.float32)) # Halfway 71 | assert np.array_equal(credit_data[0][2], np.ones(29, dtype=np.float32)) # Max value 72 | 73 | 74 | def test_load_imdb(): 75 | """Tests imbDRL.data.load_imdb.""" 76 | # Integer `config` 77 | with pytest.raises(TypeError) as exc: 78 | data.load_imdb(config=500) 79 | assert "is no valid datatype" in str(exc.value) 80 | 81 | # Wrong tuple length `config` 82 | with pytest.raises(ValueError) as exc: 83 | data.load_imdb(config=(100, 100, 100)) 84 | assert "must be 2" in str(exc.value) 85 | 86 | # Negative `config` 87 | with pytest.raises(ValueError) as exc: 88 | data.load_imdb(config=(-100, 10)) 89 | assert "must be > 0" in str(exc.value) 90 | 91 | # Negative `config` 92 | with pytest.raises(ValueError) as exc: 93 | data.load_imdb(config=(100, -10)) 94 | assert "must be > 0" in str(exc.value) 95 | 96 | imdb_data = data.load_imdb(config=(10, 5)) 97 | assert [x.shape for x in imdb_data] == [(25000, 5), (25000, ), (25000, 5), (25000, )] 98 | assert [x.dtype for x in imdb_data] == ["int32", "int32", "int32", "int32"] 99 | 100 | 101 | def test_get_train_test_val(capsys): 102 | """Tests imbDRL.data.get_train_test_val.""" 103 | X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) 104 | y = np.array([1, 0, 0, 0]) 105 | 106 | with pytest.raises(ValueError) as exc: 107 | data.get_train_test_val(X, y, X, y, 0.2, [0], [1, 2], val_frac=0.0) 108 | assert "not in interval" in str(exc.value) 109 | 110 | with pytest.raises(ValueError) as exc: 111 | data.get_train_test_val(X, y, X, y, 0.2, [0], [1, 2], val_frac=1) 112 | assert "not in interval" in str(exc.value) 113 | 114 | with pytest.raises(TypeError) as exc: 115 | data.get_train_test_val(X, y, X, y, 0.2, [0], [1, 2], print_stats=1234) 116 | assert "must be of type" in str(exc.value) 117 | 118 | X_train, y_train, X_test, y_test, X_val, y_val = data.get_train_test_val(X, y, X, y, [1], [0], imb_ratio=0.25, print_stats=False) 119 | assert X_train.shape == (2, 2) 120 | assert X_test.shape == (3, 2) 121 | assert X_val.shape == (1, 2) 122 | assert y_train.shape == (2, ) 123 | assert y_test.shape == (3, ) 124 | assert y_val.shape == (1, ) 125 | 126 | data.get_train_test_val(X, y, X, y, [1], [0], imb_ratio=0.25, print_stats=True) # Check if printing 127 | captured = capsys.readouterr() 128 | assert captured.out == ("Imbalance ratio `p`:\n" 129 | "\ttrain: n=0, p=0.000000\n" 130 | "\ttest: n=0, p=0.000000\n" 131 | "\tvalidation: n=0, p=0.000000\n") 132 | 133 | data.get_train_test_val(X, y, X, y, [1], [0], imb_ratio=0.25, print_stats=False) # Check if not printing 134 | captured = capsys.readouterr() 135 | assert captured.out == "" 136 | 137 | 138 | def test_imbalance_data(): 139 | """Tests imbDRL.data.imbalance_data.""" 140 | X = [[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6]] 141 | y = [2, 2, 2, 3, 3, 3] 142 | 143 | with pytest.raises(TypeError) as exc: 144 | data.imbalance_data(X, np.array(y), [2], [3], imb_ratio=0.5) 145 | assert "`X` must be of type" in str(exc.value) 146 | 147 | with pytest.raises(TypeError) as exc: 148 | data.imbalance_data(np.array(X), y, [2], [3], imb_ratio=0.5) 149 | assert "`y` must be of type" in str(exc.value) 150 | 151 | X = np.array(X) 152 | y = np.array(y) 153 | 154 | with pytest.raises(ValueError) as exc: 155 | data.imbalance_data(X, y, [2], [3], imb_ratio=0.0) 156 | assert "not in interval" in str(exc.value) 157 | 158 | with pytest.raises(ValueError) as exc: 159 | data.imbalance_data(X, y, [2], [3], imb_ratio=0.0) 160 | assert "not in interval" in str(exc.value) 161 | 162 | with pytest.raises(TypeError) as exc: 163 | data.imbalance_data(X, y, 2, [3], imb_ratio=0.0) 164 | assert "`min_class` must be of type list or tuple" in str(exc.value) 165 | 166 | with pytest.raises(TypeError) as exc: 167 | data.imbalance_data(X, y, [2], 3, imb_ratio=0.0) 168 | assert "`maj_class` must be of type list or tuple" in str(exc.value) 169 | 170 | X = np.arange(10).reshape(5, 2) 171 | y = np.arange(6) 172 | 173 | with pytest.raises(ValueError) as exc: 174 | data.imbalance_data(X, y, [1], [0], imb_ratio=0.2) 175 | assert "must contain the same amount of rows" in str(exc.value) 176 | 177 | X = np.random.rand(100, 2) 178 | y = np.concatenate([np.ones(50), np.zeros(50)]) 179 | X, y = data.imbalance_data(X, y, [1], [0], imb_ratio=0.2) 180 | assert [(60, 2), (60, ), 10] == [X.shape, y.shape, y.sum()] # 50/50 is original imb_ratio, 10/50(=0.2) is new imb_ratio 181 | 182 | X = np.random.rand(100, 2) 183 | y = np.concatenate([np.ones(50), np.zeros(50)]) 184 | X, y = data.imbalance_data(X, y, [1], [0]) 185 | assert [(100, 2), (100, ), 50] == [X.shape, y.shape, y.sum()] # 50/50 is original imb_ratio, 50/50(=1) is new imb_ratio 186 | -------------------------------------------------------------------------------- /imbDRL/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Tuple 3 | 4 | import numpy as np 5 | from pandas import read_csv 6 | from sklearn.model_selection import train_test_split 7 | from sklearn.utils import shuffle 8 | from tensorflow.keras.datasets import cifar10, fashion_mnist, imdb, mnist 9 | from tensorflow.keras.preprocessing.sequence import pad_sequences 10 | 11 | from imbDRL.utils import imbalance_ratio 12 | 13 | TrainTestData = Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] 14 | TrainTestValData = Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray] 15 | 16 | 17 | def load_image(data_source: str) -> TrainTestData: 18 | """ 19 | Loads one of the following image datasets: {mnist, famnist, cifar10}. 20 | Normalizes the data. Returns X and y for both train and test datasets. 21 | Dtypes of X's and y's will be `float32` and `int32` to be compatible with `tf_agents`. 22 | 23 | :param data_source: Either mnist, famnist or cifar10 24 | :type data_source: str 25 | 26 | :return: Tuple of (X_train, y_train, X_test, y_test) containing original split of train/test 27 | :rtype: tuple 28 | """ 29 | reshape_shape = -1, 28, 28, 1 30 | 31 | if data_source == "mnist": 32 | (X_train, y_train), (X_test, y_test) = mnist.load_data() 33 | 34 | elif data_source == "famnist": 35 | (X_train, y_train), (X_test, y_test) = fashion_mnist.load_data() 36 | 37 | elif data_source == "cifar10": 38 | (X_train, y_train), (X_test, y_test) = cifar10.load_data() 39 | reshape_shape = -1, 32, 32, 3 40 | 41 | else: 42 | raise ValueError("No valid `data_source`.") 43 | 44 | X_train = X_train.reshape(reshape_shape).astype(np.float32) # Float32 is the expected dtype for the observation spec in the env 45 | X_test = X_test.reshape(reshape_shape).astype(np.float32) 46 | 47 | X_train /= 255 # /= is not available when casting int to float: https://stackoverflow.com/a/48948461/10603874 48 | X_test /= 255 49 | 50 | y_train = y_train.reshape(y_train.shape[0], ).astype(np.int32) 51 | y_test = y_test.reshape(y_test.shape[0], ).astype(np.int32) 52 | 53 | return X_train, y_train, X_test, y_test 54 | 55 | 56 | def load_csv(fp_train: str, fp_test: str, label_col: str, drop_cols: List[str], normalization: bool = False) -> TrainTestData: 57 | """ 58 | Loads any csv-file from local filepaths. Returns X and y for both train and test datasets. 59 | Option to normalize the data with min-max normalization. 60 | Only csv-files with float32 values for the features and int32 values for the labels supported. 61 | Source for dataset: https://mimic-iv.mit.edu/ 62 | 63 | :param fp_train: Location of the train csv-file 64 | :type fp_train: str 65 | :param fp_test: Location of the test csv-file 66 | :type fp_test: str 67 | :param label_col: The name of the column containing the labels of the data 68 | :rtype label_col: str 69 | :param drop_cols: List of the names of the columns to be dropped. `label_col` gets dropped automatically 70 | :rtype drop_cols: List of strings 71 | :param normalization: Normalize the data with min-max normalization? 72 | :type normalization: bool 73 | 74 | :return: Tuple of (X_train, y_train, X_test, y_test) containing original split of train/test 75 | :rtype: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] 76 | """ 77 | if not os.path.isfile(fp_train): 78 | raise FileNotFoundError(f"`fp_train` {fp_train} does not exist.") 79 | if not os.path.isfile(fp_test): 80 | raise FileNotFoundError(f"`fp_test` {fp_test} does not exist.") 81 | if not isinstance(normalization, bool): 82 | raise TypeError(f"`normalization` must be of type `bool`, not {type(normalization)}") 83 | 84 | X_train = read_csv(fp_train).astype(np.float32) # DataFrames directly converted to float32 85 | X_test = read_csv(fp_test).astype(np.float32) 86 | 87 | y_train = X_train[label_col].astype(np.int32) 88 | y_test = X_test[label_col].astype(np.int32) 89 | X_train.drop(columns=drop_cols + [label_col], inplace=True) # Dropping cols and label column 90 | X_test.drop(columns=drop_cols + [label_col], inplace=True) 91 | 92 | # Other data sources are already normalized. RGB values are always in range 0 to 255. 93 | if normalization: 94 | mini, maxi = X_train.min(axis=0), X_train.max(axis=0) 95 | X_train -= mini 96 | X_train /= maxi - mini 97 | X_test -= mini 98 | X_test /= maxi - mini 99 | 100 | return X_train.values, y_train.values, X_test.values, y_test.values # Numpy arrays 101 | 102 | 103 | def load_imdb(config: Tuple[int, int] = (5_000, 500)) -> TrainTestData: 104 | """Loads the IMDB dataset. Returns X and y for both train and test datasets. 105 | 106 | :param config: Tuple of number of most frequent words and max length of each sequence. 107 | :type config: str 108 | 109 | :return: Tuple of (X_train, y_train, X_test, y_test) containing original split of train/test 110 | :rtype: tuple 111 | """ 112 | if not isinstance(config, (tuple, list)): 113 | raise TypeError(f"{type(config)} is no valid datatype for `config`.") 114 | if len(config) != 2: 115 | raise ValueError("Tuple length of `config` must be 2.") 116 | if not all(i > 0 for i in config): 117 | raise ValueError("All integers of `config` must be > 0.") 118 | 119 | (X_train, y_train), (X_test, y_test) = imdb.load_data(num_words=config[0]) 120 | 121 | X_train = pad_sequences(X_train, maxlen=config[1]) 122 | X_test = pad_sequences(X_test, maxlen=config[1]) 123 | 124 | y_train = y_train.astype(np.int32) 125 | y_test = y_test.astype(np.int32) 126 | 127 | return X_train, y_train, X_test, y_test 128 | 129 | 130 | def get_train_test_val(X_train: np.ndarray, y_train: np.ndarray, X_test: np.ndarray, y_test: np.ndarray, min_classes: List[int], 131 | maj_classes: List[int], imb_ratio: float = None, imb_test: bool = True, val_frac: float = 0.25, 132 | print_stats: bool = True) -> TrainTestValData: 133 | """ 134 | Imbalances data and divides the data into train, test and validation sets. 135 | The imbalance rate of each individual dataset is approx. the same as the given `imb_ratio`. 136 | 137 | :param X_train: The X_train data 138 | :type X_train: np.ndarray 139 | :param y_train: The y_train data 140 | :type y_train: np.ndarray 141 | :param X_test: The X_test data 142 | :type X_test: np.ndarray 143 | :param y_test: The y_test data 144 | :type y_test: np.ndarray 145 | :param min_classes: List of labels of all minority classes 146 | :type min_classes: list 147 | :param maj_classes: List of labels of all majority classes. 148 | :type maj_classes: list 149 | :param imb_ratio: Imbalance ratio for minority to majority class: len(minority datapoints) / len(majority datapoints) 150 | If the `imb_ratio` is None, data will not be imbalanced and will only be relabeled to 1's and 0's. 151 | :type imb_ratio: float 152 | :param imb_test: Imbalance the test dataset? 153 | :type imb_test: bool 154 | :param val_frac: Fraction to take from X_train and y_train for X_val and y_val 155 | :type val_frac: float 156 | :param print_stats: Print the imbalance ratio of the imbalanced data? 157 | :type print_stats: bool 158 | 159 | :return: Tuple of (X_train, y_train, X_test, y_test, X_val, y_val) 160 | :rtype: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray] 161 | """ 162 | if not 0 < val_frac < 1: 163 | raise ValueError(f"{val_frac} is not in interval 0 < x < 1.") 164 | if not isinstance(print_stats, bool): 165 | raise TypeError(f"`print_stats` must be of type `bool`, not {type(print_stats)}.") 166 | 167 | X_train, y_train = imbalance_data(X_train, y_train, min_classes, maj_classes, imb_ratio=imb_ratio) 168 | # Only imbalance test-data if imb_test is True 169 | X_test, y_test = imbalance_data(X_test, y_test, min_classes, maj_classes, imb_ratio=imb_ratio if imb_test else None) 170 | 171 | # stratify=y_train to ensure class balance is kept between train and validation datasets 172 | X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=val_frac, stratify=y_train) 173 | 174 | if print_stats: 175 | p_train, p_test, p_val = [((y == 1).sum(), imbalance_ratio(y)) for y in (y_train, y_test, y_val)] 176 | print(f"Imbalance ratio `p`:\n" 177 | f"\ttrain: n={p_train[0]}, p={p_train[1]:.6f}\n" 178 | f"\ttest: n={p_test[0]}, p={p_test[1]:.6f}\n" 179 | f"\tvalidation: n={p_val[0]}, p={p_val[1]:.6f}") 180 | 181 | return X_train, y_train, X_test, y_test, X_val, y_val 182 | 183 | 184 | def imbalance_data(X: np.ndarray, y: np.ndarray, min_class: List[int], maj_class: List[int], 185 | imb_ratio: float = None) -> Tuple[np.ndarray, np.ndarray]: 186 | """ 187 | Split data in minority and majority, only values in {min_class, maj_class} will be kept. 188 | (Possibly) decrease minority rows to match the imbalance rate. 189 | If initial imb_ratio of dataset is lower than given `imb_ratio`, the imb_ratio of the returned data will not be changed. 190 | If the `imb_ratio` is None, data will not be imbalanced and will only be relabeled to 1's and 0's. 191 | """ 192 | if not isinstance(X, np.ndarray): 193 | raise TypeError(f"`X` must be of type `np.ndarray` not {type(X)}") 194 | if not isinstance(y, np.ndarray): 195 | raise TypeError(f"`y` must be of type `np.ndarray` not {type(y)}") 196 | if X.shape[0] != y.shape[0]: 197 | raise ValueError("`X` and `y` must contain the same amount of rows.") 198 | if not isinstance(min_class, (list, tuple)): 199 | raise TypeError("`min_class` must be of type list or tuple.") 200 | if not isinstance(maj_class, (list, tuple)): 201 | raise TypeError("`maj_class` must be of type list or tuple.") 202 | 203 | if (imb_ratio is not None) and not (0 < imb_ratio < 1): 204 | raise ValueError(f"{imb_ratio} is not in interval 0 < imb_ratio < 1.") 205 | 206 | if imb_ratio is None: # Do not imbalance data if no `imb_ratio` is given 207 | imb_ratio = 1 208 | 209 | X_min = X[np.isin(y, min_class)] # Mask the correct indexes 210 | X_maj = X[np.isin(y, maj_class)] # Only keep data/labels for x in {min_class, maj_class} and forget all other 211 | 212 | min_len = int(X_maj.shape[0] * imb_ratio) # Amount of rows to select from minority classes to get to correct imbalance ratio 213 | # Keep all majority rows, decrease minority rows to match `imb_ratio` 214 | X_min = X_min[np.random.choice(X_min.shape[0], min(min_len, X_min.shape[0]), replace=False), :] 215 | 216 | X_imb = np.concatenate([X_maj, X_min]).astype(np.float32) 217 | y_imb = np.concatenate((np.zeros(X_maj.shape[0]), np.ones(X_min.shape[0]))).astype(np.int32) 218 | X_imb, y_imb = shuffle(X_imb, y_imb) 219 | 220 | return X_imb, y_imb 221 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | https://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | Copyright 2021 Thijs van den Berg 179 | 180 | Licensed under the Apache License, Version 2.0 (the "License"); 181 | you may not use this file except in compliance with the License. 182 | You may obtain a copy of the License at 183 | 184 | http://www.apache.org/licenses/LICENSE-2.0 185 | 186 | Unless required by applicable law or agreed to in writing, software 187 | distributed under the License is distributed on an "AS IS" BASIS, 188 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 189 | See the License for the specific language governing permissions and 190 | limitations under the License. 191 | -------------------------------------------------------------------------------- /imbDRL/agents/ddqn.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from datetime import datetime 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | from imbDRL.environments.classifierenv import ClassifierEnv 7 | from imbDRL.metrics import (classification_metrics, decision_function, 8 | network_predictions, plot_pr_curve, plot_roc_curve) 9 | from imbDRL.utils import imbalance_ratio 10 | from tensorflow import data 11 | from tensorflow.keras.optimizers import Adam 12 | from tf_agents.agents.dqn.dqn_agent import DdqnAgent 13 | from tf_agents.drivers.dynamic_step_driver import DynamicStepDriver 14 | from tf_agents.environments.tf_py_environment import TFPyEnvironment 15 | from tf_agents.networks.sequential import Sequential 16 | from tf_agents.policies.random_tf_policy import RandomTFPolicy 17 | from tf_agents.replay_buffers.tf_uniform_replay_buffer import \ 18 | TFUniformReplayBuffer 19 | from tf_agents.utils import common 20 | from tqdm import tqdm 21 | 22 | 23 | class TrainDDQN(): 24 | """Wrapper for DDQN training, validation, saving etc.""" 25 | 26 | def __init__(self, episodes: int, warmup_steps: int, learning_rate: float, gamma: float, min_epsilon: float, decay_episodes: int, 27 | model_path: str = None, log_dir: str = None, batch_size: int = 64, memory_length: int = None, 28 | collect_steps_per_episode: int = 1, val_every: int = None, target_update_period: int = 1, target_update_tau: float = 1.0, 29 | progressbar: bool = True, n_step_update: int = 1, gradient_clipping: float = 1.0, collect_every: int = 1) -> None: 30 | """ 31 | Wrapper to make training easier. 32 | Code is partly based of https://www.tensorflow.org/agents/tutorials/1_dqn_tutorial 33 | 34 | :param episodes: Number of training episodes 35 | :type episodes: int 36 | :param warmup_steps: Number of episodes to fill Replay Buffer with random state-action pairs before training starts 37 | :type warmup_steps: int 38 | :param learning_rate: Learning Rate for the Adam Optimizer 39 | :type learning_rate: float 40 | :param gamma: Discount factor for the Q-values 41 | :type gamma: float 42 | :param min_epsilon: Lowest and final value for epsilon 43 | :type min_epsilon: float 44 | :param decay_episodes: Amount of episodes to decay from 1 to `min_epsilon` 45 | :type decay_episodes: int 46 | :param model_path: Location to save the trained model 47 | :type model_path: str 48 | :param log_dir: Location to save the logs, usefull for TensorBoard 49 | :type log_dir: str 50 | :param batch_size: Number of samples in minibatch to train on each step 51 | :type batch_size: int 52 | :param memory_length: Maximum size of the Replay Buffer 53 | :type memory_length: int 54 | :param collect_steps_per_episode: Amount of data to collect for Replay Buffer each episiode 55 | :type collect_steps_per_episode: int 56 | :param collect_every: Step interval to collect data during training 57 | :type collect_every: int 58 | :param val_every: Validate the model every X episodes using the `collect_metrics()` function 59 | :type val_every: int 60 | :param target_update_period: Update the target Q-network every X episodes 61 | :type target_update_period: int 62 | :param target_update_tau: Parameter for softening the `target_update_period` 63 | :type target_update_tau: float 64 | :param progressbar: Enable or disable the progressbar for collecting data and training 65 | :type progressbar: bool 66 | 67 | :return: None 68 | :rtype: NoneType 69 | """ 70 | self.episodes = episodes # Total episodes 71 | self.warmup_steps = warmup_steps # Amount of warmup steps before training 72 | self.batch_size = batch_size # Batch size of Replay Memory 73 | self.collect_steps_per_episode = collect_steps_per_episode # Amount of steps to collect data each episode 74 | self.collect_every = collect_every # Step interval to collect data during training 75 | self.learning_rate = learning_rate # Learning Rate 76 | self.gamma = gamma # Discount factor 77 | self.min_epsilon = min_epsilon # Minimal chance of choosing random action 78 | self.decay_episodes = decay_episodes # Number of episodes to decay from 1.0 to `EPSILON` 79 | self.target_update_period = target_update_period # Period for soft updates 80 | self.target_update_tau = target_update_tau 81 | self.progressbar = progressbar # Enable or disable the progressbar for collecting data and training 82 | self.n_step_update = n_step_update 83 | self.gradient_clipping = gradient_clipping # Clip the loss 84 | self.compiled = False 85 | NOW = datetime.now().strftime("%Y%m%d_%H%M%S") 86 | 87 | if memory_length is not None: 88 | self.memory_length = memory_length # Max Replay Memory length 89 | else: 90 | self.memory_length = warmup_steps 91 | 92 | if val_every is not None: 93 | self.val_every = val_every # Validate the policy every `val_every` episodes 94 | else: 95 | self.val_every = self.episodes // min(50, self.episodes) # Can't validate the model 50 times if self.episodes < 50 96 | 97 | if model_path is not None: 98 | self.model_path = model_path 99 | else: 100 | self.model_path = "./models/" + NOW + ".pkl" 101 | 102 | if log_dir is None: 103 | log_dir = "./logs/" + NOW 104 | self.writer = tf.summary.create_file_writer(log_dir) 105 | 106 | def compile_model(self, X_train, y_train, layers: list = [], imb_ratio: float = None, loss_fn=common.element_wise_squared_loss) -> None: 107 | """Initializes the neural networks, DDQN-agent, collect policies and replay buffer. 108 | 109 | :param X_train: Training data for the model. 110 | :type X_train: np.ndarray 111 | :param y_train: Labels corresponding to `X_train`. 1 for the positive class, 0 for the negative class. 112 | :param y_train: np.ndarray 113 | :param layers: List of layers to feed into the TF-agents custom Sequential(!) layer. 114 | :type layers: list 115 | :param imb_ratio: The imbalance ratio of the data. 116 | :type imb_ratio: float 117 | :param loss_fn: Callable loss function 118 | :type loss_fn: tf.compat.v1.losses 119 | 120 | :return: None 121 | :rtype: NoneType 122 | """ 123 | if imb_ratio is None: 124 | imb_ratio = imbalance_ratio(y_train) 125 | 126 | self.train_env = TFPyEnvironment(ClassifierEnv(X_train, y_train, imb_ratio)) 127 | self.global_episode = tf.Variable(0, name="global_episode", dtype=np.int64, trainable=False) # Global train episode counter 128 | 129 | # Custom epsilon decay: https://github.com/tensorflow/agents/issues/339 130 | epsilon_decay = tf.compat.v1.train.polynomial_decay( 131 | 1.0, self.global_episode, self.decay_episodes, end_learning_rate=self.min_epsilon) 132 | 133 | self.q_net = Sequential(layers, self.train_env.observation_spec()) 134 | 135 | self.agent = DdqnAgent(self.train_env.time_step_spec(), 136 | self.train_env.action_spec(), 137 | q_network=self.q_net, 138 | optimizer=Adam(learning_rate=self.learning_rate), 139 | td_errors_loss_fn=loss_fn, 140 | train_step_counter=self.global_episode, 141 | target_update_period=self.target_update_period, 142 | target_update_tau=self.target_update_tau, 143 | gamma=self.gamma, 144 | epsilon_greedy=epsilon_decay, 145 | n_step_update=self.n_step_update, 146 | gradient_clipping=self.gradient_clipping) 147 | self.agent.initialize() 148 | 149 | self.random_policy = RandomTFPolicy(self.train_env.time_step_spec(), self.train_env.action_spec()) 150 | self.replay_buffer = TFUniformReplayBuffer(data_spec=self.agent.collect_data_spec, 151 | batch_size=self.train_env.batch_size, 152 | max_length=self.memory_length) 153 | 154 | self.warmup_driver = DynamicStepDriver(self.train_env, 155 | self.random_policy, 156 | observers=[self.replay_buffer.add_batch], 157 | num_steps=self.warmup_steps) # Uses a random policy 158 | 159 | self.collect_driver = DynamicStepDriver(self.train_env, 160 | self.agent.collect_policy, 161 | observers=[self.replay_buffer.add_batch], 162 | num_steps=self.collect_steps_per_episode) # Uses the epsilon-greedy policy of the agent 163 | 164 | self.agent.train = common.function(self.agent.train) # Optimalization 165 | self.warmup_driver.run = common.function(self.warmup_driver.run) 166 | self.collect_driver.run = common.function(self.collect_driver.run) 167 | 168 | self.compiled = True 169 | 170 | def train(self, *args) -> None: 171 | """Starts the training of the model. Includes warmup period, metrics collection and model saving. 172 | 173 | :param *args: All arguments will be passed to `collect_metrics()`. 174 | This can be usefull to pass callables, testing environments or validation data. 175 | Overwrite the TrainDDQN.collect_metrics() function to use your own *args. 176 | :type *args: Any 177 | 178 | :return: None 179 | :rtype: NoneType, last step is saving the model as a side-effect 180 | """ 181 | assert self.compiled, "Model must be compiled with model.compile_model(X_train, y_train, layers) before training." 182 | 183 | # Warmup period, fill memory with random actions 184 | if self.progressbar: 185 | print(f"\033[92mCollecting data for {self.warmup_steps:_} steps... This might take a few minutes...\033[0m") 186 | 187 | self.warmup_driver.run(time_step=None, policy_state=self.random_policy.get_initial_state(self.train_env.batch_size)) 188 | 189 | if self.progressbar: 190 | print(f"\033[92m{self.replay_buffer.num_frames():_} frames collected!\033[0m") 191 | 192 | dataset = self.replay_buffer.as_dataset(sample_batch_size=self.batch_size, num_steps=self.n_step_update + 1, 193 | num_parallel_calls=data.experimental.AUTOTUNE).prefetch(data.experimental.AUTOTUNE) 194 | iterator = iter(dataset) 195 | 196 | def _train(): 197 | experiences, _ = next(iterator) 198 | return self.agent.train(experiences).loss 199 | _train = common.function(_train) # Optimalization 200 | 201 | ts = None 202 | policy_state = self.agent.collect_policy.get_initial_state(self.train_env.batch_size) 203 | self.collect_metrics(*args) # Initial collection for step 0 204 | pbar = tqdm(total=self.episodes, disable=(not self.progressbar), desc="Training the DDQN") # TQDM progressbar 205 | for _ in range(self.episodes): 206 | if not self.global_episode % self.collect_every: 207 | # Collect a few steps using collect_policy and save to `replay_buffer` 208 | if self.collect_steps_per_episode != 0: 209 | ts, policy_state = self.collect_driver.run(time_step=ts, policy_state=policy_state) 210 | pbar.update(self.collect_every) # More stable TQDM updates, collecting could take some time 211 | 212 | # Sample a batch of data from `replay_buffer` and update the agent's network 213 | train_loss = _train() 214 | 215 | if not self.global_episode % self.val_every: 216 | with self.writer.as_default(): 217 | tf.summary.scalar("train_loss", train_loss, step=self.global_episode) 218 | 219 | self.collect_metrics(*args) 220 | pbar.close() 221 | 222 | def collect_metrics(self, X_val: np.ndarray, y_val: np.ndarray, save_best: str = None): 223 | """Collects metrics using the trained Q-network. 224 | 225 | :param X_val: Features of validation data, same shape as X_train 226 | :type X_val: np.ndarray 227 | :param y_val: Labels of validation data, same shape as y_train 228 | :type y_val: np.ndarray 229 | :param save_best: Saving the best model of all validation runs based on given metric: 230 | Choose one of: {Gmean, F1, Precision, Recall, TP, TN, FP, FN} 231 | This improves stability since the model at the last episode is not guaranteed to be the best model. 232 | :type save_best: str 233 | """ 234 | y_pred = network_predictions(self.agent._target_q_network, X_val) 235 | stats = classification_metrics(y_val, y_pred) 236 | avgQ = np.mean(decision_function(self.agent._target_q_network, X_val)) # Max action for each x in X 237 | 238 | if save_best is not None: 239 | if not hasattr(self, "best_score"): # If no best model yet 240 | self.best_score = 0.0 241 | 242 | if stats.get(save_best) >= self.best_score: # Overwrite best model 243 | self.save_network() # Saving directly to avoid shallow copy without trained weights 244 | self.best_score = stats.get(save_best) 245 | 246 | with self.writer.as_default(): 247 | tf.summary.scalar("AverageQ", avgQ, step=self.global_episode) # Average Q-value for this epoch 248 | for k, v in stats.items(): 249 | tf.summary.scalar(k, v, step=self.global_episode) 250 | 251 | def evaluate(self, X_test, y_test, X_train=None, y_train=None): 252 | """ 253 | Final evaluation of trained Q-network with X_test and y_test. 254 | Optional PR and ROC curve comparison to X_train, y_train to ensure no overfitting is taking place. 255 | 256 | :param X_test: Features of test data, same shape as X_train 257 | :type X_test: np.ndarray 258 | :param y_test: Labels of test data, same shape as y_train 259 | :type y_test: np.ndarray 260 | :param X_train: Features of train data 261 | :type X_train: np.ndarray 262 | :param y_train: Labels of train data 263 | :type y_train: np.ndarray 264 | """ 265 | if hasattr(self, "best_score"): 266 | print(f"\033[92mBest score: {self.best_score:6f}!\033[0m") 267 | network = self.load_network(self.model_path) # Load best saved model 268 | else: 269 | network = self.agent._target_q_network # Load latest target model 270 | 271 | if (X_train is not None) and (y_train is not None): 272 | plot_pr_curve(network, X_test, y_test, X_train, y_train) 273 | plot_roc_curve(network, X_test, y_test, X_train, y_train) 274 | 275 | y_pred = network_predictions(network, X_test) 276 | return classification_metrics(y_test, y_pred) 277 | 278 | def save_network(self): 279 | """Saves Q-network as pickle to `model_path`.""" 280 | with open(self.model_path, "wb") as f: # Save Q-network as pickle 281 | pickle.dump(self.agent._target_q_network, f) 282 | 283 | @staticmethod 284 | def load_network(fp: str): 285 | """Static method to load Q-network pickle from given filepath. 286 | 287 | :param fp: Filepath to the saved pickle of the network 288 | :type fp: str 289 | 290 | :returns: The network-object loaded from a pickle file. 291 | :rtype: tensorflow.keras.models.Model 292 | """ 293 | with open(fp, "rb") as f: # Load the Q-network 294 | network = pickle.load(f) 295 | return network 296 | --------------------------------------------------------------------------------