├── .gitignore ├── src ├── data │ ├── failed.npy │ ├── pruned.npy │ ├── states.npy │ ├── steps.npy │ ├── accuracy.npy │ ├── distance.npy │ ├── metrics.npy │ ├── complexity.npy │ ├── active_accuracy.npy │ ├── passive_accuracy.npy │ ├── reversed_states.npy │ └── reversed_complexity.npy ├── core │ ├── __init__.py │ ├── utils.py │ ├── trials.py │ ├── env.py │ ├── config.py │ └── mdp.py ├── plot_scripts │ ├── plot_failed_models.py │ ├── plot_metrics.py │ ├── plot_distance.py │ ├── plot_active_accuracy.py │ ├── plot_accuracy.py │ ├── plot_states.py │ ├── plot_complexity.py │ └── plot_model_reduction.py └── scripts │ ├── active_accuracy_script.py │ ├── complexity_script.py │ ├── states_script.py │ └── model_reduction_script.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] -------------------------------------------------------------------------------- /src/data/failed.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alec-tschantz/action-oriented/HEAD/src/data/failed.npy -------------------------------------------------------------------------------- /src/data/pruned.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alec-tschantz/action-oriented/HEAD/src/data/pruned.npy -------------------------------------------------------------------------------- /src/data/states.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alec-tschantz/action-oriented/HEAD/src/data/states.npy -------------------------------------------------------------------------------- /src/data/steps.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alec-tschantz/action-oriented/HEAD/src/data/steps.npy -------------------------------------------------------------------------------- /src/data/accuracy.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alec-tschantz/action-oriented/HEAD/src/data/accuracy.npy -------------------------------------------------------------------------------- /src/data/distance.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alec-tschantz/action-oriented/HEAD/src/data/distance.npy -------------------------------------------------------------------------------- /src/data/metrics.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alec-tschantz/action-oriented/HEAD/src/data/metrics.npy -------------------------------------------------------------------------------- /src/data/complexity.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alec-tschantz/action-oriented/HEAD/src/data/complexity.npy -------------------------------------------------------------------------------- /src/data/active_accuracy.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alec-tschantz/action-oriented/HEAD/src/data/active_accuracy.npy -------------------------------------------------------------------------------- /src/data/passive_accuracy.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alec-tschantz/action-oriented/HEAD/src/data/passive_accuracy.npy -------------------------------------------------------------------------------- /src/data/reversed_states.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alec-tschantz/action-oriented/HEAD/src/data/reversed_states.npy -------------------------------------------------------------------------------- /src/data/reversed_complexity.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alec-tschantz/action-oriented/HEAD/src/data/reversed_complexity.npy -------------------------------------------------------------------------------- /src/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import * 2 | from .trials import * 3 | from .utils import * 4 | from .mdp import * 5 | from .env import * 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning action-oriented models through active inference 2 | 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 4 | 5 | Repository for "Learning action-oriented models through active inference" [[Paper](https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1007805)] 6 | 7 | 8 | 9 | The core code can be found in `src/core`, whereas scripts for running experiments can be found un `src/scripts` 10 | -------------------------------------------------------------------------------- /src/plot_scripts/plot_failed_models.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import core 4 | from core.config import * 5 | 6 | plt.rcParams["axes.edgecolor"] = "#333F4B" 7 | plt.rcParams["axes.linewidth"] = 0.8 8 | plt.rcParams["xtick.color"] = "#333F4B" 9 | plt.rcParams["ytick.color"] = "#333F4B" 10 | 11 | if __name__ == "__main__": 12 | failed_raw = np.load("data/failed.npy") 13 | failed_sum = np.sum(failed_raw, axis=1) 14 | 15 | colors = core.get_color_palette() 16 | x_ticks = range(N_AGENTS) 17 | x_labels = AGENT_NAMES 18 | f, ax = plt.subplots(1, 1, figsize=(5, 4)) 19 | ax.bar(range(N_AGENTS), failed_sum, color=colors, edgecolor="grey", align="center") 20 | plt.xticks(x_ticks) 21 | ax.set_xticklabels(x_labels) 22 | ax.spines["top"].set_visible(False) 23 | ax.spines["right"].set_visible(False) 24 | ax.spines["left"].set_smart_bounds(True) 25 | ax.spines["bottom"].set_smart_bounds(True) 26 | 27 | f.savefig(FAILED_MODELS, dpi=600, bbox_inches="tight") 28 | plt.show() 29 | -------------------------------------------------------------------------------- /src/core/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from core.config import * 3 | from core.mdp import MDP 4 | 5 | 6 | def get_mdp(agent_id, reverse_prior=False): 7 | a = np.eye(N_OBS) 8 | b = np.random.rand(N_CONTROL, N_STATES, N_STATES) 9 | c = np.zeros([N_OBS, 1]) 10 | 11 | if reverse_prior: 12 | c[0] = 1 13 | else: 14 | c[PRIOR_ID] = 1 15 | 16 | kwargs = {} 17 | if agent_id == FULL_ID: 18 | kwargs = {"alpha": ALPHA, "beta": 1, "lr": LR} 19 | elif agent_id == INST_ID: 20 | kwargs = {"alpha": ALPHA, "beta": 0, "lr": LR} 21 | elif agent_id == EPIS_ID: 22 | kwargs = {"alpha": 0, "beta": 1, "lr": LR} 23 | elif agent_id == RAND_ID: 24 | kwargs = {"alpha": 0, "beta": 0, "lr": LR} 25 | 26 | mdp = MDP(a, b, c, **kwargs) 27 | return mdp 28 | 29 | 30 | def get_true_model(): 31 | b = np.zeros([N_CONTROL, N_STATES, N_STATES]) 32 | b[TUMBLE, :, :] = np.array([[0.5, 0.5], [0.5, 0.5]]) 33 | b[RUN, :, :] = np.array([[1, 0], [0, 1]]) 34 | b += np.exp(-16) 35 | return b 36 | -------------------------------------------------------------------------------- /src/scripts/active_accuracy_script.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import src.core as core 3 | from src.core.config import * 4 | 5 | TRAIN_STEPS = 300 6 | TEST_STEPS = 100 7 | 8 | 9 | if __name__ == "__main__": 10 | passive_accuracy = np.zeros([4, N_AVERAGES]) 11 | active_accuracy = np.zeros([4, N_AVERAGES]) 12 | 13 | for n in range(N_AVERAGES): 14 | 15 | if n % 20 == 0: 16 | print("> Processing average {}".format(n)) 17 | 18 | full = core.get_mdp(FULL_ID) 19 | inst = core.get_mdp(INST_ID) 20 | epis = core.get_mdp(EPIS_ID) 21 | rand = core.get_mdp(RAND_ID) 22 | 23 | full = core.learn_trial(full, TRAIN_STEPS) 24 | inst = core.learn_trial(inst, TRAIN_STEPS) 25 | epis = core.learn_trial(epis, TRAIN_STEPS) 26 | rand = core.learn_trial(rand, TRAIN_STEPS) 27 | 28 | passive_accuracy[FULL_ID, n] = core.test_passive_accuracy(full, TEST_STEPS) 29 | passive_accuracy[INST_ID, n] = core.test_passive_accuracy(inst, TEST_STEPS) 30 | passive_accuracy[EPIS_ID, n] = core.test_passive_accuracy(epis, TEST_STEPS) 31 | passive_accuracy[RAND_ID, n] = core.test_passive_accuracy(rand, TEST_STEPS) 32 | 33 | active_accuracy[FULL_ID, n] = core.test_active_accuracy(full, TEST_STEPS) 34 | active_accuracy[INST_ID, n] = core.test_active_accuracy(inst, TEST_STEPS) 35 | active_accuracy[EPIS_ID, n] = core.test_active_accuracy(epis, TEST_STEPS) 36 | active_accuracy[RAND_ID, n] = core.test_active_accuracy(rand, TEST_STEPS) 37 | 38 | np.save(ACTIVE_ACCURACY_PATH, active_accuracy) 39 | np.save(PASSIVE_ACCURACY_PATH, passive_accuracy) 40 | print("> Data saved") 41 | -------------------------------------------------------------------------------- /src/scripts/complexity_script.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import src.core as core 4 | from src.core.config import * 5 | 6 | 7 | def get_factor_complexity(a, b): 8 | kls = np.zeros(4) 9 | kls[TUMBLE_NEG_ID] = np.sum(a[0, :, 0] * np.log(a[0, :, 0] / b[0, :, 0]), axis=0) 10 | kls[TUMBLE_POS_ID] = np.sum(a[0, :, 1] * np.log(a[0, :, 1] / b[0, :, 1]), axis=0) 11 | kls[RUN_NEG_ID] = np.sum(a[1, :, 0] * np.log(a[1, :, 0] / b[1, :, 0]), axis=0) 12 | kls[RUN_POS_ID] = np.sum(a[1, :, 1] * np.log(a[1, :, 1] / b[1, :, 1]), axis=0) 13 | return kls 14 | 15 | 16 | def process_agent(agent, reverse_prior=False): 17 | mdp = core.get_mdp(agent, reverse_prior=reverse_prior) 18 | original_model = np.copy(mdp.B) 19 | mdp = core.learn_trial(mdp, TEST_TRIAL_LEN) 20 | return get_factor_complexity(original_model, mdp.B) 21 | 22 | 23 | def get_complexity(reverse_prior): 24 | _complexity = np.zeros([N_AGENTS, N_DISTRIBUTIONS, N_AVERAGES]) 25 | 26 | for agent_id in range(N_AGENTS): 27 | print("> Processing agent: {}".format(AGENT_NAMES[agent_id])) 28 | 29 | for n in range(N_AVERAGES): 30 | if n % 50 == 0: 31 | print("> Processing average [{}/{}]".format(n, N_AVERAGES)) 32 | 33 | kl_divs = process_agent(agent_id, reverse_prior=reverse_prior) 34 | _complexity[agent_id, :, n] = kl_divs 35 | 36 | return _complexity 37 | 38 | 39 | if __name__ == "__main__": 40 | print("\n> Processing complexity") 41 | complexity = get_complexity(False) 42 | print("\n> Processing complexity (reverse prior)") 43 | reversed_complexity = get_complexity(True) 44 | 45 | np.save(COMPLEXITY_PATH, complexity) 46 | np.save(REVERSED_COMPLEXITY_PATH, reversed_complexity) 47 | print("> Data saved") 48 | -------------------------------------------------------------------------------- /src/core/trials.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from core.env import Environment 3 | from core.config import * 4 | 5 | 6 | def learn_trial(mdp, n_steps, record_states=False): 7 | env = Environment() 8 | obv = env.observe() 9 | mdp.reset(obv) 10 | states = np.zeros([N_CONTROL, N_STATES, N_STATES]) 11 | 12 | for step in range(n_steps): 13 | prev_obv = obv 14 | action = mdp.step(obv) 15 | obv = env.act(action) 16 | mdp.update(action, obv, prev_obv) 17 | if record_states: 18 | states[action, obv, prev_obv] += 1 19 | 20 | if record_states: 21 | return mdp, states 22 | return mdp 23 | 24 | 25 | def test_distance(mdp, steps): 26 | env = Environment() 27 | obv = env.observe() 28 | mdp.reset(obv) 29 | 30 | for _ in range(steps): 31 | action = mdp.step(obv) 32 | obv = env.act(action) 33 | 34 | return (env.distance() - env.source_size) + 1 35 | 36 | 37 | def test_passive_accuracy(mdp, n_steps): 38 | env = Environment() 39 | obv = env.observe() 40 | mdp.reset(obv) 41 | acc = 0 42 | 43 | for _ in range(n_steps): 44 | random_action = np.random.choice([0, 1]) 45 | pred, t_pred = mdp.predict_obv(random_action, obv) 46 | _ = mdp.step(obv) 47 | obv = env.act(random_action) 48 | acc += diff(t_pred, pred) 49 | 50 | return acc 51 | 52 | 53 | def test_active_accuracy(mdp, n_steps): 54 | env = Environment() 55 | obv = env.observe() 56 | mdp.reset(obv) 57 | acc = 0 58 | 59 | for _ in range(n_steps): 60 | action = mdp.step(obv) 61 | pred, t_pred = mdp.predict_obv(action, obv) 62 | acc += diff(t_pred, pred) 63 | obv = env.act(action) 64 | 65 | return acc 66 | 67 | 68 | def diff(p, q): 69 | return np.mean(np.square(p - q)) 70 | -------------------------------------------------------------------------------- /src/scripts/states_script.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import src.core as core 3 | from src.core.config import * 4 | 5 | 6 | def get_state_dist(agent_id, reverse_prior=False): 7 | print("> Processing agent: {}".format(AGENT_NAMES[agent_id])) 8 | state_ensemble = np.zeros([N_AVERAGES, N_CONTROL, N_STATES, N_STATES]) 9 | for sample in range(N_AVERAGES): 10 | if sample % 50 == 0: 11 | print("> Processing average [{}/{}]".format(sample, N_AVERAGES)) 12 | mdp = core.get_mdp(agent_id, reverse_prior=reverse_prior) 13 | mdp, states_trial = core.learn_trial(mdp, TEST_TRIAL_LEN, record_states=True) 14 | state_ensemble[sample, :, :, :] = states_trial 15 | 16 | states_sum = np.sum(state_ensemble, axis=0) 17 | states_dist = states_sum / np.sum(states_sum) 18 | return np.round(states_dist, 3) 19 | 20 | 21 | if __name__ == "__main__": 22 | print("\n> Processing state ensemble") 23 | states = np.zeros([N_AGENTS, N_CONTROL, N_STATES, N_STATES]) 24 | states[FULL_ID, :, :, :] = get_state_dist(FULL_ID) 25 | states[INST_ID, :, :, :] = get_state_dist(INST_ID) 26 | states[EPIS_ID, :, :, :] = get_state_dist(EPIS_ID) 27 | states[RAND_ID, :, :, :] = get_state_dist(RAND_ID) 28 | 29 | print("\n> Processing state ensemble (reversed prior)") 30 | reversed_states = np.zeros([N_AGENTS, N_CONTROL, N_STATES, N_STATES]) 31 | reversed_states[FULL_ID, :, :, :] = get_state_dist(FULL_ID, reverse_prior=True) 32 | reversed_states[INST_ID, :, :, :] = get_state_dist(INST_ID, reverse_prior=True) 33 | reversed_states[EPIS_ID, :, :, :] = get_state_dist(EPIS_ID, reverse_prior=True) 34 | reversed_states[RAND_ID, :, :, :] = get_state_dist(RAND_ID, reverse_prior=True) 35 | 36 | np.save(STATES_PATH, states) 37 | np.save(REVERSED_STATES_PATH, reversed_states) 38 | print("> Data saved") 39 | -------------------------------------------------------------------------------- /src/plot_scripts/plot_metrics.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import seaborn as sns 3 | import numpy as np 4 | 5 | plt.rc("legend", fontsize=16) 6 | 7 | plt.rcParams["axes.edgecolor"] = "#333F4B" 8 | plt.rcParams["axes.linewidth"] = 0.8 9 | plt.rcParams["xtick.color"] = "#333F4B" 10 | plt.rcParams["ytick.color"] = "#333F4B" 11 | 12 | plt.rcParams["font.family"] = "sans-serif" 13 | plt.rcParams["font.sans-serif"] = "Arial" 14 | plt.rc("ytick", labelsize=17) 15 | plt.rc("xtick", labelsize=17) 16 | 17 | if __name__ == "__main__": 18 | T = 2500 19 | buffer = 25 20 | Tb = int(T / buffer) 21 | 22 | x_ticks = np.arange(0, Tb, 20) 23 | labels = [int(i) * 25 for i in x_ticks] 24 | 25 | x = np.load("data/metrics.npy", allow_pickle=True) 26 | change_x = x[0][0] 27 | t = x[1] 28 | efe_run = x[2] 29 | efe_tumble = x[3] 30 | util_tumble = x[4] 31 | epi_tumble = x[5] 32 | util_run = x[6] 33 | epi_run = x[7] 34 | 35 | fig, ax = plt.subplots() 36 | fig.set_size_inches(12, 7) 37 | 38 | np.save("data/metrics", x) 39 | 40 | ax.plot(t, efe_run, lw=5, label="EFE Run") 41 | ax.plot(t, efe_tumble, lw=5, label="EFE Tumble") 42 | 43 | ax.plot(t, util_tumble, lw=4, linestyle="--", label="Instrumental Tumble ") 44 | ax.plot(t, epi_tumble, lw=4, linestyle="--", label="Epistemic Tumble") 45 | 46 | ax.plot(t, util_run, lw=4, linestyle="--", label="Instrumental Run") 47 | ax.plot(t, epi_run, lw=4, linestyle="--", label="Epistemic Run") 48 | 49 | plt.axvline(x=0, color="#7b7b7b", linestyle="-.", lw=3) 50 | plt.axvline(x=change_x, color="#7b7b7b", linestyle="-.", lw=3) 51 | 52 | ax.spines["top"].set_visible(False) 53 | ax.spines["right"].set_visible(False) 54 | 55 | legend = plt.legend() 56 | frame = legend.get_frame() 57 | frame.set_facecolor("1.0") 58 | frame.set_edgecolor("1.0") 59 | 60 | plt.xticks(x_ticks) 61 | ax.set_xticklabels(labels) 62 | ax.spines["left"].set_smart_bounds(True) 63 | ax.spines["bottom"].set_smart_bounds(True) 64 | 65 | ax.set_xlabel("Number of time steps", {"size": 19}) 66 | ax.set_ylabel("Bits", {"size": 19}) 67 | 68 | fig.savefig("figs/metrics.pdf", dpi=600, bbox_inches="tight") 69 | plt.show() 70 | -------------------------------------------------------------------------------- /src/plot_scripts/plot_distance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from core.config import * 4 | 5 | TICK_SIZE = 14 6 | LEGEND_SIZE = 14 7 | LABEL_SIZE = 16 8 | FIG_SIZE = [9, 7] 9 | 10 | plt.rc("xtick", labelsize=TICK_SIZE) 11 | plt.rc("ytick", labelsize=TICK_SIZE) 12 | plt.rc("legend", fontsize=LEGEND_SIZE) 13 | 14 | plt.rcParams["axes.edgecolor"] = "#333F4B" 15 | plt.rcParams["axes.linewidth"] = 0.8 16 | plt.rcParams["xtick.color"] = "#333F4B" 17 | plt.rcParams["ytick.color"] = "#333F4B" 18 | 19 | plt.rcParams["font.family"] = "sans-serif" 20 | plt.rcParams["font.sans-serif"] = "Arial" 21 | 22 | N_PLOT_EPOCHS = 40 23 | N_AVERAGES = 300 24 | 25 | 26 | def convert_steps_to_ticks(_steps, _epochs): 27 | _steps = _steps[0:_epochs] 28 | _steps = _steps - 20 29 | _steps = [item for item in _steps.astype(str)] 30 | _x_ticks = np.arange(0, _epochs, 10) 31 | _labels = [_steps[int(i)] for i in _x_ticks] 32 | return _x_ticks, _labels 33 | 34 | 35 | if __name__ == "__main__": 36 | steps = np.load(STEPS_PATH) 37 | raw_distances = np.load(DISTANCE_PATH) 38 | x_range = range(N_PLOT_EPOCHS) 39 | 40 | colors = get_color_palette() 41 | x_ticks, x_labels = convert_steps_to_ticks(steps, N_PLOT_EPOCHS) 42 | 43 | fig, ax = plt.subplots() 44 | fig.set_size_inches(FIG_SIZE[0], FIG_SIZE[1]) 45 | 46 | for agent_id in range(N_AGENTS): 47 | mean = np.mean(raw_distances[agent_id, 0:N_PLOT_EPOCHS, :], axis=1) 48 | std = np.std(raw_distances[agent_id, 0:N_PLOT_EPOCHS, :], axis=1) / np.sqrt(N_AVERAGES) 49 | high = mean + std 50 | low = mean - std 51 | 52 | ax.plot(x_range, mean, color=colors[agent_id], lw=2.5, label=AGENT_NAMES[agent_id]) 53 | ax.fill_between(x_range, high, low, color=colors[agent_id], alpha=0.2, linewidth=0) 54 | 55 | ax.spines["top"].set_visible(False) 56 | ax.spines["right"].set_visible(False) 57 | 58 | legend = plt.legend() 59 | frame = legend.get_frame() 60 | frame.set_facecolor("1.0") 61 | frame.set_edgecolor("1.0") 62 | 63 | plt.xlabel("Number of learning steps", {"size": LABEL_SIZE}) 64 | plt.ylabel("Final distance from source", {"size": LABEL_SIZE}) 65 | 66 | plt.xticks(x_ticks) 67 | ax.set_xticklabels(x_labels) 68 | plt.savefig(DISTANCE, dpi=600, bbox_inches="tight") 69 | plt.show() 70 | -------------------------------------------------------------------------------- /src/plot_scripts/plot_active_accuracy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import core 3 | from core.config import * 4 | 5 | FONT_SIZE = 20 6 | plt.rc("ytick", labelsize=16) 7 | plt.rc("xtick", labelsize=16) 8 | plt.rc("legend", fontsize=FONT_SIZE) 9 | 10 | plt.rcParams["axes.edgecolor"] = "#333F4B" 11 | plt.rcParams["axes.linewidth"] = 0.8 12 | plt.rcParams["xtick.color"] = "#333F4B" 13 | plt.rcParams["ytick.color"] = "#333F4B" 14 | 15 | plt.rcParams["font.family"] = "sans-serif" 16 | plt.rcParams["font.sans-serif"] = "Arial" 17 | 18 | if __name__ == "__main__": 19 | active_accuracy = np.load(ACTIVE_ACCURACY_PATH) 20 | passive_accuracy = np.load(PASSIVE_ACCURACY_PATH) 21 | 22 | colors = core.get_color_palette() 23 | positions = [-0.24, -0.08, 0.08, 0.24] 24 | x = np.array([0, 1]) 25 | x_ticks = np.arange(2) 26 | 27 | f, ax = plt.subplots(1, 1, figsize=(10, 7)) 28 | 29 | for agent_id in range(N_AGENTS): 30 | avg = np.round(np.mean(passive_accuracy[agent_id, :]), 1) 31 | sem = np.std(passive_accuracy[agent_id, :]) / np.sqrt(N_AVERAGES) 32 | plt.bar( 33 | 0 + positions[agent_id], 34 | avg, 35 | width=0.16, 36 | color=colors[agent_id], 37 | align="center", 38 | label=AGENT_NAMES[agent_id], 39 | edgecolor="white", 40 | ) 41 | 42 | for agent_id in range(N_AGENTS): 43 | avg = np.round(np.mean(active_accuracy[agent_id, :]), 1) 44 | sem = np.std(active_accuracy[agent_id, :]) / np.sqrt(N_AVERAGES) 45 | plt.bar( 46 | 1 + positions[agent_id], 47 | avg, 48 | width=0.16, 49 | color=colors[agent_id], 50 | align="center", 51 | edgecolor="white", 52 | ) 53 | 54 | plt.xticks(x_ticks, fontsize=FONT_SIZE) 55 | ax.set_xticklabels(["Passive error", "Active error"]) 56 | 57 | legend = plt.legend() 58 | frame = legend.get_frame() 59 | frame.set_facecolor("1.0") 60 | frame.set_edgecolor("1.0") 61 | 62 | ax.spines["top"].set_visible(False) 63 | ax.spines["right"].set_visible(False) 64 | ax.spines["left"].set_smart_bounds(True) 65 | ax.spines["bottom"].set_smart_bounds(True) 66 | 67 | ax.set_ylabel("Total M.S.E error", {"size": FONT_SIZE}, labelpad=10) 68 | f.savefig(ACTIVE_ACCURACY, dpi=600, bbox_inches="tight") 69 | plt.show() 70 | -------------------------------------------------------------------------------- /src/plot_scripts/plot_accuracy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | from core.config import * 5 | 6 | TICK_SIZE = 16 7 | LEGEND_SIZE = 14 8 | LABEL_SIZE = 20 9 | FIG_SIZE = [9, 7] 10 | 11 | FONT_SIZE = 18 12 | plt.rc("ytick", labelsize=16) 13 | plt.rc("xtick", labelsize=16) 14 | plt.rc("legend", fontsize=FONT_SIZE) 15 | 16 | plt.rcParams["axes.edgecolor"] = "#333F4B" 17 | plt.rcParams["axes.linewidth"] = 1.2 18 | plt.rcParams["xtick.color"] = "#333F4B" 19 | plt.rcParams["ytick.color"] = "#333F4B" 20 | 21 | plt.rcParams["font.family"] = "sans-serif" 22 | plt.rcParams["font.sans-serif"] = "Arial" 23 | 24 | 25 | N_PLOT_EPOCHS = 40 26 | N_AVERAGES = 300 27 | 28 | 29 | def convert_steps_to_ticks(_steps, _epochs): 30 | _steps = _steps[0:_epochs] 31 | _steps = _steps - 20 32 | _steps = [item for item in _steps.astype(str)] 33 | _x_ticks = np.arange(0, _epochs, 10) 34 | _labels = [_steps[int(i)] for i in _x_ticks] 35 | return _x_ticks, _labels 36 | 37 | 38 | if __name__ == "__main__": 39 | steps = np.load(STEPS_PATH) 40 | raw_accuracy = np.load(ACCURACY_PATH) 41 | x_range = range(N_PLOT_EPOCHS) 42 | 43 | colors = get_color_palette() 44 | x_ticks, x_labels = convert_steps_to_ticks(steps, N_PLOT_EPOCHS) 45 | 46 | fig, ax = plt.subplots() 47 | fig.set_size_inches(FIG_SIZE[0], FIG_SIZE[1]) 48 | 49 | for agent_id in range(N_AGENTS): 50 | mean = 2.2 - np.mean(raw_accuracy[agent_id, 0:N_PLOT_EPOCHS, :], axis=1) 51 | std = np.std(raw_accuracy[agent_id, 0:N_PLOT_EPOCHS, :], axis=1) / np.sqrt(N_AVERAGES) 52 | high = mean + std 53 | low = mean - std 54 | 55 | ax.plot(x_range, mean, color=colors[agent_id], lw=2.5, label=AGENT_NAMES[agent_id]) 56 | ax.fill_between(x_range, high, low, color=colors[agent_id], alpha=0.2, linewidth=0) 57 | 58 | ax.spines["top"].set_visible(False) 59 | ax.spines["right"].set_visible(False) 60 | ax.spines["left"].set_smart_bounds(True) 61 | ax.spines["bottom"].set_smart_bounds(True) 62 | 63 | legend = plt.legend() 64 | frame = legend.get_frame() 65 | frame.set_facecolor("1.0") 66 | frame.set_edgecolor("1.0") 67 | 68 | plt.xlabel("Number of learning steps", {"size": LABEL_SIZE}) 69 | plt.ylabel("(-ve) KL from true model", {"size": LABEL_SIZE}) 70 | 71 | plt.xticks(x_ticks) 72 | ax.set_xticklabels(x_labels) 73 | plt.savefig(ACCURACY, dpi=600, bbox_inches="tight") 74 | plt.show() 75 | -------------------------------------------------------------------------------- /src/core/env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from core.config import * 3 | 4 | 5 | class Environment(object): 6 | def __init__( 7 | self, 8 | env_size=ENVIRONMENT_SIZE, 9 | init_distance=INIT_DISTANCE, 10 | source_size=SOURCE_SIZE, 11 | agent_size=AGENT_SIZE, 12 | velocity=VELOCITY, 13 | ): 14 | 15 | self.env_size = env_size 16 | self.init_distance = init_distance 17 | self.source_size = source_size 18 | self.agent_size = agent_size 19 | self.vel = velocity 20 | 21 | self.pos = None 22 | self.s_pos = None 23 | self.theta = None 24 | self.reset() 25 | 26 | def reset(self): 27 | rand_loc = np.random.rand() * (2 * np.pi) 28 | fx = self.env_size / 2 + (self.init_distance * np.cos(rand_loc)) 29 | fy = self.env_size / 2 + (self.init_distance * np.sin(rand_loc)) 30 | 31 | self.pos = [fx, fy] 32 | self.s_pos = [self.env_size / 2, self.env_size / 2] 33 | self.theta = np.random.rand() * (2 * np.pi) 34 | self.observe() 35 | 36 | def observe(self): 37 | fx = self.pos[0] + (self.agent_size * np.cos(self.theta)) 38 | fy = self.pos[1] + (self.agent_size * np.sin(self.theta)) 39 | f_dis = self.dis(fx, fy, self.s_pos[0], self.s_pos[1]) 40 | b_dis = self.dis(self.pos[0], self.pos[1], self.s_pos[0], self.s_pos[1]) 41 | if f_dis > b_dis: 42 | o = NEG_GRADIENT 43 | else: 44 | o = POS_GRADIENT 45 | return o 46 | 47 | def act(self, a): 48 | if a == RUN and self.distance() > self.source_size: 49 | self.pos[0] += self.vel * np.cos(self.theta) 50 | self.pos[1] += self.vel * np.sin(self.theta) 51 | self.check_bounds() 52 | elif a == TUMBLE: 53 | self.theta = np.random.rand() * (2 * np.pi) 54 | 55 | return self.observe() 56 | 57 | def distance(self): 58 | return self.dis(self.pos[0], self.pos[1], self.s_pos[0], self.s_pos[1]) 59 | 60 | def check_bounds(self): 61 | if self.pos[0] > self.env_size: 62 | self.pos[0] = self.env_size 63 | if self.pos[0] < 0: 64 | self.pos[0] = 0 65 | if self.pos[1] > self.env_size: 66 | self.pos[1] = self.env_size 67 | if self.pos[1] < 0: 68 | self.pos[1] = 0 69 | 70 | @staticmethod 71 | def dis(x1, y1, x2, y2): 72 | return np.sqrt(((x1 - x2) * (x1 - x2)) + ((y1 - y2) * (y1 - y2))) 73 | -------------------------------------------------------------------------------- /src/plot_scripts/plot_states.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | from core.config import * 5 | 6 | plt.rc("text", usetex=True) 7 | 8 | 9 | def create_heatmap(matrix, title, save_path, color_bar=True): 10 | f, (ax1, ax2) = plt.subplots(2, 1, figsize=(2.7, 4)) 11 | 12 | x_labels = [r"$s_{t-1}^{neg}$", r"$s_{t-1}^{pos}$"] 13 | y_labels = [r"$s_{t}^{neg}$", r"$s_{t}^{pos}$"] 14 | g1 = sns.heatmap( 15 | matrix[0, :, :] * 100, 16 | cmap="OrRd", 17 | ax=ax1, 18 | vmin=0.0, 19 | vmax=70.0, 20 | linewidth=2.5, 21 | annot=True, 22 | xticklabels=x_labels, 23 | yticklabels=y_labels, 24 | cbar=color_bar, 25 | ) 26 | g2 = sns.heatmap( 27 | matrix[1, :, :] * 100, 28 | cmap="OrRd", 29 | ax=ax2, 30 | vmin=0.0, 31 | vmax=70.0, 32 | linewidth=2.5, 33 | annot=True, 34 | xticklabels=x_labels, 35 | yticklabels=y_labels, 36 | cbar=color_bar, 37 | ) 38 | g1.set_yticklabels(g1.get_yticklabels(), rotation=0, fontsize=14) 39 | g1.set_xticklabels(g1.get_xticklabels(), fontsize=14) 40 | g2.set_yticklabels(g2.get_yticklabels(), rotation=0, fontsize=14) 41 | g2.set_xticklabels(g2.get_xticklabels(), fontsize=14) 42 | 43 | f.savefig(save_path, dpi=600, bbox_inches="tight") 44 | plt.show() 45 | 46 | 47 | if __name__ == "__main__": 48 | states = np.load(STATES_PATH) 49 | create_heatmap(states[FULL_ID, :, :, :], AGENT_NAMES[FULL_ID], FULL_STATES, color_bar=False) 50 | create_heatmap(states[INST_ID, :, :, :], AGENT_NAMES[INST_ID], INST_STATES, color_bar=False) 51 | create_heatmap(states[EPIS_ID, :, :, :], AGENT_NAMES[EPIS_ID], EPIS_STATES, color_bar=False) 52 | create_heatmap(states[RAND_ID, :, :, :], AGENT_NAMES[RAND_ID], RAND_STATES, color_bar=False) 53 | create_heatmap(states[RAND_ID, :, :, :], AGENT_NAMES[RAND_ID], COLOR_BAR, color_bar=True) 54 | 55 | reversed_states = np.load(REVERSED_STATES_PATH) 56 | create_heatmap( 57 | reversed_states[FULL_ID, :, :, :], 58 | AGENT_NAMES[FULL_ID] + " (Reversed prior)", 59 | FULL_STATES_REVERSED, 60 | color_bar=False, 61 | ) 62 | create_heatmap( 63 | reversed_states[INST_ID, :, :, :], 64 | AGENT_NAMES[INST_ID] + " (Reversed prior)", 65 | INST_STATES_REVERSED, 66 | color_bar=False, 67 | ) 68 | create_heatmap( 69 | reversed_states[EPIS_ID, :, :, :], 70 | AGENT_NAMES[EPIS_ID] + " (Reversed prior)", 71 | EPIS_STATES_REVERSED, 72 | color_bar=False, 73 | ) 74 | create_heatmap( 75 | reversed_states[RAND_ID, :, :, :], 76 | AGENT_NAMES[RAND_ID] + " (Reversed prior)", 77 | RAND_STATES_REVERSED, 78 | color_bar=False, 79 | ) 80 | -------------------------------------------------------------------------------- /src/plot_scripts/plot_complexity.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from core.config import * 4 | 5 | TICK_SIZE = 14 6 | 7 | FONT_SIZE = 16 8 | plt.rc("ytick", labelsize=16) 9 | plt.rc("xtick", labelsize=16) 10 | plt.rc("legend", fontsize=13) 11 | 12 | plt.rcParams["axes.edgecolor"] = "#333F4B" 13 | plt.rcParams["axes.linewidth"] = 0.8 14 | plt.rcParams["xtick.color"] = "#333F4B" 15 | plt.rcParams["ytick.color"] = "#333F4B" 16 | 17 | plt.rcParams["font.family"] = "sans-serif" 18 | plt.rcParams["font.sans-serif"] = "Arial" 19 | 20 | plt.rc("text", usetex=True) 21 | 22 | 23 | def plot_complexity(matrix, title, save_path): 24 | colors = get_color_palette() 25 | f, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 7)) 26 | positions = [-0.3, -0.1, 0.1, 0.3] 27 | x_axis = np.array([0.0, 1.0]) 28 | 29 | x_labels_a = [ 30 | r"$P_{\theta}(s_{t}|s_{t-1}^{neg}, u_{t-1}^{tumble})$", 31 | r"$P_{\theta}(s_{t}|s_{t-1}^{pos}, u_{t-1}^{tumble})$", 32 | ] 33 | 34 | x_labels_b = [ 35 | r"$P_{\theta}(s_{t}|s_{t-1}^{neg}, u_{t-1}^{run})$", 36 | r"$P_{\theta}(s_{t}|s_{t-1}^{pos}, u_{t-1}^{run})$", 37 | ] 38 | 39 | for i in range(N_AGENTS): 40 | values = matrix[i, :, :] 41 | mean_values = np.mean(values, axis=1) 42 | _ = np.std(values, axis=1) / np.sqrt(N_AVERAGES) 43 | 44 | x_values = x_axis + positions[i] 45 | ax1.bar( 46 | x_values, 47 | mean_values[:2], 48 | align="center", 49 | width=0.2, 50 | color=colors[i], 51 | edgecolor="white", 52 | label=AGENT_NAMES[i], 53 | ) 54 | ax1.spines["top"].set_visible(False) 55 | ax1.spines["right"].set_visible(False) 56 | ax1.spines["left"].set_smart_bounds(True) 57 | ax1.spines["bottom"].set_smart_bounds(True) 58 | ax1.set_xticks(x_axis) 59 | ax1.set_xticklabels(x_labels_a) 60 | ax1.tick_params(axis="x", which="major", pad=15) 61 | 62 | ax2.bar( 63 | x_values, 64 | mean_values[2:], 65 | align="center", 66 | width=0.2, 67 | color=colors[i], 68 | edgecolor="white", 69 | label=AGENT_NAMES[i], 70 | ) 71 | ax2.spines["top"].set_visible(False) 72 | ax2.spines["right"].set_visible(False) 73 | ax2.spines["left"].set_smart_bounds(True) 74 | ax2.spines["bottom"].set_smart_bounds(True) 75 | ax2.set_xticks(x_axis) 76 | ax2.set_xticklabels(x_labels_b) 77 | ax2.tick_params(axis="x", which="major", pad=15) 78 | 79 | legend = plt.legend() 80 | f.savefig(save_path, dpi=600, bbox_inches="tight") 81 | plt.show() 82 | 83 | 84 | if __name__ == "__main__": 85 | complexity = np.load(COMPLEXITY_PATH) 86 | plot_complexity(complexity, "Change in distributions", COMPLEXITY) 87 | reversed_complexity = np.load(REVERSED_COMPLEXITY_PATH) 88 | plot_complexity( 89 | reversed_complexity, "Change in distributions (reversed prior)", REVERSED_COMPLEXITY 90 | ) 91 | -------------------------------------------------------------------------------- /src/core/config.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import seaborn as sns 3 | import numpy as np 4 | 5 | ############################## 6 | # Simulation config # 7 | ############################## 8 | 9 | N_AVERAGES = 200 10 | TEST_TRIAL_LEN = 2000 11 | MODEL_REDUCTION_TRIAL_LEN = 500 12 | 13 | ############################## 14 | # Environment config # 15 | ############################## 16 | 17 | ENVIRONMENT_SIZE = 500 18 | INIT_DISTANCE = 400 19 | SOURCE_SIZE = 25 20 | AGENT_SIZE = 5 21 | VELOCITY = 1 22 | 23 | TUMBLE = 0 24 | RUN = 1 25 | NEG_GRADIENT = 0 26 | POS_GRADIENT = 1 27 | 28 | ############################## 29 | # Agent config # 30 | ############################## 31 | 32 | FULL_ID = 0 33 | INST_ID = 1 34 | EPIS_ID = 2 35 | RAND_ID = 3 36 | AGENT_NAMES = ["E.F.E", "Instrumental", "Epistemic", "Random"] 37 | N_AGENTS = 4 38 | 39 | ############################## 40 | # MDP config # 41 | ############################## 42 | 43 | N_OBS = 2 44 | N_CONTROL = 2 45 | N_STATES = 2 46 | 47 | N_DISTRIBUTIONS = 4 48 | TUMBLE_NEG_ID = 0 49 | TUMBLE_POS_ID = 1 50 | RUN_NEG_ID = 2 51 | RUN_POS_ID = 3 52 | 53 | PRIOR_ID = 1 54 | ALPHA = 1 / 10 55 | LR = 0.005 56 | 57 | ############################## 58 | # File config # 59 | ############################## 60 | 61 | DISTANCE_PATH = "data/distance.npy" 62 | ACCURACY_PATH = "data/accuracy.npy" 63 | STEPS_PATH = "data/steps.npy" 64 | 65 | STATES_PATH = "data/states.npy" 66 | REVERSED_STATES_PATH = "data/reversed_states.npy" 67 | COMPLEXITY_PATH = "data/complexity.npy" 68 | REVERSED_COMPLEXITY_PATH = "data/reversed_complexity.npy" 69 | 70 | ACTIVE_ACCURACY_PATH = "data/active_accuracy.npy" 71 | PASSIVE_ACCURACY_PATH = "data/passive_accuracy.npy" 72 | 73 | PRUNED_PATH = "data/pruned.npy" 74 | 75 | FAILED_PATH = "data/failed.npy" 76 | 77 | ############################## 78 | # Figures config # 79 | ############################## 80 | 81 | DISTANCE = "figs/distance.pdf" 82 | ACCURACY = "figs/accuracy.pdf" 83 | 84 | FULL_STATES = "figs/full_states.pdf" 85 | INST_STATES = "figs/inst_states.pdf" 86 | EPIS_STATES = "figs/epis_states.pdf" 87 | RAND_STATES = "figs/rand_states.pdf" 88 | COLOR_BAR = "figs/color_bar.pdf" 89 | 90 | FULL_STATES_REVERSED = "figs/full_states_reversed.pdf" 91 | INST_STATES_REVERSED = "figs/inst_states_reversed.pdf" 92 | EPIS_STATES_REVERSED = "figs/epis_states_reversed.pdf" 93 | RAND_STATES_REVERSED = "figs/rand_states_reversed.pdf" 94 | 95 | FULL_PRUNED = "figs/full_pruned.pdf" 96 | INST_PRUNED = "figs/inst_pruned.pdf" 97 | EPIS_PRUNED = "figs/epis_pruned.pdf" 98 | RAND_PRUNED = "figs/rand_pruned.pdf" 99 | COLOR_BAR_PRUNED = "figs/color_bar_pruned.pdf" 100 | TOTAL_PRUNED = "figs/total_pruned.pdf" 101 | 102 | COMPLEXITY = "figs/complexity.pdf" 103 | REVERSED_COMPLEXITY = "figs/reversed_complexity.pdf" 104 | 105 | ACTIVE_ACCURACY = "figs/active_accuracy.pdf" 106 | 107 | FAILED_MODELS = "figs/failed_models.pdf" 108 | 109 | 110 | ############################## 111 | # Plot config # 112 | ############################## 113 | 114 | 115 | def get_color_palette(): 116 | palette = sns.color_palette("Paired", 12) 117 | _colors = [palette[5], palette[7], palette[0], palette[2]] 118 | return _colors 119 | -------------------------------------------------------------------------------- /src/plot_scripts/plot_model_reduction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | import core 5 | from core.config import * 6 | 7 | plt.rc("text", usetex=True) 8 | plt.rcParams["axes.edgecolor"] = "#333F4B" 9 | plt.rcParams["axes.linewidth"] = 0.8 10 | plt.rcParams["xtick.color"] = "#333F4B" 11 | plt.rcParams["ytick.color"] = "#333F4B" 12 | 13 | plt.rcParams["font.family"] = "sans-serif" 14 | plt.rcParams["font.sans-serif"] = "Arial" 15 | plt.rc("ytick", labelsize=12) 16 | plt.rc("xtick", labelsize=12) 17 | 18 | 19 | def create_heatmap(matrix, title, save_path, color_bar=False): 20 | f, (ax1, ax2) = plt.subplots(2, 1, figsize=(2.7, 4)) 21 | 22 | x_labels = [r"$s_{t-1}^{neg}$", r"$s_{t-1}^{pos}$"] 23 | y_labels = [r"$s_{t}^{neg}$", r"$s_{t}^{pos}$"] 24 | g1 = sns.heatmap( 25 | matrix[0, :, :] * 100, 26 | cmap="OrRd", 27 | ax=ax1, 28 | vmin=0.0, 29 | vmax=100.0, 30 | linewidth=2.5, 31 | annot=True, 32 | xticklabels=x_labels, 33 | yticklabels=y_labels, 34 | cbar=color_bar, 35 | ) 36 | g2 = sns.heatmap( 37 | matrix[1, :, :] * 100, 38 | cmap="OrRd", 39 | ax=ax2, 40 | vmin=0.0, 41 | vmax=100.0, 42 | linewidth=2.5, 43 | annot=True, 44 | xticklabels=x_labels, 45 | yticklabels=y_labels, 46 | cbar=color_bar, 47 | ) 48 | g1.set_yticklabels(g1.get_yticklabels(), rotation=0, fontsize=14) 49 | g1.set_xticklabels(g1.get_xticklabels(), fontsize=14) 50 | g2.set_yticklabels(g2.get_yticklabels(), rotation=0, fontsize=14) 51 | g2.set_xticklabels(g2.get_xticklabels(), fontsize=14) 52 | 53 | f.savefig(save_path, dpi=600, bbox_inches="tight") 54 | plt.show() 55 | 56 | 57 | if __name__ == "__main__": 58 | pruned = np.load(PRUNED_PATH) 59 | 60 | create_heatmap(pruned[FULL_ID, :, :, :], AGENT_NAMES[FULL_ID], FULL_PRUNED) 61 | create_heatmap(pruned[INST_ID, :, :, :], AGENT_NAMES[INST_ID], INST_PRUNED) 62 | create_heatmap(pruned[EPIS_ID, :, :, :], AGENT_NAMES[EPIS_ID], EPIS_PRUNED) 63 | create_heatmap(pruned[RAND_ID, :, :, :], AGENT_NAMES[RAND_ID], RAND_PRUNED) 64 | create_heatmap(pruned[RAND_ID, :, :, :], AGENT_NAMES[RAND_ID], COLOR_BAR_PRUNED, color_bar=True) 65 | 66 | full_total = np.sum(pruned[FULL_ID, :, :, :]) 67 | inst_total = np.sum(pruned[INST_ID, :, :, :]) 68 | epis_total = np.sum(pruned[EPIS_ID, :, :, :]) 69 | rand_total = np.sum(pruned[RAND_ID, :, :, :]) 70 | 71 | colors = core.get_color_palette() 72 | x_ticks = range(N_AGENTS) 73 | x_labels = AGENT_NAMES 74 | f, ax = plt.subplots(1, 1, figsize=(5, 4)) 75 | ax.bar( 76 | range(N_AGENTS), 77 | [full_total, inst_total, epis_total, rand_total], 78 | color=colors, 79 | edgecolor="grey", 80 | align="center", 81 | ) 82 | plt.xticks(x_ticks) 83 | ax.set_xticklabels(x_labels) 84 | ax.spines["top"].set_visible(False) 85 | ax.spines["right"].set_visible(False) 86 | ax.spines["left"].set_smart_bounds(True) 87 | ax.spines["bottom"].set_smart_bounds(True) 88 | 89 | f.savefig(TOTAL_PRUNED, dpi=600, bbox_inches="tight") 90 | plt.show() 91 | 92 | # model does not need to know about what happens when you run in negative gradients, or tumble in positive gradients 93 | # we are looking at the learned model - what redundant priors does it entail *in the presence of action*? 94 | -------------------------------------------------------------------------------- /src/scripts/model_reduction_script.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.special import gamma 3 | import pprint 4 | import src.core as core 5 | from src.core.config import * 6 | 7 | 8 | def ln_beta_fn(vec): 9 | numerator = 1 10 | for a in vec: 11 | numerator = numerator * gamma(a) 12 | denominator = sum(vec) 13 | denominator = gamma(denominator) 14 | return np.log(numerator / denominator) 15 | 16 | 17 | def calc_df(prior, posterior, reduced_prior): 18 | prior = np.squeeze(prior) 19 | posterior = np.squeeze(posterior) 20 | posterior_reduced = posterior + reduced_prior - prior 21 | 22 | term_1 = ln_beta_fn(posterior) 23 | term_2 = ln_beta_fn(reduced_prior) 24 | term_3 = ln_beta_fn(prior) 25 | term_4 = ln_beta_fn(posterior_reduced) 26 | 27 | df = term_1 + term_2 - term_3 - term_4 28 | return df 29 | 30 | 31 | def calc_dfs(prior, posterior): 32 | p_0 = np.exp(-16) 33 | _dfs = np.zeros([N_CONTROL, N_STATES, N_STATES]) 34 | 35 | for u in range(N_CONTROL): 36 | for s_t0 in range(N_STATES): 37 | _prior = prior[u, s_t0, :] 38 | _posterior = posterior[u, s_t0, :] 39 | 40 | for s_t1 in range(N_STATES): 41 | reduced_prior = np.copy(_prior) 42 | reduced_prior[s_t1] = p_0 43 | df = calc_df(_prior, _posterior, reduced_prior) 44 | _dfs[u, s_t0, s_t1] = df 45 | return _dfs 46 | 47 | 48 | def perform_model_reduction(agent_id): 49 | print("> Processing agent {}".format(AGENT_NAMES[agent_id])) 50 | _dfs = np.zeros([N_CONTROL, N_STATES, N_STATES, N_AVERAGES]) 51 | _pruned_priors = np.zeros([N_CONTROL, N_STATES, N_STATES, N_AVERAGES]) 52 | 53 | for n in range(N_AVERAGES): 54 | if n % 10 == 0: 55 | print("> Processing average [{}/{}]".format(n, N_AVERAGES)) 56 | mdp = core.get_mdp(agent_id) 57 | mdp = core.learn_trial(mdp, TEST_TRIAL_LEN * 4) 58 | prior = np.copy(mdp.Ba) 59 | mdp = core.learn_trial(mdp, MODEL_REDUCTION_TRIAL_LEN) 60 | _trial_dfs = calc_dfs(prior, mdp.Ba) 61 | _trial_pruned = np.zeros([N_CONTROL, N_STATES, N_STATES]) 62 | _trial_pruned[_trial_dfs < 0.0] = 1 63 | _dfs[:, :, :, n] = _trial_dfs 64 | _pruned_priors[:, :, :, n] = _trial_pruned 65 | 66 | return _dfs, _pruned_priors 67 | 68 | 69 | if __name__ == "__main__": 70 | print("> Processing model reduction") 71 | pruned = np.zeros([N_AGENTS, N_CONTROL, N_STATES, N_STATES]) 72 | 73 | dfs_full, pruned_full = perform_model_reduction(FULL_ID) 74 | dfs_mean = np.mean(dfs_full[:, :, :, :], axis=-1) 75 | pruned_sum = np.sum(pruned_full[:, :, :, :], axis=-1) 76 | pprint.pprint(dfs_mean) 77 | pprint.pprint(pruned_sum / N_AVERAGES) 78 | print(np.sum(pruned_sum) / N_AVERAGES) 79 | pruned[FULL_ID, :, :, :] = pruned_sum / N_AVERAGES 80 | 81 | dfs_inst, pruned_inst = perform_model_reduction(INST_ID) 82 | dfs_mean = np.mean(dfs_inst[:, :, :, :], axis=-1) 83 | pruned_sum = np.sum(pruned_inst[:, :, :, :], axis=-1) 84 | pprint.pprint(dfs_mean) 85 | pprint.pprint(pruned_sum / N_AVERAGES) 86 | print(np.sum(pruned_sum) / N_AVERAGES) 87 | pruned[INST_ID, :, :, :] = pruned_sum / N_AVERAGES 88 | 89 | dfs_epis, pruned_epis = perform_model_reduction(EPIS_ID) 90 | dfs_mean = np.mean(dfs_epis[:, :, :, :], axis=-1) 91 | pruned_sum = np.sum(pruned_epis[:, :, :, :], axis=-1) 92 | pprint.pprint(dfs_mean) 93 | pprint.pprint(pruned_sum / N_AVERAGES) 94 | print(np.sum(pruned_sum) / N_AVERAGES) 95 | pruned[EPIS_ID, :, :, :] = pruned_sum / N_AVERAGES 96 | 97 | dfs_rand, pruned_rand = perform_model_reduction(RAND_ID) 98 | dfs_mean = np.mean(dfs_rand[:, :, :, :], axis=-1) 99 | pruned_sum = np.sum(pruned_rand[:, :, :, :], axis=-1) 100 | pprint.pprint(dfs_mean) 101 | pprint.pprint(pruned_sum / N_AVERAGES) 102 | print(np.sum(pruned_sum) / N_AVERAGES) 103 | pruned[RAND_ID, :, :, :] = pruned_sum / N_AVERAGES 104 | 105 | np.save(PRUNED_PATH, pruned) 106 | print("> Data saved") 107 | 108 | # maybe increase trial len 109 | -------------------------------------------------------------------------------- /src/core/mdp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class MDP(object): 5 | def __init__(self, a, b, c, lr=0.1, alpha=1, beta=1): 6 | 7 | self.A = a 8 | self.B = b 9 | self.C = c 10 | 11 | self.alpha = alpha 12 | self.beta = beta 13 | self.lr = lr 14 | self.p0 = np.exp(-16) 15 | 16 | if np.size(self.C, 1) > np.size(self.C, 0): 17 | self.C = self.C.T 18 | 19 | self.Ns = self.A.shape[1] 20 | self.No = self.A.shape[0] 21 | self.Nu = self.B.shape[0] 22 | 23 | self.A = self.A + self.p0 24 | self.A = self.normdist(self.A) 25 | self.lnA = np.log(self.A) 26 | 27 | self.B = self.B + self.p0 28 | for u in range(self.Nu): 29 | self.B[u] = self.normdist(self.B[u]) 30 | self.Ba = np.copy(self.B) 31 | self.wB = 0 32 | self.calc_wb() 33 | 34 | self.true_B = self.get_true_model() 35 | 36 | self.C = self.C + self.p0 37 | self.C = self.normdist(self.C) 38 | 39 | self.sQ = np.zeros([self.Ns, 1]) 40 | self.uQ = np.zeros([self.Nu, 1]) 41 | self.EFE = np.zeros([self.Nu, 1]) 42 | 43 | self.action_range = np.arange(0, self.Nu) 44 | self.obv = 0 45 | self.action = 0 46 | 47 | def reset(self, obv): 48 | self.obv = obv 49 | likelihood = self.lnA[obv, :] 50 | likelihood = likelihood[:, np.newaxis] 51 | self.sQ = self.softmax(likelihood) 52 | self.action = int(np.random.choice(self.action_range)) 53 | 54 | def step(self, obv): 55 | self.obv = obv 56 | self.infer_sQ(obv) 57 | self.evaluate_efe() 58 | self.infer_uq() 59 | return self.act() 60 | 61 | def infer_sQ(self, obv): 62 | likelihood = self.lnA[obv, :] 63 | likelihood = likelihood[:, np.newaxis] 64 | prior = np.dot(self.B[self.action], self.sQ) 65 | prior = np.log(prior) 66 | self.sQ = self.softmax(likelihood + prior) 67 | 68 | def evaluate_efe(self): 69 | self.EFE = np.zeros([self.Nu, 1]) 70 | 71 | for u in range(self.Nu): 72 | fs = np.dot(self.B[u], self.sQ) 73 | fo = np.dot(self.A, fs) 74 | fo = self.normdist(fo + self.p0) 75 | 76 | utility = (np.sum(fo * np.log(fo / self.C), axis=0)) * self.alpha 77 | utility = utility[0] 78 | surprise = self.bayesian_surprise(u, fs) * self.beta 79 | 80 | self.EFE[u] -= utility 81 | self.EFE[u] += surprise 82 | 83 | def infer_uq(self): 84 | self.uQ = self.softmax(self.EFE) 85 | 86 | def update(self, action, new, previous): 87 | self.Ba[action, new, previous] += self.lr 88 | b = np.copy(self.Ba[action]) 89 | self.B[action] = self.normdist(b) 90 | self.calc_wb() 91 | 92 | def calc_expectation(self): 93 | for u in range(self.Nu): 94 | b = np.copy(self.Ba[u]) 95 | self.B[u] = self.normdist(b) 96 | self.calc_wb() 97 | 98 | def calc_wb(self): 99 | wb_norm = np.copy(self.Ba) 100 | wb_avg = np.copy(self.Ba) 101 | 102 | for u in range(self.Nu): 103 | for s in range(self.Ns): 104 | wb_norm[u, :, s] = np.divide(1.0, np.sum(wb_norm[u, :, s])) 105 | wb_avg[u, :, s] = np.divide(1.0, (wb_avg[u, :, s])) 106 | 107 | self.wB = wb_norm - wb_avg 108 | 109 | def act(self): 110 | hu = max(self.uQ) 111 | options = np.where(self.uQ == hu)[0] 112 | self.action = int(np.random.choice(options)) 113 | return self.action 114 | 115 | def bayesian_surprise(self, u, fs): 116 | surprise = 0 117 | wb = self.wB[u, :, :] 118 | for st in range(self.Ns): 119 | for s in range(self.Ns): 120 | surprise += fs[st] * wb[st, s] * self.sQ[s] 121 | return -surprise 122 | 123 | def predict_obv(self, action, obv): 124 | _obv = np.zeros([2, 1]) + self.p0 125 | _obv[obv] = 1 126 | fs = np.dot(self.B[action], _obv) 127 | fo = np.dot(self.A, fs) 128 | fo = self.normdist(fo + self.p0) 129 | 130 | tfs = np.dot(self.true_B[action], _obv) 131 | tfo = np.dot(self.A, tfs) 132 | tfo = self.normdist(tfo + self.p0) 133 | return fo, tfo 134 | 135 | @staticmethod 136 | def entropy(fs): 137 | fs = fs[:, 0] 138 | return -np.sum(fs * np.log(fs), axis=0) 139 | 140 | @staticmethod 141 | def softmax(x): 142 | x = x - x.max() 143 | x = np.exp(x) 144 | x = x / np.sum(x) 145 | return x 146 | 147 | @staticmethod 148 | def normdist(x): 149 | return np.dot(x, np.diag(1 / np.sum(x, 0))) 150 | 151 | @staticmethod 152 | def get_true_model(): 153 | b = np.zeros([2, 2, 2]) 154 | b[0, :, :] = np.array([[0.5, 0.5], [0.5, 0.5]]) 155 | b[1, :, :] = np.array([[1, 0], [0, 1]]) 156 | b += np.exp(-16) 157 | return b 158 | --------------------------------------------------------------------------------