├── .floydexpt
├── .floydignore
├── .gitignore
├── .idea
├── fun-with-dnc.iml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── .vscode
└── settings.json
├── README.md
├── __pycache__
├── dnc.cpython-35.pyc
└── generators.cpython-35.pyc
├── arg.py
├── dnc_arity_list.py
├── floydrequirements.txt
├── images
├── acc.png
└── training.png
├── losses.py
├── misc
├── 78fa9003f6c0f735bc3250fc2116f6100463.pdf
├── dnc_sizes.ods
└── train.csv
├── old
├── dnc.py
├── dnc_clean.py
├── dnc_model.png
├── dnc_stateful.py
├── dnc_v1_safe.py
├── generators.py
├── notes.txt
└── train.py
├── problem
├── __init__.py
├── copy_squence.py
├── generators_v2.py
├── logic.py
├── lp_utils.py
├── my_air_cargo_problems.py
├── my_planning_graph.py
├── planning.py
├── search.py
└── utils.py
├── run.py
├── setup.sh
├── tests.py
├── tf.conf
├── training.py
├── utils.py
└── visualize
├── logger.py
└── wuddido.py
/.floydexpt:
--------------------------------------------------------------------------------
1 | {"family_id": "B7wQJjD7KMoyM7nAkQXfJj", "name": "dnc"}
--------------------------------------------------------------------------------
/.floydignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | .git
3 | .idea
4 | .gitignore
5 | __pycache__
6 | runs
7 | models
8 | old
9 | misc
10 | images
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | __pycache__
3 | runs
4 |
5 | models
6 |
--------------------------------------------------------------------------------
/.idea/fun-with-dnc.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.linting.pylintEnabled": true
3 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Fun-With-Dnc (Differentiable Neural Computing)
2 |
3 | Pytorch implementation of deepmind paper [Hybrid computing using a neural network with dynamic external memory]: https://pdfs.semanticscholar.org/7635/78fa9003f6c0f735bc3250fc2116f6100463.pdf. The code is based on the tensorflow implementation [here]: https://github.com/deepmind/dnc.
4 |
5 | Todo finish retraining, and better writeup.
6 |
7 | ## Problems and Expirements
8 | There are a few tasks setup. One is the "Air Cargo Prolbem" from Arificial Intelligence (Russell & Norvig). The origional code for the problem is based on the [Udacity Implementation]: https://github.com/udacity/AIND-Planning , and the full description is in the problem repo.
9 |
10 | The Air Cargo problem can be seen as structured prediction, (every prediction step can be seen as changing the state of the problem). The algorithms used to solve it in the book included graphplan, and Astar search of the state space, putting it in the same family of problems as the Blocks Problem (SHRDLU) solved in the origional paper.
11 |
12 | ## Installation
13 | Requires pytorch (no cuda)
14 |
15 | conda install pytorch torchvision cuda80 -c soumith
16 |
17 | ## Tensorboard
18 | Run tensorboard with "--log 10" flag (the number is logging frequency). Below is a shot during training. Losses are recorded seperately for each entity and type, as well as the action for more addictive monitoring.
19 |
20 | 
21 | 
22 |
23 | To run with tensorboard:
24 | pip install tensorboardX (for tensorboard)
25 | pip install tensorflow (for tensorboard web server)
26 |
27 | ## Training Scenarios
28 | ### Planning
29 | The code implements a training schedule as in the paper. Start small with the minimum sized problem (2 entities of each kind)
30 |
31 | python run.py --act plan --iters 1000 --ret_graph 1 --zero_at step --n_phases 20 --opt_at step
32 | python run.py --act plan --iters 1000 --ret_graph 1 --zero_at step --n_phases 20 --opt_at step --save opt_zero_step
33 | python run.py --act plan --iters 1000 --ret_graph 0 --opt_at problem --save opt_problem_plan --n_phases 20
34 |
35 |
36 | We humans would think about the problem in terms of actions and type, so I thought the first thing the DNC would start getting correct would be the (Action, typeofthing1, typeofthing2, typeofthing3) 'tuple', since those must be correct in order to reliably get the instance correct. This was indeed the case as can be seen on the 'accuracies' plots during training. By the scemantics of the problem, the last 'type' is always Airplane, so that goes to 100% accuracy immediately. The next chunk of 1/3td of the training bumps up the types to 0.9-1.0 range. Only then does the loss for the entities themselves start dropping consistently. Even then, the ent1 and ent3 were coupled, which in the logic of the problem...
37 |
38 | To show details at each step of what was predicted vs best moves, specify the --detail flag. You will get something like this:
39 |
40 | trial 978, step 19514 trial accy: 6/7, 0.86, pass total 296/978, running avg 0.7463, loss 0.0774
41 | best Load ['C1', 'P1', 'A1'], Fly ['P1', 'A1', 'A0']
42 | chosen: Load ['C1', 'P1', 'A1'], guided True, prob 0.25, T? True ---loss 0.2553
43 | best Fly ['P1', 'A1', 'A0'], Unload ['C1', 'P1', 'A0']
44 | chosen: Fly ['P1', 'A1', 'A0'], guided True, prob 0.33, T? True ---loss 0.0784
45 | best Unload ['C1', 'P1', 'A0'], Fly ['P0', 'A0', 'A1']
46 | chosen: Unload ['C1', 'P1', 'A0'], guided True, prob 0.33, T? True ---loss 0.0830
47 | best Fly ['P0', 'A0', 'A1'], Load ['C0', 'P0', 'A1']
48 | chosen: Fly ['C0', 'P0', 'A1'], guided True, prob 0.25, T? False ---loss 0.3716
49 | best Load ['C0', 'P0', 'A1'], Fly ['P0', 'A1', 'A0']
50 | chosen: Load ['C0', 'P0', 'A1'], guided True, prob 0.25, T? True ---loss 0.1288
51 | best Fly ['P0', 'A1', 'A0'], Unload ['C0', 'P0', 'A0']
52 | chosen: Fly ['P0', 'A1', 'A0'], guided False, prob 0.25, T? True ---loss 1.1554
53 | best Fly ['P0', 'A1', 'A0'], Unload ['C0', 'P0', 'A0']
54 | chosen: Unload ['C0', 'P1', 'A0'], guided True, prob 0.33, T? False ---loss 0.9901
55 | best Unload ['C0', 'P0', 'A0']
56 | chosen: Unload ['C0', 'P0', 'A0'], guided False, prob 0.33, T? True ---loss 0.8087
57 | best Fly ['P0', 'A1', 'A0'], Unload ['C0', 'P0', 'A0']
58 | chosen: Unload ['C0', 'P0', 'A0'], guided True, prob 0.33, T? True ---loss 0.7677
59 |
60 | The best actions are what was deterimined by the problem heurstics (not always optimal to save time). The chosen action is what the DNC ended up chosing. 'Guided' refers to Beta from the paper. 'Prob' is the chance of chosing that action (out of all legal actions), and the loss is there as well.
61 |
62 |
63 | ### Question Answering
64 | Another task that would be interesting I figured would be to give the DNC a problem (initial state, and goal), then make some moves, and ask where a certain Cargo is (which airport is it in? is it in a plane? which plane?). This did not work too well. See the run.py train_qa function.
65 |
66 | python run.py --act qa --iters 1000 --n_phases 20
67 |
68 | ## Training Misc
69 | ### Other Problems
70 | In an initial pass I tested with the sequence Memorization task from the deepmind repo. I have not tested it recently and I doubt it works (see todo). To run this specify the wit the problem
71 |
72 | ### Other Setups
73 | The DNC was tested against vanilla Lstms. The Lstm appears to get stuck on air cargo problem at ~40%. To run the training with LSTM only specify with '--algo LSTM' flag like so:
74 |
75 | python run.py --act plan --algo lstm --iters 1000 --n_phases 20
76 |
77 | ### Misc
78 | Training at each 'level' took 20K steps. This is way more than reported in the paper. On my crappy home CPU, this meant about a day, aka forever. Since I also lost my computer, causing me to need to retrain everything, I only got through the first level of training before having to submit (2 airports, 2 cargos, 2 planes).
79 |
80 |
81 | ## Differences from Original
82 | There was some expirementation here, so there are a bunch of flags on when to optimize. In the paper they calculated loss at end of each problem. This did not work for me, so I ended up with running the optimzer after each response.
83 |
84 | ## Loading Previous Run
85 |
86 | python run.py --act plan --iters 1000 --n_phases 20 --load the_saved_name_or_path --save the_new
87 |
88 | ## Flags
89 |
90 |
91 | ## Running on floydhub
92 | Set the --env flag to floyd. When it gets up there, the script will create all the directories in /output. Tensorboard for pytorch does not appear to work on there for reasons I do not understand.
93 |
94 | floyd run --env pytorch-0.2 --tensorboard "bash setup.sh && python run.py --act dag --iters 1000 --env floyd"
95 |
96 | ## Todo
97 | Upload best models
98 | Test the sequence memorization task. probably does not work.
99 |
100 |
101 | ~~Implement with GPU.~~
102 |
103 |
104 | ~~Faster problem generator~~
105 |
106 |
107 | ~~fix tensorboard issues~~
108 |
109 |
110 | ~~gradient clipping~~
111 |
112 |
113 | visualization of what dnc is doing internally (per paper)
114 |
115 |
116 | penalty for bad actions when not using the beta coefficient for forcing
117 |
118 |
119 | losses by prediction (fast loss)
120 |
121 |
122 | run whole lstm on input and goal state?
123 |
124 |
125 | Document args in argparse
126 |
127 |
128 | Testing on moar problems.
129 |
130 |
131 |
132 |
--------------------------------------------------------------------------------
/__pycache__/dnc.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/psavine42/fun-with-dnc/bdf110762f00347cae77bb3689d628d053893a4b/__pycache__/dnc.cpython-35.pyc
--------------------------------------------------------------------------------
/__pycache__/generators.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/psavine42/fun-with-dnc/bdf110762f00347cae77bb3689d628d053893a4b/__pycache__/generators.cpython-35.pyc
--------------------------------------------------------------------------------
/arg.py:
--------------------------------------------------------------------------------
1 | import argparse, os, time, json
2 |
3 | parser = argparse.ArgumentParser(description='Hyperparams')
4 | ########### STANDARD ENV SPEC ##############
5 | parser.add_argument('-a', '--act', nargs='?', type=str, default='train', help='[]')
6 | parser.add_argument('-l', '--load', nargs='?', type=str, default='', help='load model and state')
7 | parser.add_argument('--start_epoch', nargs='?', type=int, default=0, help='load model and state checkpt start')
8 | parser.add_argument('-s', '--save', nargs='?', type=str, default='', help='save if true')
9 | parser.add_argument('-i', '--iters', nargs='?', type=int, default=21, help='number of iterations')
10 | parser.add_argument('--env', type=str, default='', help='')
11 | parser.add_argument('--lr', type=float, default=1e-5, help='learning rate 1e-5 in paper')
12 | parser.add_argument('--checkpoint_every', type=int, default=5, help='save models every n epochs')
13 | parser.add_argument('--log', type=int, default=0, help='send to tensorboard frequency. 0 to disable')
14 | parser.add_argument('--notes', type=str, default='', help='any notes')
15 | parser.add_argument('-d', '--show_details', type=int, default=20, help='any notes')
16 |
17 | ########### Algorithm and Optimizer ##############
18 | parser.add_argument('--opt', type=str, default='adam', help='optimizer to use. options are sgd and adam')
19 | parser.add_argument('--algo', type=str, default='dnc', help='dnc or lstm. default is dnc')
20 | parser.add_argument('-c', '--cuda', type=int, default=0, help='device to run - 1 if cuda, 0 if cpu')
21 | parser.add_argument('--clip', type=float, default=0, help='gradient clipping, if zero, no clip')
22 |
23 | ########### CONTROL FLOW ##############
24 | parser.add_argument('--feed_last', type=int, default=1, help='')
25 | parser.add_argument('--opt_at', type=str, default='step', help='')
26 | parser.add_argument('--zero_at', type=str, default='step', help='')
27 | parser.add_argument('--ret_graph', type=int, default=1, help='retain graph todo change to ')
28 | parser.add_argument('--rpkg_step', type=int, default=1, help='repackage input vars at each step')
29 |
30 | ########### PROBLEM SETUP ##############
31 | parser.add_argument('-p', '--n_phases', type=int, default=15, help='number of training phases')
32 | parser.add_argument('--n_cargo', type=int, default=2, help='number of cargo at starts')
33 | parser.add_argument('--n_plane', type=int, default=2, help='number of plane at starts')
34 | parser.add_argument('--n_airport', type=int, default=2, help='number of airports at starts')
35 | parser.add_argument('-n', '--n_init_start', type=int, default=2, help='number of entities to start with')
36 | parser.add_argument('--typed', type=int, default=1, help='1=use typed entity descriptions, 0=one hot each entity')
37 |
38 | ########### PROBLEM CONTROL ##############
39 | parser.add_argument('--passing', type=float, default=0.9, help='passing percentage for a run default 0.9')
40 | parser.add_argument('--num_tests', type=int, default=2, help='')
41 | parser.add_argument('--num_repeats', type=int, default=2, help='')
42 | parser.add_argument('--max_ents', type=int, default=6, help='maximum number of entities')
43 | parser.add_argument('--beta', type=float, default=0.8, help='mixture param from paper')
44 | parser.add_argument('--penalty', type=float, default=0.0, help='mixture param from paper')
45 | args = parser.parse_args()
46 | print('\n\n')
47 |
48 | args.repakge_each_step = True if args.rpkg_step == 1 else False
49 | args.ret_graph = True if args.ret_graph == 1 else False
50 | args.cuda = True if args.cuda == 1 else False
51 | args.clip = None if args.clip == 0 else args.clip
52 | args.prefix = '/output/' if args.env == 'floyd' else './'
53 | args.penalty = None if args.penalty == 0.0 else args.penalty
54 |
55 | if not os.path.exists(args.prefix + 'models'):
56 | os.mkdir(args.prefix + 'models')
57 |
58 | start_timer = time.time()
59 | start = '{:0.0f}'.format(start_timer)
60 |
61 | if args.save != '':
62 | args.base_dir = '{}{}{}_{}/'.format(args.prefix, 'models/', start, args.save)
63 | os.mkdir(args.base_dir)
64 | os.mkdir(args.base_dir + 'checkpts/')
65 | argparse_dict = vars(args)
66 | with open(args.base_dir + 'params.txt', 'w') as outfile:
67 | json.dump(argparse_dict, outfile)
68 | print('Saving in folder {}'.format(args.base_dir))
69 |
70 | writer = None
71 | if args.log > 0:
72 | from tensorboardX import SummaryWriter
73 | global writer
74 | writer = SummaryWriter()
75 | from visualize import logger
76 | logger.log_step += args.log
77 |
78 | # test +new no log #
79 | # python run.py -a plan -n 3 -p 2 -d 5 --iters 10 --opt_at step --zero_at step
80 |
81 | # test +load #
82 | # python run.py -a plan -n 3 -p 2 -d 5 --load --iters 10 --opt_at step --zero_at step
83 |
84 | # test +new +log
85 | # python run.py -a plan -n 3 -p 2 -d 5 --iters 10 --log 2 --opt_at step --zero_at step
86 |
87 | # test +new +cuda #
88 | # python run.py -a plan -n 3 -p 2 -d 5 --iters 10 --cuda 1 --opt_at step --zero_at step
89 |
90 | # test +new +cuda +clip #
91 | # python run.py -a plan -n 3 -p 2 -d 5 --clip 20 --iters 10 --cuda 1 --opt_at step --zero_at step
92 |
93 | # test +new +load +log #
94 | # python run.py -a plan -n 2 --n_phases 2 --show_details 5 --clip 40 --log 2 --iters 10 --cuda 0 --opt_at step --zero_at step
95 |
96 |
97 |
98 |
99 | # floyd run --env pytorch-0.2 --tensorboard "bash setup.sh && python run.py --act run --opt_at problem --ret_graph 0 --env floyd --save _nopkg --n_phases 2 --iters 10000"
100 | # python run.py --act run --opt_at problem --ret_graph 0 --save _nopkg --n_phases 2 --iters 10000
101 |
102 | # QA training
103 | # floyd run --env pytorch-0.2 --tensorboard "bash setup.sh && python run.py --act dag --iters 1000 --save _new_hidden --ret_graph 1 --opt_at step --env floyd"
104 | # floyd run --cpu --env pytorch-0.2 --tensorboard 'bash setup.sh && python run.py --act dag --iters 1000 --save _new_hidden --ret_graph 1 --opt_at step --zero_at step --env floyd'
105 |
--------------------------------------------------------------------------------
/floydrequirements.txt:
--------------------------------------------------------------------------------
1 |
2 | tensorflow
3 | tensorboardX
--------------------------------------------------------------------------------
/images/acc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/psavine42/fun-with-dnc/bdf110762f00347cae77bb3689d628d053893a4b/images/acc.png
--------------------------------------------------------------------------------
/images/training.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/psavine42/fun-with-dnc/bdf110762f00347cae77bb3689d628d053893a4b/images/training.png
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from visualize import logger as sl
3 | from utils import flat, repackage, _variable
4 |
5 |
6 | def action_loss(logits, action, criterion, log=None):
7 | """
8 | Sum of losses of one hot vectors encoding an action
9 | :param logits: network output vector of [action, [[type_i, ent_i], for i in ents]]
10 | :param action: target vector size [7]
11 | :param criterion: loss function
12 | :return:
13 | """
14 | losses = []
15 | for idx, action_part in enumerate(flat(action)):
16 | tgt = _variable(torch.LongTensor([action_part]))
17 | losses.append(criterion(logits[idx], tgt))
18 | loss = torch.stack(losses, 0).mean()
19 | if log is not None:
20 | sl.log_loss(losses, loss)
21 | return loss
22 |
23 |
24 | def get_top_prediction(expanded_logits, idxs=None):
25 | max_idxs = []
26 | idxs = range(len(expanded_logits)) if idxs is None else idxs
27 | for idx in idxs:
28 | _, pidx = expanded_logits[idx].data.topk(1)
29 | max_idxs.append(pidx.squeeze()[0])
30 | return tuple(max_idxs)
31 |
32 |
33 | def combined_ent_loss(logits, action, criterion, log=None):
34 | """
35 | some hand tunining of penalties for illegal actions...
36 | trying to force learning of types.
37 |
38 | action type => type_e...
39 | :param logits: network output vector of one_hot distributions
40 | [action, [type_i, ent_i], for i in ents]
41 | :param action: target vector size [7]
42 | :param criterion: loss function
43 | :return:
44 | """
45 | losses = []
46 | for idx, action_part in enumerate(flat(action)):
47 | tgt = _variable(torch.Tensor([action_part]).float())
48 | losses.append(criterion(logits[idx], tgt))
49 | lfs = [[losses[0]]]
50 | n = 2
51 | for l in(losses[i:i+n] for i in range(1, len(losses), n)):
52 | lfs.append(torch.stack(losses, 0).sum())
53 | loss = torch.stack(lfs, 0).mean()
54 | if log is not None:
55 | sl.log_loss(losses, loss)
56 | return loss
57 |
58 |
59 | def naive_loss(logits, targets, criterion, log=None):
60 | """
61 | Calculate best choice from among targets, and return loss
62 |
63 | :param logits:
64 | :param targets:
65 | :param criterion:
66 | :return: loss
67 | """
68 | # copy_logits = depackage(logits)
69 | # final_action = closest_action(copy_logits, targets)
70 | loss_idx, _ = min(enumerate([action_loss(repackage(logits), a, criterion) for a in targets]))
71 | final_action = targets[loss_idx]
72 | return final_action, action_loss(logits, final_action, criterion, log=log)
73 |
74 |
75 |
76 |
--------------------------------------------------------------------------------
/misc/78fa9003f6c0f735bc3250fc2116f6100463.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/psavine42/fun-with-dnc/bdf110762f00347cae77bb3689d628d053893a4b/misc/78fa9003f6c0f735bc3250fc2116f6100463.pdf
--------------------------------------------------------------------------------
/misc/dnc_sizes.ods:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/psavine42/fun-with-dnc/bdf110762f00347cae77bb3689d628d053893a4b/misc/dnc_sizes.ods
--------------------------------------------------------------------------------
/old/dnc_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/psavine42/fun-with-dnc/bdf110762f00347cae77bb3689d628d053893a4b/old/dnc_model.png
--------------------------------------------------------------------------------
/old/generators.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from torch.utils.data import Dataset
4 | import numpy as np
5 | from timeit import default_timer as timer
6 | import torch
7 | parent = os.path.dirname(os.path.realpath(__file__))
8 | sys.path.insert(0, os.path.dirname(parent))
9 | import aimacode
10 | import my_air_cargo_problems as mac
11 | from aimacode.search import *
12 | from aimacode.utils import Expr
13 |
14 | props = {'At':0, 'In':1}
15 | actions = {'Fly':0, 'Load':1, 'Unload':2}
16 | actions_1h = {0:'Fly', 1:'Load', 2:'Unload'}
17 |
18 | exprs = ['Fly', 'Load', 'Unload', 'At', 'In']
19 |
20 | phases = ['State', 'Goal', 'Plan', 'Solve']
21 |
22 | phase_to_ix = {word: i for i, word in enumerate(phases)}
23 | ix_to_phase = {i: word for i, word in enumerate(phases)}
24 | exprs_to_ix = {exxp: i for i, exxp in enumerate(exprs)}
25 | ix_to_exprs = {i: exxp for i, exxp in enumerate(exprs)}
26 |
27 | def encoding_():
28 | pass
29 |
30 | def swap_fly(fly_action):
31 | """
32 |
33 | :param fly_action: Fly(P1, A1, A0)
34 | :return: Fly(P1 _ A0)
35 | """
36 |
37 | pass
38 |
39 | class Encoded_Expr():
40 | def __init__(self, op, args):
41 | self.op = str(op)
42 | self.args = args
43 | self.one_hot = []
44 |
45 | def vec_to_expr(self):
46 | pass
47 |
48 |
49 |
50 |
51 | class EncodedAirCargoProblem(mac.AirCargoProblem):
52 | def __init__(self, problem, init_vec, goals_vec, mp_merged, one_hot_size):
53 | self.problem = problem
54 | self.succs = self.goal_tests = self.states = 0
55 | self.found = None
56 | self.init_state = init_vec
57 |
58 | self.goal_state = goals_vec
59 | # dictionary of mappings of ents to one_hot
60 | self.problem_ents = mp_merged
61 | # dictionary pf mappings of one_hot to ents
62 | self.ent_to_vec = self.flip_encoding(mp_merged)
63 | print(self.problem_ents)
64 | print(self.ent_to_vec)
65 | self.one_hot_size = one_hot_size
66 | self.solution_node = None
67 | self.entity_o_h = torch.eye(one_hot_size).long()
68 | self.action_o_h = torch.eye(3).long()
69 | self.types_o_h = torch.eye(3).long()
70 | self.phases_o_h = torch.LongTensor([[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]])
71 |
72 | def flip_encoding(self, ents):
73 | encoding = {}
74 | for key, value in ents.items():
75 | encoding[str(list(value))] = key
76 | return encoding
77 |
78 | def reverse_lookup(self, one_hot):
79 | return next(key for key, value in self.problem_ents.items() if value == one_hot)
80 |
81 | def action_expr_to_vec(self, action_expr):
82 | """
83 |
84 | :param action_expr: Action Expr object Fly(P0, A1, A2)
85 | :return: action vec [0,
86 | """
87 | action_vec = [actions[action_expr.name]]
88 | for arg in action_expr.args:
89 | e_type, ent = self.problem_ents[str(arg)]
90 | action_vec.append(e_type)
91 | action_vec.append(ent)
92 | print("action_vec", action_vec)
93 | return np.asarray(action_vec, dtype=int)
94 |
95 | def get_best_action_vecs(self, solution_node):
96 | """
97 |
98 | :param solution_node: Graph Node reporesenting solution
99 | :return: vectors for each expr in solution
100 | """
101 | action_vecs = []
102 | for action_expr in solution_node.solution()[0:len(self.problem.planes)]:
103 | action_vec = self.action_expr_to_vec(action_expr)
104 | action_vecs.append(torch.from_numpy(action_vec))
105 | return action_vecs
106 |
107 | def get_all_actions(self, state):
108 | zz = [torch.from_numpy(self.action_expr_to_vec(a)) for a in self.problem.actions(state)]
109 | return zz
110 |
111 | def encode_solution(self, solution):
112 | zz = [torch.from_numpy(self.action_expr_to_vec(a)) for a in solution]
113 | return zz
114 |
115 | def decode_action(self, coded_action):
116 | sym = actions_1h[coded_action[0]]
117 | ent1 = str(list(coded_action[1:3]))
118 | ent2 = str(list(coded_action[3:5]))
119 | ent3 = str(list(coded_action[5:7]))
120 | ex1 = self.ent_to_vec[ent1]
121 | ex2 = self.ent_to_vec[ent2]
122 | ex3 = self.ent_to_vec[ent3]
123 | return sym, [ex1, ex2, ex3]
124 |
125 | def send_action(self, state, coded_action):
126 | sym, args = self.decode_action(coded_action)
127 | actions_ = self.actions(state)
128 | print(actions_)
129 | final_act = []
130 | for a in actions_:
131 | if a.name == sym and all((str(ar) == at) for ar, at in zip(a.args, args)):
132 | final_act = a
133 | break
134 | assert final_act != []
135 | print(final_act)
136 | result_state = self.problem.result(state, final_act)
137 | return result_state
138 |
139 | def actions(self, state):
140 | self.succs += 1
141 | return self.problem.actions(state)
142 |
143 | def result(self, state, action):
144 | self.states += 1
145 | return self.problem.result(state, action)
146 |
147 | def goal_test(self, state):
148 | self.goal_tests += 1
149 | result = self.problem.goal_test(state)
150 | if result:
151 | self.found = state
152 | return result
153 |
154 | def run_search(self, search_fn, parameter=None):
155 | if parameter is not None:
156 | prm = getattr(self.problem, parameter)
157 | node = search_fn(self.problem, prm)
158 | else:
159 | node = search_fn(self.problem)
160 | return node
161 |
162 | def path_cost(self, c, state1, action, state2):
163 | return self.problem.path_cost(c, state1, action, state2)
164 |
165 | def value(self, state):
166 | return self.problem.value(state)
167 |
168 | def __getattr__(self, attr):
169 | return getattr(self.problem, attr)
170 |
171 |
172 | class AirCargoData():
173 | """
174 | Flags for
175 | """
176 | def __init__(self,
177 | num_plane=10, num_cargo=6, batch_size=6,
178 | num_airport=1000, one_hot_size=10, mode='loose',
179 | search_function=astar_search):
180 | self.num_plane = num_plane
181 | self.num_cargo = num_cargo
182 | self.num_airport = num_airport
183 | self.batch_size = batch_size
184 | self._actions_mode = mode
185 | self.one_hot_size = one_hot_size
186 | self.search_fn = search_function
187 | self.STATE = ''
188 | self.search_param = 'h_ignore_preconditions'
189 | self.problem = None
190 | self.current_problem = None
191 | self.current_index = 0
192 | #self.make_new_problem()
193 |
194 | def print_solution(self, node):
195 | for action in node.solution():
196 | print("{}{}".format(action.name, action.args))
197 |
198 | def phase_vec(self, tensor_, add_channel):
199 | chans = torch.stack([torch.LongTensor(add_channel) for _ in range(tensor_.size(0))], dim=0)
200 | return torch.cat([chans, tensor_.long()], dim=-1)
201 |
202 | def get_actions(self, mode='strict'):
203 | self.problem.problem.initial = self.STATE
204 | if mode == 'all':
205 | return self.problem.get_all_actions(self.STATE)
206 | else:
207 | solution = self.problem.run_search(self.search_fn, self.search_param)
208 | return self.problem.get_best_action_vecs(solution)
209 |
210 | def encode_action(self, action_obj):
211 | return torch.from_numpy(self.problem.encode_action(action_obj)).long()
212 |
213 | def send_action(self, coded_action):
214 | self.problem.problem.initial = self.STATE
215 | self.STATE = self.problem.send_action(self.STATE, coded_action)
216 | return True
217 |
218 | def vec_to_one_hot(self, coded_ent):
219 | ents = []
220 | for idx in range(1, 7, 2):
221 | e_type = coded_ent[idx]
222 | ent = coded_ent[idx + 1]
223 | if e_type == 0 and ent == 0:
224 | ents.append(torch.zeros(3 + self.one_hot_size).long())
225 | else:
226 | ents.append(torch.cat([self.problem.types_o_h[e_type], self.problem.entity_o_h[ent]], 0))
227 | return torch.cat(ents, 0)
228 |
229 | def expand_state_vec(self, coded_state):
230 | """Input target vec representing cross entropy loss target [1 0 2 0 0 0 0]
231 | Returns a one hot version of it as training input [01 00, 100, 000, 000, 000]"""
232 | ents = self.vec_to_one_hot(coded_state)
233 | phase = self.problem.phases_o_h[coded_state[0]]
234 | return torch.cat([phase, ents], 0)
235 |
236 | def expand_action_vec(self, coded_action):
237 | """Input target vec representing action
238 | [1 0 2 0 0 0 0]
239 | Returns a one hot version of it as training input
240 | [01 00, 100, 000, 000, 000]
241 | """
242 | ents = self.vec_to_one_hot(coded_action)
243 | action = self.problem.action_o_h[coded_action[0]]
244 | return torch.cat([action, ents], 0).unsqueeze(0).float()
245 |
246 | def make_new_problem(self):
247 | acp, i, g, m = mac.arbitrary_ACP(self.num_airport, self.num_plane,
248 | self.num_cargo, one_hot_size=self.one_hot_size)
249 | problem = EncodedAirCargoProblem(acp, i, g, m, self.one_hot_size)
250 |
251 | # print(problem)
252 | # run the solution to determine how long to give dnc
253 | solution_node = problem.run_search(self.search_fn, self.search_param)
254 |
255 | word_len = len(problem.init_state[0])
256 | len_init_phase = len(problem.init_state)
257 | len_goal_phase = len(problem.goal_state)
258 | len_plan_phase = 10
259 | len_resp_phase = len(solution_node.solution()) + 6
260 |
261 | # determine the number of iterations input will happen for
262 | mask_zero = torch.zeros(len_init_phase + len_goal_phase + len_plan_phase)
263 | mask_ones = torch.ones(len_resp_phase)
264 | masks = torch.cat([mask_zero, mask_ones], 0)
265 |
266 | init_phs_data = self.phase_vec(torch.from_numpy(problem.init_state), [0])
267 | goal_phs_data = self.phase_vec(torch.from_numpy(problem.goal_state), [1])
268 | # during planning and response phases, there is no inputs.
269 | plan_phs_data = self.phase_vec(torch.zeros(len_plan_phase, word_len), [2])
270 | resp_phs_data = self.phase_vec(torch.zeros(len_resp_phase, word_len), [3])
271 |
272 | inputs = torch.cat([init_phs_data, goal_phs_data, plan_phs_data, resp_phs_data], 0)
273 | self.current_problem = [inputs, masks]
274 | self.problem = problem
275 | self.STATE = problem.initial
276 | self.current_index = 0
277 | return problem.problem, solution_node, masks
278 |
279 | def len__(self):
280 | if self.current_index >= self.current_problem[0].size(0):
281 | self.make_new_problem()
282 | return len(self.current_problem[1])
283 |
284 | def getitem(self, batch=1):
285 | """Returns a problem, [initial-state, goals]
286 | and a runnable solution object [problem, solution_node]
287 |
288 | Otherwise take the target one_hot class mask in form of
289 | [ent1-type, ent1 ....entN, channel]
290 |
291 | """
292 | if self.current_index >= self.current_problem[0].size(0):
293 | self.make_new_problem()
294 |
295 | masks = self.current_problem[1][self.current_index:self.current_index+batch]
296 | inputs = self.current_problem[0][self.current_index:self.current_index + batch]
297 | inputs = torch.stack([self.expand_state_vec(i).float() for i in inputs], 0)
298 | self.current_index += batch
299 |
300 | return inputs, masks
301 |
302 |
303 | class RandomData(Dataset):
304 | def __init__(self,
305 | num_seq=10,
306 | seq_len=6,
307 | iters=1000,
308 | seq_width=4):
309 | self.seq_width = seq_width
310 | self.num_seq = num_seq
311 | self.seq_len = seq_len
312 | self.iters = iters
313 |
314 | def __getitem__(self, index):
315 | con = np.random.randint(0, self.seq_width, size=self.seq_len)
316 | seq = np.zeros((self.seq_len, self.seq_width))
317 | seq[np.arange(self.seq_len), con] = 1
318 | end = torch.from_numpy(np.asarray([[-1] * self.seq_width])).float()
319 | zer = np.zeros((self.seq_len, self.seq_width))
320 | return seq, zer
321 |
322 | def __len__(self):
323 | return self.iters
324 |
325 | class GraphData(Dataset):
326 | def __init__(self,
327 | num_seq=10,
328 | seq_len=6,
329 | iters=1000,
330 | domain=None,
331 | actions=None,
332 | start_state=None,
333 | seq_width=4):
334 | """
335 | Each vector encoded a triple consisting of a source label,
336 | an edge label and a destination label.
337 | All labels were represented as numbers between 0 and 999,
338 | with each digit represented as a 10-way one-hot encoding.
339 | We reserved a special ‘blank’ label, represented by the all-zero vector
340 | for the three digits, to indicate an unspecified label.
341 | Each label required 30 input elements, and each triple required 90.
342 | The sequences were divided into multiple phases:
343 | 1) first a graph description phase,
344 | then a series of query (2 Q) and answer (3 A) phases;
345 | in some cases the Q and A were separated by an additional planning phase
346 | with no input, during which the network was given time to compute the answer.
347 | During the graph description phase, the triples defining the input graph were
348 | presented in random order.
349 | Target vectors were present during only the answer phases.
350 |
351 | Params from Paper:
352 | GRAPH
353 | input vectors were size 92
354 | 90 info
355 | binary chan for phase transition
356 | binary chan for when prediction is needed
357 | target vectors 90
358 | BABI
359 | input vector of 159 one-hot-vector of words (156 unique words + 3 tokens)
360 |
361 | Propositions:
362 | :: one-hot 0-9
363 | :: [At ( C1 , SFO )]
364 | :: [{0..1} ( {0..1} , {0..1} )] + [True/False, end-exp, pred-required?]
365 | Load({}, {}, {})
366 | special tokens [ '(', ')', ',', ]
367 |
368 | Actions vs Propositions
369 | Action => precond_pos , precond_neg , effects_pos , effects_ned]
370 | Eat(Cake) => [[Have(Cake) ], [], , [Eaten(Cake)], [Not Have(Cake)]]
371 | +---------------------------------------------------------------------------+
372 | | 1 9 3 9 8 4 9 2 9 3 5 9 2 9 3 7 4 9 2 9 |
373 | | 1 |
374 | +---------------------------------------------------------------------------+
375 |
376 | input: T? Op ( Pred )
377 | +----------------------------+
378 | start-seq | 1|
379 | Eat(Cake) | 00 0001 1001 0011 1001 0 0| Action Name | h * h * args
380 | Have(Cake) | 11 0100 1001 0011 1001 1 0| } Pre-Conditions
381 | . | 0 0|
382 | . | 0 0|
383 | Eaten(Cake) | 11 0101 1001 0011 1001 1 0| } Post-Conditions
384 | ¬ Have(Cake) | 01 0100 1001 0011 1001 0 1|
385 | +----------------------------+
386 |
387 | Final input is concat of all statements
388 |
389 | Goal -> At(C1, Place )
390 |
391 | Paper (first block, adjacency relation, second block)
392 | (100000, 1000, 010000)
393 | “block 1 above block 2”
394 |
395 | let the goals be 1 of 26 possible letters designated by one-hot encodings;
396 | that is, A = (1, 0, ..., 0), Z = (0, 0, ..., 1) and so on
397 |
398 | -The board is represented as a set of place-coded representations, one for each square.
399 | Therefore, (000000, 100000, ...) designates that the bottom,
400 | left-hand square is empty, block 1 is in the bottom centre square, and so on
401 |
402 | The network also sees a binary flag that represents a ‘go cue’.
403 | While the go cue is active, a goal is selected from the list of goals that have
404 | been shown to the network, its label is retransmitted to the network for one
405 | time-step, and the network can begin to move the blocks on the board.
406 |
407 | All told, the policy observes at each time-step a vector with features
408 |
409 | Constraints 16
410 | (goal name, first block, adjacency relation, second block, go cue, board state).
411 | [26 ... (6 4 6 )x6 1 63] -> 186 ~state
412 |
413 | there are 7 possible actions so output is mapped to size 7 one_hot vector.
414 |
415 | 10 goals -> 250,
416 |
417 | Up to 10 goals with 6 constraints each can be sent to the network before action begins.
418 |
419 | Once the go cue arrives, it is possible for the policy network to move a block
420 | from one column to another or to pass at each turn.
421 | We parameterize these actions using another one-hot encoding so that,
422 | for a 3 × 3 board, a move can be made from any column to any other;
423 | with the pass move, there are therefore 7 moves.
424 |
425 |
426 | 8 Airports , 4 Cargos, 4 Airplanes
427 | 0001 , 0000 , 0100
428 |
429 | Network - [2 x 250]
430 |
431 |
432 |
433 | Input at t: prev
434 |
435 |
436 | Types of tasks ->
437 | 1) given True Statements, we may want to generate negative statements
438 | 2)
439 |
440 | """
441 | self.seq_width = seq_width
442 | self.num_seq = num_seq
443 | self.seq_len = seq_len
444 | self.iters = iters
445 |
446 | def __getitem__(self, index):
447 | #con = np.random.randint(0, self.seq_width, size=self.seq_len)
448 |
449 |
450 |
451 | return None #seq, zer
452 |
453 | def __len__(self):
454 | return self.iters
455 |
456 |
--------------------------------------------------------------------------------
/old/notes.txt:
--------------------------------------------------------------------------------
1 |
2 | >>> rnn = nn.LSTM(10, 20, 2)
3 | input_size – The number of expected features in the input x
4 | hidden_size – The number of features in the hidden state h
5 | num_layers – Number of recurrent layers.
6 |
7 | input (seq_len, batch, input_size): tensor containing the features of the input sequence. The input can also be a packed variable length sequence. See torch.nn.utils.rnn.pack_padded_sequence() for details.
8 | h_0 (num_layers * num_directions, batch, hidden_size): tensor containing the initial hidden state for each element in the batch.
9 | c_0 (num_layers * num_directions, batch, hidden_size): tensor containing the initial cell state for each element in the batch.
10 |
11 |
12 | >>> input = Variable(torch.randn(5, 3, 10))
13 | >>> h0 = Variable(torch.randn(2, 3, 20))
14 | >>> c0 = Variable(torch.randn(2, 3, 20))
15 | >>> output, hn = rnn(input, (h0, c0))
16 |
17 | rnn = nn.LSTM(word_size, hidden_size, num_layers)
18 | h0 = Variable(torch.randn(num_layers, z, hidden_size))
19 | c0 = Variable(torch.randn(q, z, y))
20 |
21 | input = Variable(torch.randn(5, z, x))
22 |
23 | output, hn = rnn(input, (h0, c0))
24 |
25 | nn (x, y, q)
26 | in (a, z, x)
27 | hdn (q, z, y)
28 | hdn (q, z, y)
--------------------------------------------------------------------------------
/old/train.py:
--------------------------------------------------------------------------------
1 |
2 | def train_qa(args, num_problems, data, Dnc, optimizer, dnc_state, save_=False):
3 | """
4 | I am jacks liver. This is a sanity test
5 |
6 | 0 - describe state.
7 | 1 - describe goal.
8 | 2 - do actions.
9 | 3 - ask some questions
10 | :param args:
11 | :return:
12 | """
13 | sl.log_step += args.log
14 |
15 | print(Dnc)
16 | criterion = nn.CrossEntropyLoss()
17 |
18 | cum_correct = []
19 | cum_total_move = []
20 | num_tests = 2
21 | num_repeats = 1
22 | for n in range(num_problems):
23 | masks = data.make_new_problem()
24 | num_correct = 0
25 | total_moves = 0
26 | # prev_action = None
27 |
28 | if args.repakge_each_step is False:
29 | dnc_state = dnc.repackage(dnc_state)
30 | optimizer.zero_grad()
31 | # repackage the state to zero grad at start of each problem
32 |
33 | for idx, mask in enumerate(masks):
34 | sl.global_step += 1
35 | inputs, mask_ = data.getitem()
36 |
37 | if mask == 0 or mask == 1:
38 | for nz in range(num_repeats):
39 | inputs1 = Variable(torch.cat([mask_, inputs], 1))
40 | logits, dnc_state = Dnc(inputs1, dnc_state)
41 |
42 | # sl.log_state(dnc_state)
43 | else:
44 | targets_star = data.get_actions(mode='one')
45 | if targets_star == []:
46 | break
47 | final_move = targets_star[0]
48 |
49 | if args.repakge_each_step is True:
50 | dnc_state = dnc.repackage(dnc_state)
51 | optimizer.zero_grad()
52 |
53 | for nq in range(num_repeats):
54 | inputs2 = Variable(torch.cat([data.phase_oh[2].unsqueeze(0), data.vec_to_ix(final_move)], 1))
55 | logits, dnc_state = Dnc(inputs2, dnc_state)
56 |
57 | data.send_action(final_move)
58 | # sl.log_state(dnc_state)
59 |
60 | for _ in range(num_tests):
61 | total_moves += 1
62 | state_expr = random.choice(data.pull_state())
63 | state_vec = data.expr_to_vec(state_expr)
64 | mask_idx = 2 # random.randint(1, 2)
65 | mask_chunk = state_vec[mask_idx]
66 |
67 | zeros = 0 if type(mask_chunk) == int else tuple([0] * len(mask_chunk))
68 | masked_state_vec = state_vec.copy()
69 | # masked_state_vec[mask_idx] = zeros
70 | masked_state_vec[2] = zeros
71 |
72 | inputs3 = Variable(torch.cat([data.phase_oh[3].unsqueeze(0), data.vec_to_ix(masked_state_vec)], 1))
73 | logits, dnc_state = Dnc(inputs3, dnc_state)
74 |
75 | expanded_logits = data.ix_input_to_ixs(logits)
76 | idx1, idx2 = mask_idx * 2 - 1, mask_idx * 2
77 |
78 | target_chunk1 = Variable(torch.LongTensor([mask_chunk[0]]))
79 | target_chunk2 = Variable(torch.LongTensor([mask_chunk[1]]))
80 |
81 | loss1 = criterion(expanded_logits[idx1], target_chunk1)
82 | loss2 = criterion(expanded_logits[idx2], target_chunk2)
83 | loss = loss1 + loss2
84 | sl.log_state(dnc_state)
85 |
86 | loss.backward(retain_graph=args.ret_graph)
87 | optimizer.step()
88 | sl.log_state(dnc_state)
89 |
90 | resp1, pidx1 = expanded_logits[idx1].data.topk(1)
91 | resp2, pidx2 = expanded_logits[idx2].data.topk(1)
92 | pred_tuple = pidx1.squeeze()[0], pidx2.squeeze()[0]
93 | # pred_expr = data.lookup_ix_to_expr(pred_tuple)
94 | correct_ = mask_chunk == pred_tuple
95 | num_correct += 1 if mask_chunk == pred_tuple else 0
96 | # print("step {}.{}, loss: {:0.2f}, state {} actual {} pred: {}, {}".format(
97 | # n, idx, loss.data[0], state_expr, mask_chunk, pred_tuple, correct_))
98 | sl.log_loss_qa(loss1, loss2, loss)
99 |
100 | cum_total_move.append(total_moves)
101 | cum_correct.append(num_correct)
102 | trial_acc = num_correct / total_moves
103 | sl.writer.add_scalar('recall.pct_correct', trial_acc, sl.global_step)
104 | sl.log_state(dnc_state)
105 | sl.log_model(Dnc)
106 |
107 | print("trial: {} accy {:0.4f}, cum_score {:0.4f}".format(n, trial_acc, sum(cum_correct[-100:]) / sum(cum_total_move[-100:])))
108 | if save_ is not False:
109 | save(Dnc, dnc_state, start, args.save, sl.global_step)
110 |
111 | score = sum(cum_correct[-100:]) / sum(cum_total_move[-100:])
112 | return Dnc, optimizer, dnc_state, score
113 |
114 | def setupDNC(args):
115 | if args.algo == 'lstm':
116 | return setupLSTM(args)
117 | data = gen.AirCargoData(**generate_data_spec(args))
118 | dnc_args['output_size'] = data.nn_in_size # output has no phase component
119 | dnc_args['word_len'] = data.nn_out_size
120 | print(dnc_args)
121 | if args.load == '':
122 | Dnc = dnc.DNC(**dnc_args)
123 | o, m, r, w, l, lw, u, (ho, hc) = Dnc.init_state()
124 | else:
125 | model_path = './models/dnc_model_' + args.load
126 | state_path = './models/dnc_state_' + args.load
127 | print('loading', model_path, state_path)
128 | Dnc = dnc.DNC(**dnc_args)
129 | Dnc.load_state_dict(torch.load(model_path))
130 | o, m, r, w, l, lw, u, (ho, hc) = torch.load(state_path)
131 | print(dnc_checksum([o, m, r, w, l, lw, u]))
132 |
133 | lr = 5e-5 if args.lr is None else args.lr
134 | if args.opt == 'adam':
135 | optimizer = optim.Adam([{'params': Dnc.parameters()}, {'params': o}, {'params': m}, {'params': r}, {'params': w},
136 | {'params': l}, {'params': lw}, {'params': u}, {'params': ho}, {'params': hc}],
137 | lr=lr)
138 | else:
139 | optimizer = optim.SGD([{'params': Dnc.parameters()}, {'params': o}, {'params': m}, {'params': r}, {'params': w},
140 | {'params': l}, {'params': lw}, {'params': u}, {'params': ho}, {'params': hc}],
141 | lr=lr)
142 | dnc_state = (o, m, r, w, l, lw, u, (ho, hc))
143 | return data, Dnc, optimizer, dnc_state
144 |
145 |
146 | """
147 | #target_chunk1 = Variable(torch.LongTensor([mask_chunk[0]]))
148 | #target_chunk2 = Variable(torch.LongTensor([mask_chunk[1]]))
149 | #loss1 = criterion(expanded_logits[idx1], target_chunk1)
150 | #loss2 = criterion(expanded_logits[idx2], target_chunk2)
151 | #lstep = loss1 + loss2
152 | # action_loss(logits, expanded_logits, criterion)
153 | """
--------------------------------------------------------------------------------
/problem/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/psavine42/fun-with-dnc/bdf110762f00347cae77bb3689d628d053893a4b/problem/__init__.py
--------------------------------------------------------------------------------
/problem/copy_squence.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/psavine42/fun-with-dnc/bdf110762f00347cae77bb3689d628d053893a4b/problem/copy_squence.py
--------------------------------------------------------------------------------
/problem/generators_v2.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import numpy as np
3 | import torch
4 | from utils import flat
5 | from . import my_air_cargo_problems as mac
6 | from .search import *
7 | from .planning import Action
8 | import random
9 |
10 |
11 | EXPRESSIONS = ['Fly', 'Load', 'Unload', 'At', 'In']
12 | PHASES = ['State', 'Goal', 'Plan', 'Solve']
13 |
14 | phase_to_ix = {word: i for i, word in enumerate(PHASES)}
15 | ix_to_phase = {i: word for i, word in enumerate(PHASES)}
16 | exprs_to_ix = {exxp: i for i, exxp in enumerate(EXPRESSIONS)}
17 | ix_to_exprs = {i: exxp for i, exxp in enumerate(EXPRESSIONS)}
18 |
19 |
20 | default_encoding = {'expr', 'action', 'ent1-type', 'ent1'}
21 | ix_size = [9, 9, 9]
22 | ix2_size = [3, 6, 3, 6, 3, 6]
23 | split1 = [[0, 1], [0, 2], [0, 3]]
24 | split2 = [[0, 1], [1, 2], [0, 3], [1, 4], [0, 5], [1, 6]]
25 |
26 | permutations = {'At': {'insert': [0, 0], 'tst': [0, 'P'], 'idx': [2, 2], 'permute': [4, 5, 0, 1, 2, 3]},
27 | 'Fly': {'insert': [0, 0], 'idx':0, 'permute': [0, 1, 2, 3, 4, 5,]},
28 | 'Load': {'insert': [0, 0], 'idx':0, 'permute': [0, 1, 2, 3, 4, 5,]},
29 | 'In': {'insert': [0, 0], 'idx':0, 'permute': [0, 1, 2, 3, 4, 5,]},
30 | 'Unload': {'insert': [0, 0], 'idx':0, 'permute': [0, 1, 2, 3, 4, 5,]},
31 | }
32 |
33 | perm_one = {'At': {'insert': [0], 'tst': [0, 'P'], 'idx': [2, 2], 'permute': [0, 1, 2]},
34 | 'Fly': {'insert': [0], 'idx':0, 'permute': [0, 1, 2]},
35 | 'Load': {'insert': [0], 'idx':0, 'permute': [0, 1, 2]},
36 | 'In': {'insert': [0], 'idx':0, 'permute': [0, 1, 2]},
37 | 'Unload': {'insert': [0], 'idx':0, 'permute': [0, 1, 2]},
38 | }
39 |
40 |
41 | class BitData():
42 | def __init__(self,
43 | num_bits=6,
44 | batch_size=1,
45 | min_length=1,
46 | max_length=1,
47 | min_repeats=1,
48 | max_repeats=2,
49 | norm_max=10,
50 | log_prob_in_bits=False,
51 | time_average_cost=False,
52 | name='repeat_copy',):
53 | self._batch_size = batch_size
54 | self._num_bits = num_bits
55 | self._min_length = min_length
56 | self._max_length = max_length
57 | self._min_repeats = min_repeats
58 | self._max_repeats = max_repeats
59 | self._norm_max = norm_max
60 | self._log_prob_in_bits = log_prob_in_bits
61 | self._time_average_cost = time_average_cost
62 |
63 | def _normilize(self, val):
64 | return val / self._norm_max
65 |
66 | def _unnormalize(self, val):
67 | return val * self._norm_max
68 |
69 | @property
70 | def time_average_cost(self):
71 | return self._time_average_cost
72 |
73 | @property
74 | def log_prob_in_bits(self):
75 | return self._log_prob_in_bits
76 |
77 | @property
78 | def num_bits(self):
79 | """The dimensionality of each random binary vector in a pattern."""
80 | return self._num_bits
81 |
82 | @property
83 | def target_size(self):
84 | """The dimensionality of the target tensor."""
85 | return self._num_bits + 1
86 |
87 | @property
88 | def batch_size(self):
89 | return self._batch_size
90 |
91 | def get_item(self):
92 | min_length, max_length = self._min_length, self._max_length
93 | min_reps, max_reps = self._min_repeats, self._max_repeats
94 | num_bits = self.num_bits
95 | batch_size = self.batch_size
96 |
97 | # We reserve one dimension for the num-repeats and one for the start-marker.
98 | full_obs_size = num_bits + 2
99 | # We reserve one target dimension for the end-marker.
100 | full_targ_size = num_bits + 1
101 | start_end_flag_idx = full_obs_size - 2
102 | num_repeats_channel_idx = full_obs_size - 1
103 |
104 | # Samples each batch index's sequence length and the number of repeats
105 | sub_seq_length_batch = torch.from_numpy(
106 | np.random.randint(min_length, max_length + 1, batch_size)).long()
107 | num_repeats_batch = torch.from_numpy(
108 | np.random.randint(min_reps, max_reps + 1, batch_size)).long()
109 |
110 | # Pads all the batches to have the same total sequence length.
111 | total_length_batch = sub_seq_length_batch * (num_repeats_batch + 1) + 3
112 | max_length_batch = total_length_batch.max()
113 | residual_length_batch = max_length_batch - total_length_batch
114 |
115 | obs_batch_shape = [max_length_batch, batch_size, full_obs_size]
116 | targ_batch_shape = [max_length_batch, batch_size, full_targ_size]
117 | mask_batch_trans_shape = [batch_size, max_length_batch]
118 |
119 | obs_tensors = []
120 | targ_tensors = []
121 | mask_tensors = []
122 |
123 | # Generates patterns for each batch element independently.
124 | for batch_index in range(batch_size):
125 | sub_seq_len = sub_seq_length_batch[batch_index]
126 | num_reps = num_repeats_batch[batch_index]
127 |
128 | # The observation pattern is a sequence of random binary vectors.
129 | obs_pattern_shape = [sub_seq_len, num_bits]
130 | # obs_pattern = torch.LongTensor(obs_pattern_shape).uniform_(0, 2).float()
131 | obs_pattern = torch.from_numpy(np.random.randint(0, 2, obs_pattern_shape)).float()
132 | print(obs_pattern.size())
133 | # The target pattern is the observation pattern repeated n times.
134 | # Some reshaping is required to accomplish the tiling.
135 | targ_pattern_shape = [sub_seq_len * num_reps, num_bits]
136 | flat_obs_pattern = obs_pattern.view(-1)
137 | print(flat_obs_pattern, targ_pattern_shape)
138 | flat_targ_pattern = flat_obs_pattern.expand(num_reps)
139 | targ_pattern = torch.reshape(flat_targ_pattern, targ_pattern_shape)
140 |
141 | # Expand the obs_pattern to have two extra channels for flags.
142 | # Concatenate start flag and num_reps flag to the sequence.
143 | obs_flag_channel_pad = torch.zeros([sub_seq_len, 2])
144 | obs_start_flag = torch.eye([start_end_flag_idx], full_obs_size).float()
145 | num_reps_flag = torch.eye(num_repeats_channel_idx, full_obs_size)
146 |
147 | # note the concatenation dimensions.
148 | obs = torch.cat([obs_pattern, obs_flag_channel_pad], 1)
149 | obs = torch.cat([obs_start_flag, obs], 0)
150 | obs = torch.cat([obs, num_reps_flag], 0)
151 |
152 | # Now do the same for the targ_pattern (it only has one extra channel).
153 | targ_flag_channel_pad = torch.zeros([sub_seq_len * num_reps, 1])
154 | targ_end_flag = torch.eye(start_end_flag_idx, full_targ_size)
155 | targ = torch.cat([targ_pattern, targ_flag_channel_pad], 1)
156 | targ = torch.cat([targ, targ_end_flag], 0)
157 |
158 | # Concatenate zeros at end of obs and begining of targ.
159 | # This aligns them s.t. the target begins as soon as the obs ends.
160 | obs_end_pad = torch.zeros([sub_seq_len * num_reps + 1, full_obs_size])
161 | targ_start_pad = torch.zeros([sub_seq_len + 2, full_targ_size])
162 |
163 | # The mask is zero during the obs and one during the targ.
164 | mask_off = torch.zeros([sub_seq_len + 2])
165 | mask_on = torch.ones([sub_seq_len * num_reps + 1])
166 |
167 | obs = torch.cat([obs, obs_end_pad], 0)
168 | targ = torch.cat([targ_start_pad, targ], 0)
169 | mask = torch.cat([mask_off, mask_on], 0)
170 |
171 | obs_tensors.append(obs)
172 | targ_tensors.append(targ)
173 | mask_tensors.append(mask)
174 |
175 | # End the loop over batch index.
176 | # Compute how much zero padding is needed to make tensors sequences
177 | # the same length for all batch elements.
178 | residual_obs_pad = [
179 | torch.zeros(residual_length_batch[i], full_obs_size) for i in range(batch_size)]
180 | residual_targ_pad = [
181 | torch.zeros(residual_length_batch[i], full_targ_size) for i in range(batch_size)]
182 |
183 | residual_mask_pad = [torch.zeros(residual_length_batch[i]) for i in range(batch_size)]
184 |
185 | # Concatenate the pad to each batch element.
186 | obs_tensors = [
187 | torch.cat([o, p], 0) for o, p in zip(obs_tensors, residual_obs_pad)
188 | ]
189 | targ_tensors = [
190 | torch.cat([t, p], 0) for t, p in zip(targ_tensors, residual_targ_pad)
191 | ]
192 | mask_tensors = [
193 | torch.cat([m, p], 0) for m, p in zip(mask_tensors, residual_mask_pad)
194 | ]
195 |
196 | # Concatenate each batch element into a single tensor.
197 | obs = torch.cat(obs_tensors, 1).view(obs_batch_shape)
198 | targ = torch.cat(targ_tensors, 1).view(targ_batch_shape)
199 | mask = torch.cat(mask_tensors, 0).reshape(mask_batch_trans_shape).t()
200 | return obs, targ, mask
201 |
202 |
203 | class AirCargoData():
204 | """
205 | Flags for
206 | """
207 | def __init__(self, num_plane=10, num_cargo=6, batch_size=6,
208 | num_airport=1000, plan_phase=1, cuda=False,
209 | one_hot_size=10, encoding=2, mapping=None,
210 | search_function=astar_search, solve=True):
211 | self.n_plane, self.n_cargo, self.n_airport = num_plane, num_cargo, num_airport
212 | self.plan_len = plan_phase
213 | self.batch_size = batch_size
214 | self.mapping = mapping
215 | self.solve = solve
216 | self.encoding = encoding
217 | self.one_hot_size = [one_hot_size] if type(one_hot_size) == int else one_hot_size
218 | self.search_fn = search_function
219 |
220 | self.search_param = 'h_ignore_preconditions'
221 | self.ents_to_ix, self.ix_to_ents = None, None
222 |
223 | self.STATE, self.INIT_STATE = '', ''
224 | self.current_index, self.cuda = 0, cuda
225 | self.current_problem, self.goals, self.state = None, None, None
226 |
227 | self.phase_oh = torch.eye(len(PHASES))
228 | self.blnk_vec = torch.zeros(len(EXPRESSIONS) * 2).float() #
229 | self.expr_o_h = torch.cat([torch.zeros([1, len(EXPRESSIONS)]), torch.eye(len(EXPRESSIONS))], 0).float()
230 | self.ents_o_h = None
231 | self.masks = []
232 |
233 | # new indices
234 | self.goals_idx, self.cargo_in_idx = {}, {}
235 | self.cache, self.encodings = {}, {}
236 | self.make_new_problem()
237 | print(self.plan_len)
238 |
239 | @property
240 | def nn_in_size(self):
241 | return self.blnk_vec.size(-1) + self.phase_oh.size(-1)
242 |
243 | @property
244 | def nn_out_size(self):
245 | return self.blnk_vec.size(-1)
246 |
247 | def lookup_expr_to_ix(self, _expr):
248 | return self.ents_to_ix[str(_expr)]
249 |
250 | def lookup_ix_to_expr(self, ix) -> str:
251 | if ix in self.ix_to_ents:
252 | return self.ix_to_ents[ix]
253 | else:
254 | return "NA"
255 |
256 | def masked_input(self):
257 | state_expr = random.choice(self.pull_state())
258 | state_vec = self.expr_to_vec(state_expr)
259 | mask_idx = 2
260 | mask_chunk = state_vec[mask_idx]
261 |
262 | zeros = 0 if type(mask_chunk) == int else tuple([0] * len(mask_chunk))
263 | masked_state_vec = state_vec.copy()
264 | masked_state_vec[2] = zeros
265 | inputs = torch.cat([self.phase_oh[3].unsqueeze(0), self.vec_to_ix(masked_state_vec)], 1)
266 | return inputs, mask_chunk, state_vec
267 |
268 | def generate_encodings(self):
269 | """
270 |
271 | :param ix_to_ents:
272 | :return:
273 | """
274 | noops = [torch.zeros(1, i) for i in self.one_hot_size]
275 | self.ents_o_h = []
276 | for noop, ix_size in zip(noops, self.one_hot_size):
277 | ix_enc = torch.cat([noop, torch.eye(ix_size)], 0).float()
278 | self.ents_o_h.append(ix_enc)
279 | self.blnk_vec = torch.cat([torch.zeros([1, len(EXPRESSIONS)]), torch.cat(noops, 1)], 1)
280 |
281 | def print_solution(self, node):
282 | for action in node.solution():
283 | print("{}{}".format(action.name, action.args))
284 |
285 | def best_logic(self, action_exprs):
286 | """
287 | Astar search takes forever for this problem.
288 | best logical thing to do is handcode the ops needed.
289 | :param action_exprs:
290 | :return:
291 | """
292 | if self.goals_idx == {}: # the problem has no remaining goals
293 | return []
294 | best_actions, at_goal = [], []
295 | for action in action_exprs:
296 | op = action.name
297 | if op == 'Unload':
298 | cargo = action.args[0]
299 | if cargo in self.goals_idx and self.goals_idx[cargo] == action.args[2]:
300 | at_goal.append(action)
301 | continue
302 | elif op == 'Load':
303 | cargo = action.args[0]
304 | # load cargo if it is not in its home
305 | if cargo in self.goals_idx:
306 | at_goal.append(action)
307 | continue
308 | elif op == 'Fly':
309 | plane, airpt, dest_n = action.args
310 | cargos_in_plane = [cargo for cargo, _in in self.cargo_in_idx.items() if _in == plane]
311 | # if there is a cargo in the plane
312 | # and desitination is the cargo's goal
313 | if cargos_in_plane != []:
314 | cargo = cargos_in_plane[0] # assumes one cargo
315 | if cargo in self.goals_idx and self.goals_idx[cargo] == dest_n:
316 | at_goal.append(action)
317 | continue
318 | # if there are no cargos at the airport, fly to destination with GOAL cargo
319 | cargos_in_airpt = [cargo for cargo, _in in self.cargo_in_idx.items() if _in == airpt]
320 | if cargos_in_airpt == []:
321 | cargos_in_dest = [cargo for cargo, _in in self.cargo_in_idx.items() if _in == dest_n]
322 | if cargos_in_dest != []:
323 | best_actions.append(action)
324 | if not at_goal:
325 | return best_actions
326 | else:
327 | return at_goal
328 |
329 | def get_raw_actions(self, mode='best'):
330 | self.current_problem.initial = self.STATE
331 | if mode == 'all':
332 | actions_exprs = self.current_problem.actions(self.STATE)
333 | elif mode == 'one':
334 | actions_exprs = [self.current_problem.one_action(self.STATE)]
335 | elif mode == 'both':
336 | all_actions = self.current_problem.actions(self.STATE)
337 | best_actions = self.best_logic(all_actions)
338 | return best_actions, all_actions
339 | else:
340 | if self.search_param is not None:
341 | prm = getattr(self.current_problem, self.search_param)
342 | solution = self.search_fn(self.current_problem, prm)
343 | actions_exprs = solution.solution()[0:len(self.current_problem.planes)]
344 | else:
345 | solution = self.search_fn(self.current_problem)
346 | actions_exprs = solution.solution()[0:len(self.current_problem.planes)]
347 | return actions_exprs
348 |
349 | def get_actions(self, mode='best'):
350 | if mode == 'both':
351 | best, all = self.get_raw_actions(mode=mode)
352 | return [self.expr_to_vec(a) for a in best], [self.expr_to_vec(a) for a in all]
353 | else:
354 | return [self.expr_to_vec(a) for a in self.get_raw_actions(mode=mode)]
355 |
356 | def encode_action(self, action_obj):
357 | return torch.from_numpy(self.current_problem.encode_action(action_obj)).long()
358 |
359 | def send_action(self, action_vec):
360 | """
361 | Transform vector into an expression compatible with the problem code.
362 | """
363 | self.current_problem.initial = self.STATE
364 | sym, args = self.vec_to_expr(action_vec)
365 | actions_ = self.current_problem.actions(self.STATE)
366 | final_act = []
367 | for a in actions_:
368 | if a.name == sym and all((str(ar) == at) for ar, at in zip(a.args, args)):
369 | final_act = a
370 | break
371 | assert final_act != []
372 | if final_act.name == 'Load':
373 | cargo, plane, _ = final_act.args
374 | self.cargo_in_idx[cargo] = plane
375 | elif final_act.name == 'Unload':
376 | cargo, _, airpt = final_act.args
377 | if cargo in self.goals_idx and self.goals_idx[cargo] == airpt:
378 | # if this is the goal state, remove from running index
379 | del self.goals_idx[cargo]
380 | del self.cargo_in_idx[cargo]
381 | else: # else add to new airport
382 | self.cargo_in_idx[cargo] = airpt
383 | self.STATE = self.current_problem.result(self.STATE, final_act)
384 | return self.STATE, final_act
385 |
386 | def vec_to_expr(self, _vec):
387 | action_str = EXPRESSIONS[_vec[0]]
388 | args_vec = []
389 | for idx, value in enumerate(_vec[1:]):
390 | args_vec.append(self.lookup_ix_to_expr(value))
391 | return action_str, args_vec
392 |
393 | def pull_state(self):
394 | res = []
395 | for idx, char in enumerate(self.STATE):
396 | if char == 'T':
397 | res.append(self.current_problem.state_map[idx])
398 | return res
399 |
400 | def vec_to_ix(self, _vec):
401 | """
402 | Input target vec representing cross entropy loss target [1 0 2 0 0 0 0]
403 | Returns a one hot version of it as training input [01 00, 100, 000, 000, 000]
404 | :param _vec:
405 | :return:
406 | """
407 | merged = flat(_vec)
408 | action = self.expr_o_h[merged[0]].unsqueeze(0)
409 | ix_ent = []
410 | merged = merged[1:]
411 | if self.mapping is not None:
412 | expr_str = EXPRESSIONS[_vec[0]]
413 | # print(expr_str)
414 | mp = self.mapping[expr_str]
415 | permute = mp['permute'] if 'permute' in mp else None
416 | insert_ = mp['insert'] if 'insert' in mp else None
417 | test = mp['tst'] if 'tst' in mp else None
418 | idx = mp['idx'] if 'idx' in mp else None
419 | if insert_ is not None and test is not None and idx is not None:
420 | text_ix = _vec[1:][test[0]]
421 | for i, insrt_ent in zip(idx, insert_):
422 | if test[1] in self.lookup_ix_to_expr(text_ix):
423 | merged.insert(i, insrt_ent)
424 | else:
425 | merged.append(insrt_ent)
426 | elif insert_ is not None and idx is None:
427 | for insrt_ent in insert_:
428 | merged.append(insrt_ent)
429 |
430 | if permute is not None:
431 | nperm = np.argsort(permute)
432 | merged = np.asarray(merged)[nperm]
433 | else:
434 | # print(len(self.ents_o_h))
435 | for idx in range(len(self.ents_o_h) - len(merged)):
436 | merged.append(0)
437 |
438 | for idx, value in enumerate(merged):
439 | ix_ent.append(self.ents_o_h[idx][value].unsqueeze(0))
440 | return torch.cat([action, torch.cat(ix_ent, 1)], 1)
441 |
442 | def ix_to_vec(self, ix_ent):
443 | action_l = self.expr_o_h.size(-1)
444 | ent_vec = [ix_ent[0:action_l].index(1) + 1]
445 | start = action_l
446 | for idx, ix_size in enumerate(self.one_hot_size):
447 | expr_ix = ix_ent[start:start+ix_size]
448 | if 1 in expr_ix:
449 | ent_vec.append(expr_ix.index(1) + 1)
450 | else:
451 | ent_vec.append(0)
452 | start += ix_size
453 | return ent_vec
454 |
455 | def ix_to_ixs(self, ix_ent, grouping=None):
456 | action_l = self.expr_o_h.size(-1)
457 | ixs_vec = [ix_ent[:, 0:action_l]]
458 | start = action_l
459 | if grouping is None:
460 | grouping = self.one_hot_size
461 | for idx, ix_size in enumerate(grouping):
462 | ixs_vec.append(ix_ent[:, start:start + ix_size])
463 | start += ix_size
464 | return ixs_vec
465 |
466 | def strip_ix_mask(self, ix_input_vec):
467 | phase_size = self.phase_oh.size(-1)
468 | ixs_vec = ix_input_vec[:, phase_size:]
469 | phase_vec = ix_input_vec[:, :phase_size]
470 | return phase_vec, ixs_vec
471 |
472 | def ix_input_to_ixs(self, ix_input_vec, grouping=None):
473 | """
474 |
475 | :param ix_input_vec:
476 | :param grouping:
477 | :return:
478 | """
479 | phase_size = self.phase_oh.size(-1)
480 | ixs_vec = ix_input_vec[:, phase_size:]
481 | return self.ix_to_ixs(ixs_vec, grouping)
482 |
483 | def ix_to_expr(self, ix_input_vec):
484 | vec = self.ix_to_vec(ix_input_vec)
485 | return self.vec_to_expr(vec)
486 |
487 | def expr_to_vec(self, expr_obj):
488 | """
489 | :param expr_obj: Action Expr object Fly(P0, A1, A2)
490 | :return: action vec [0, 0, 1, 2], and argument permutation
491 | """
492 | if type(expr_obj) == Action:
493 | exp_name = expr_obj.name
494 | else:
495 | exp_name = expr_obj.op
496 | ent_vec = [exprs_to_ix[exp_name]]
497 | for arg in expr_obj.args:
498 | ent_vec.append(self.lookup_expr_to_ix(arg))
499 |
500 | return ent_vec
501 |
502 | def gen_state_vec(self, index):
503 | state_expr = self.state[index]
504 | return self.expr_to_vec(state_expr)
505 |
506 | def gen_input_ix(self, _exprs, index):
507 | """
508 | Generate a one_hot vector at a given index
509 | :param vecs:
510 | :param index:
511 | :return:
512 | """
513 | _expr = _exprs[index]
514 | ent_vec = self.expr_to_vec(_expr)
515 | return self.vec_to_ix(ent_vec)
516 |
517 | def human_readable(self, inputs, mask=None) -> str:
518 | if mask is None:
519 | phase = PHASES[self.phase_oh[self.masks[self.current_index]]]
520 | elif mask is False:
521 | phase = ''
522 | elif mask is True:
523 | phase_size = self.phase_oh.size(-1)
524 | phase_ix = inputs[:, :phase_size].squeeze()
525 |
526 | inputs = inputs[:, phase_size:]
527 | phase = ''
528 |
529 | elif isinstance(mask, int):
530 | phase = PHASES[mask]
531 | else:
532 | phase = PHASES[mask.squeeze()[0]]
533 | args = []
534 | txt = ''
535 | if phase:
536 | args.append(phase)
537 | txt += 'Phase {}, '
538 |
539 | if isinstance(inputs, torch.Tensor):
540 | expr = self.ix_to_expr(inputs)
541 | else:
542 | expr = self.vec_to_expr(vec)
543 | args.append(expr)
544 | txt += 'expr {}'
545 | return txt.format(args)
546 |
547 | def make_new_problem(self):
548 | """
549 | Set up new problem object
550 | :return:
551 | """
552 | problem, (e_ix, ix_e), (s, g) = \
553 | mac.arbitrary_ACP2(self.n_airport, self.n_cargo, self.n_plane, encoding=self.encoding)
554 | self.current_problem = problem
555 |
556 | self.STATE = problem.initial
557 | self.INIT_STATE = problem.initial
558 | self.generate_encodings()
559 | self.state, self.goals = s, g
560 | self.goals_idx = {goal.args[0]: goal.args[1] for goal in self.goals}
561 | self.cargo_in_idx = {st.args[0]: st.args[1] for st in s if st.args[0].op.startswith('C')}
562 |
563 | # if self.solve is True:
564 | # todo something about this maybe deepcopy self and fastsolve?
565 | # prm = getattr(problem, self.search_param)
566 | # solution_node = len(self.search_fn(problem, prm).solution())
567 | # else:
568 | # solution_node = self.solve
569 | # len = num_ents * 2 + num_ents + num_ents * 2 + num_ents * 4
570 | # 9 * 6
571 |
572 | state = torch.zeros(len(self.state))
573 | goal = torch.ones(len(self.goals))
574 | plan = torch.ones(self.plan_len) * 2
575 | resp = torch.ones(self.n_cargo * 4) * 3
576 |
577 | self.masks = torch.cat([state, goal, plan, resp], 0).long()
578 | self.ents_to_ix, self.ix_to_ents = e_ix, ix_e
579 | self.current_index = 0
580 | return self.masks
581 |
582 | def len__(self):
583 | if self.current_index >= self.current_problem[0].size(0):
584 | self.make_new_problem()
585 | return len(self.current_problem[1])
586 |
587 | def getitem(self, batch=1):
588 | """Returns a problem, [initial-state, goals]
589 | and a runnable solution object [problem, solution_node]
590 |
591 | Otherwise take the target one_hot class mask in form of
592 | [ent1-type, ent1 ....entN, channel]
593 |
594 | """
595 | if self.current_index >= len(self.masks):
596 | self.make_new_problem()
597 |
598 | phase = self.masks[self.current_index]
599 | if phase == 0:
600 | inputs = self.gen_input_ix(self.state, self.current_index)
601 | elif phase == 1:
602 | inputs = self.gen_input_ix(self.goals, self.current_index - len(self.state))
603 | elif phase == 2:
604 | inputs = self.blnk_vec
605 | else:
606 | inputs = self.blnk_vec
607 |
608 | self.current_index += batch
609 | mask = self.phase_oh[phase].unsqueeze(0)
610 | return inputs, mask
611 |
612 | def getmask(self, batch=1):
613 | if self.current_index >= len(self.masks):
614 | self.make_new_problem()
615 | phase = self.masks[self.current_index]
616 | self.current_index += batch
617 | mask = self.phase_oh[phase].unsqueeze(0)
618 | return mask.cuda() if self.cuda is True else mask
619 |
620 | def getitem_combined(self, batch=1):
621 | inputs, mask = self.getitem(batch)
622 | combined = torch.cat([mask, inputs], 1)
623 | return combined.cuda() if self.cuda is True else combined
624 |
625 |
626 |
627 |
--------------------------------------------------------------------------------
/problem/lp_utils.py:
--------------------------------------------------------------------------------
1 | from .logic import associate
2 | from .utils import expr
3 |
4 |
5 | class FluentState():
6 | """ state object for planning problems as positive and negative fluents
7 |
8 | """
9 |
10 | def __init__(self, pos_list, neg_list):
11 | self.pos = pos_list
12 | self.neg = neg_list
13 |
14 | def sentence(self):
15 | return expr(conjunctive_sentence(self.pos, self.neg))
16 |
17 | def pos_sentence(self):
18 | return expr(conjunctive_sentence(self.pos, []))
19 |
20 |
21 | def conjunctive_sentence(pos_list, neg_list):
22 | """ returns expr conjuntive sentence given positive and negative fluent lists
23 |
24 | :param pos_list: list of fluents
25 | :param neg_list: list of fluents
26 | :return: expr sentence of fluent conjunction
27 | e.g. "At(C1, SFO) ∧ ~At(P1, SFO)"
28 | """
29 | clauses = []
30 | for f in pos_list:
31 | clauses.append(expr("{}".format(f)))
32 | for f in neg_list:
33 | clauses.append(expr("~{}".format(f)))
34 | return associate('&', clauses)
35 |
36 |
37 | def encode_state(fs: FluentState, fluent_map: list) -> str:
38 | """ encode fluents to a string of T/F using mapping
39 |
40 | :param fs: FluentState object
41 | :param fluent_map: ordered list of possible fluents for the problem
42 | :return: str eg. "TFFTFT" string of mapped positive and negative fluents
43 | """
44 | state_tf = []
45 | for fluent in fluent_map:
46 | if fluent in fs.pos:
47 | state_tf.append('T')
48 | else:
49 | state_tf.append('F')
50 | return "".join(state_tf)
51 |
52 |
53 | def decode_state(state: str, fluent_map: list) -> FluentState:
54 | """ decode string of T/F as fluent per mapping
55 |
56 | :param state: str eg. "TFFTFT" string of mapped positive and negative fluents
57 | :param fluent_map: ordered list of possible fluents for the problem
58 | :return: fs: FluentState object
59 |
60 | lengths of state string and fluent_map list must be the same
61 | """
62 | fs = FluentState([], [])
63 | for idx, char in enumerate(state):
64 | if char == 'T':
65 | fs.pos.append(fluent_map[idx])
66 | else:
67 | fs.neg.append(fluent_map[idx])
68 | return fs
69 |
--------------------------------------------------------------------------------
/problem/my_air_cargo_problems.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | from .logic import PropKB
4 | from .planning import Action
5 | from .search import (
6 | Node, Problem
7 | )
8 | from .utils import expr
9 | from .lp_utils import (
10 | FluentState, encode_state, decode_state,
11 | )
12 | from .my_planning_graph import PlanningGraph
13 | from functools import lru_cache
14 | from random import shuffle
15 |
16 |
17 | class AirCargoProblem(Problem):
18 | def __init__(self, cargos, planes, airports, initial: FluentState, goal: list):
19 | """
20 |
21 | :param cargos: list of str
22 | cargos in the problem
23 | :param planes: list of str
24 | planes in the problem
25 | :param airports: list of str
26 | airports in the problem
27 | :param initial: FluentState object
28 | positive and negative literal fluents (as expr) describing initial state
29 | :param goal: list of expr
30 | literal fluents required for goal test
31 | """
32 | self.state_map = initial.pos + initial.neg
33 | self.initial_state_TF = encode_state(initial, self.state_map)
34 | Problem.__init__(self, self.initial_state_TF, goal=goal)
35 | self.cargos = cargos
36 | self.planes = planes
37 | self.airports = airports
38 | self.actions_list = self.get_actions()
39 |
40 | def get_actions(self):
41 | """
42 | This method creates concrete actions (no variables) for all actions in the problem
43 | domain action schema and turns them into complete Action objects as defined in the
44 | aimacode.planning module. It is computationally expensive to call this method directly;
45 | however, it is called in the constructor and the results cached in the `actions_list` property.
46 |
47 | Returns:
48 | ----------
49 | list
50 | list of Action objects
51 | """
52 | def load_actions():
53 | """Create all concrete Load actions and return a list
54 |
55 | :return: list of Action objects
56 | """
57 | loads = []
58 | for p in self.planes:
59 | for a in self.airports:
60 | for c in self.cargos:
61 | precond_pos = [expr("At({}, {})".format(c, a)),
62 | expr("At({}, {})".format(p, a))]
63 | # remove this for tests where plane can load more than cargo
64 | precond_neg = [expr("In({}, {})".format(c1, p)) for c1 in self.cargos]
65 | effect_add = [expr("In({}, {})".format(c, p))]
66 | effect_rem = [expr("At({}, {})".format(c, a))]
67 | act = Action(expr("Load({}, {}, {})".format(c, p, a)),
68 | [precond_pos, precond_neg],
69 | [effect_add, effect_rem])
70 | loads.append(act)
71 |
72 | return loads
73 |
74 | def unload_actions():
75 | """Create all concrete Unload actions and return a list
76 |
77 | :return: list of Action objects
78 | """
79 | unloads = []
80 | # create all Unload ground actions from the domain Unload action
81 | for p in self.planes:
82 | for a in self.airports:
83 | for c in self.cargos:
84 | precond_pos = [expr("In({}, {})".format(c, p)),
85 | expr("At({}, {})".format(p, a))]
86 |
87 | effect_add = [expr("At({}, {})".format(c, a))]
88 | effect_rem = [expr("In({}, {})".format(c, p))]
89 | act = Action(expr("Unload({}, {}, {})".format(c, p, a)),
90 | [precond_pos, []],
91 | [effect_add, effect_rem])
92 | unloads.append(act)
93 |
94 | return unloads
95 |
96 | def fly_actions():
97 | """Create all concrete Fly actions and return a list
98 |
99 | :return: list of Action objects
100 | """
101 | flys = []
102 |
103 | for fr in self.airports:
104 | for to in self.airports:
105 | if fr != to:
106 | for p in self.planes:
107 | precond_pos = [expr("At({}, {})".format(p, fr))]
108 | # precond_neg = []
109 | effect_add = [expr("At({}, {})".format(p, to))]
110 | effect_rem = [expr("At({}, {})".format(p, fr))]
111 | fly = Action(expr("Fly({}, {}, {})".format(p, fr, to)),
112 | [precond_pos, []],
113 | [effect_add, effect_rem])
114 | flys.append(fly)
115 | return flys
116 | data = load_actions() + unload_actions() + fly_actions()
117 | shuffle(data)
118 | return data
119 |
120 | def one_action(self, state: str) -> list:
121 | """ Return the actions that can be executed in the given state.
122 |
123 | :param state: str
124 | state represented as T/F string of mapped fluents (state variables)
125 | e.g. 'FTTTFF'
126 | :return: list of Action objects
127 | """
128 | kb = PropKB()
129 | kb.tell(decode_state(state, self.state_map).pos_sentence())
130 |
131 | for action in self.actions_list:
132 | is_possible = True
133 | for clause in action.precond_pos:
134 | if clause not in kb.clauses:
135 | is_possible = False
136 | break
137 | for clause in action.precond_neg:
138 | if clause in kb.clauses:
139 | is_possible = False
140 | break
141 | if is_possible:
142 | return action
143 | return []
144 |
145 | def actions(self, state: str) -> list:
146 | possible_actions = []
147 | kb = PropKB()
148 | kb.tell(decode_state(state, self.state_map).pos_sentence())
149 | for action in self.actions_list:
150 | is_possible = True
151 | for clause in action.precond_pos:
152 | if clause not in kb.clauses:
153 | is_possible = False
154 | break
155 | for clause in action.precond_neg:
156 | if clause in kb.clauses:
157 | is_possible = False
158 | break
159 | if is_possible:
160 | possible_actions.append(action)
161 | return possible_actions
162 |
163 | def result(self, state: str, action: Action):
164 | """ Return the state that results from executing the given
165 | action in the given state. The action must be one of
166 | self.actions(state).
167 |
168 | :param state: state entering node
169 | :param action: Action applied
170 | :return: resulting state after action
171 | """
172 | new_state = FluentState([], [])
173 | old_state = decode_state(state, self.state_map)
174 |
175 | for fluent in old_state.pos:
176 | if fluent not in action.effect_rem:
177 | new_state.pos.append(fluent)
178 |
179 | for fluent in action.effect_add:
180 | if fluent not in new_state.pos:
181 | new_state.pos.append(fluent)
182 |
183 | for fluent in old_state.neg:
184 | if fluent not in action.effect_add:
185 | new_state.neg.append(fluent)
186 |
187 | for fluent in action.effect_rem:
188 | if fluent not in new_state.neg:
189 | new_state.neg.append(fluent)
190 | return encode_state(new_state, self.state_map)
191 |
192 | def goal_test(self, state: str) -> bool:
193 | """ Test the state to see if goal is reached
194 |
195 | :param state: str representing state
196 | :return: bool
197 | """
198 | kb = PropKB()
199 | kb.tell(decode_state(state, self.state_map).pos_sentence())
200 | for clause in self.goal:
201 | if clause not in kb.clauses:
202 | return False
203 | return True
204 |
205 | def h_1(self, node: Node):
206 | # note that this is not a true heuristic
207 | h_const = 1
208 | return h_const
209 |
210 | @lru_cache(maxsize=8192)
211 | def h_pg_levelsum(self, node: Node):
212 | """This heuristic uses a planning graph representation of the problem
213 | state space to estimate the sum of all actions that must be carried
214 | out from the current state in order to satisfy each individual goal
215 | condition.
216 | """
217 | pg = PlanningGraph(self, node.state)
218 | pg_levelsum = pg.h_levelsum()
219 | return pg_levelsum
220 |
221 | @lru_cache(maxsize=8192)
222 | def h_ignore_preconditions(self, node: Node):
223 | """This heuristic estimates the minimum number of actions that must be
224 | carried out from the current state in order to satisfy all of the goal
225 | conditions by ignoring the preconditions required for an action to be
226 | executed.
227 | """
228 | kb = PropKB()
229 | kb.tell(decode_state(node.state, self.state_map).pos_sentence())
230 | return len([c for c in self.goal if c not in kb.clauses])
231 |
232 |
233 | def posneg_helper(cargos, planes, airports, pos_exprs):
234 | """given the positive expressions and entities, generate the negative expressions"""
235 | pos_set = set(pos_exprs)
236 | neg = []
237 | for c in cargos:
238 | for p in planes:
239 | neg.append(expr("In({}, {})".format(c, p)))
240 | for a in airports:
241 | neg.append(expr("At({}, {})".format(c, a)))
242 | for a in airports:
243 | for p in planes:
244 | neg.append(expr("At({}, {})".format(p, a)))
245 | return list(set(neg).difference(pos_set))
246 |
247 | #def ¬(x):
248 | # print(x) well, python does not allow arbitrary symbols is guess.
249 | #pass
250 |
251 | def air_cargo_p1() -> AirCargoProblem:
252 | cargos = ['C1', 'C2']
253 | planes = ['P1', 'P2']
254 | airports = ['JFK', 'SFO']
255 | pos = [expr('At(C1, SFO)'),
256 | expr('At(C2, JFK)'),
257 | expr('At(P1, SFO)'),
258 | expr('At(P2, JFK)')]
259 | neg = posneg_helper(cargos, planes, airports, pos)
260 | init = FluentState(pos, neg)
261 | goal = [expr('At(C1, JFK)'),
262 | expr('At(C2, SFO)')]
263 | return AirCargoProblem(cargos, planes, airports, init, goal)
264 |
265 |
266 | def air_cargo_p2() -> AirCargoProblem:
267 | cargos = ['C1', 'C2', 'C3']
268 | planes = ['P1', 'P2', 'P3']
269 | airports = ['JFK', 'SFO', 'ATL']
270 | pos = [expr('At(C1, SFO)'),
271 | expr('At(C2, JFK)'),
272 | expr('At(C3, ATL)'),
273 | expr('At(P1, SFO)'),
274 | expr('At(P2, JFK)'),
275 | expr('At(P3, ATL)')]
276 | neg = posneg_helper(cargos, planes, airports, pos)
277 | init = FluentState(pos, neg)
278 | goal = [expr('At(C1, JFK)'),
279 | expr('At(C2, SFO)'),
280 | expr('At(C3, SFO)')]
281 | return AirCargoProblem(cargos, planes, airports, init, goal)
282 |
283 |
284 | def air_cargo_p3() -> AirCargoProblem:
285 | cargos = ['C1', 'C2', 'C3', 'C4']
286 | planes = ['P1', 'P2']
287 | airports = ['JFK', 'SFO', 'ATL', 'ORD']
288 | pos = [expr('At(C1, SFO)'),
289 | expr('At(C2, JFK)'),
290 | expr('At(C3, ATL)'),
291 | expr('At(C4, ORD)'),
292 | expr('At(P1, SFO)'),
293 | expr('At(P2, JFK)')]
294 | neg = posneg_helper(cargos, planes, airports, pos)
295 | init = FluentState(pos, neg)
296 | goal = [expr('At(C1, JFK)'),
297 | expr('At(C2, SFO)'),
298 | expr('At(C3, JFK)'),
299 | expr('At(C4, SFO)')]
300 | return AirCargoProblem(cargos, planes, airports, init, goal)
301 |
302 | #####################################################################
303 | ### GENERATORS
304 | #####################################################################
305 | # this is all for generating DNC data
306 |
307 | import numpy as np
308 | import random
309 | from random import randint
310 | random.seed()
311 |
312 |
313 | def bin_gen(num, one_hot_size):
314 | """In case I need to do a binary encoding vs one_hot
315 | """
316 | int(bin(num)[2:].zfill(one_hot_size))
317 |
318 |
319 |
320 | def air_cargo_generator_v1(num_plane, num_airport, num_cargo,
321 | one_hot_size=None, vec_fn=None):
322 | """Init:
323 | each plane must be at airport
324 | each cargo must be at an airport
325 | each cargo has a destination
326 |
327 | Returned tuples of as one hot vecs of equal size
328 | zeros vector represents nothing at that slot.
329 | so 'At(C1, SFO)' with one_hot_size = 4 ->
330 | S, Act, A, P, C
331 | +---------------+
332 | | 1-0 2-0 3-0 |
333 | +---------------+
334 |
335 | 'state' size becomes max A * P * 1h + A * C * 1h (4, 4, 4, 10)
336 | """
337 | if one_hot_size is None:
338 | one_hot_size = max([num_airport, num_cargo, num_plane]) + 3
339 | if vec_fn is None:
340 | vec_fn = np.eye
341 |
342 | assert num_airport <= one_hot_size
343 | encoding_vec = vec_fn(one_hot_size)
344 |
345 | o_h_planes = encoding_vec[:num_plane].astype(int)
346 | o_h_airprt = encoding_vec[:num_airport].astype(int)
347 | o_h_cargos = encoding_vec[:num_cargo].astype(int)
348 | o_h_types = vec_fn(3).astype(int)
349 |
350 | zeros = np.zeros(one_hot_size, dtype=int)
351 | zero_type = np.zeros(3, dtype=int)
352 | zero_vec = np.concatenate([zero_type, zeros], axis=0)
353 |
354 | o_h_state, o_h_goals = [], []
355 | idx_state, idx_goals = [], []
356 | airprt_dict, cargos_dict, planes_dict = {}, {}, {}
357 |
358 | for idx, o_h_airpr in enumerate(o_h_airprt):
359 | o_h_airpr = np.concatenate([o_h_types[0], o_h_airpr], axis=0)
360 | airprt_dict[str(o_h_airpr)] = 'A{}'.format(idx)
361 |
362 | #planes can go wherever
363 | for idx, o_h_plane in enumerate(o_h_planes):
364 | #generate random
365 | airprt_idx = randint(0, num_airport - 1)
366 | #make one_hot vecs
367 | airprt_vec = np.concatenate([o_h_types[0], o_h_airprt[airprt_idx]], axis=0)
368 | plane_vec = np.concatenate([o_h_types[1], o_h_plane], axis=0)
369 | #make final state vecs and add to states.
370 | oh_plane = np.concatenate([airprt_vec, plane_vec, zero_vec], axis=0)
371 | o_h_state.append(oh_plane)
372 | #lookup idxs
373 | pl_ent = 'P{}'.format(idx)
374 | planes_dict[str(plane_vec)] = pl_ent
375 | #make state dicts and add to states.
376 | ar_ent = airprt_dict[str(airprt_vec)]
377 | idx_state.append('At({}, {})'.format(pl_ent, ar_ent))
378 |
379 | # cargos start end end are mutually exclusive
380 | # for porpises of this problem. Erp erp.
381 | for idx, o_h_cargo in enumerate(o_h_cargos):
382 | # generate random
383 | init_idx = random.randint(0, num_airport-1)
384 | allowed_values = set(range(num_airport))
385 | allowed_values.discard(init_idx)
386 | goal_idx = random.choice(list(allowed_values))
387 | # make one_hot vecs
388 | air_in_vec = np.concatenate([o_h_types[0], o_h_airprt[init_idx]], axis=0)
389 | air_gl_vec = np.concatenate([o_h_types[0], o_h_airprt[goal_idx]], axis=0)
390 | cargo_vec = np.concatenate([o_h_types[2], o_h_cargo], axis=0)
391 | # make final state vecs and add to states.
392 | oh_init = np.concatenate([air_in_vec, zero_vec, cargo_vec], axis=0)
393 | oh_goal = np.concatenate([air_gl_vec, zero_vec, cargo_vec], axis=0)
394 | o_h_state.append(oh_init)
395 | o_h_goals.append(oh_goal)
396 | # lookup idxs
397 | cr_ent = 'C{}'.format(idx)
398 | cargos_dict[str(cargo_vec)] = cr_ent
399 | state_ar_ent = airprt_dict[str(air_in_vec)]
400 | goals_ar_ent = airprt_dict[str(air_gl_vec)]
401 | # make state dicts and add to states.
402 | idx_state.append('At({}, {})'.format(cr_ent, state_ar_ent))
403 | idx_goals.append('At({}, {})'.format(cr_ent, goals_ar_ent))
404 |
405 | return [[np.asarray(o_h_state), np.asarray(o_h_goals)],
406 | [idx_state, idx_goals],
407 | [airprt_dict, cargos_dict, planes_dict]]
408 |
409 | def reverse_lookup(problem_ents, o_h_idx):
410 | key = next(key for key, value in problem_ents.items() if np.array_equal(value, o_h_idx))
411 | return key
412 |
413 | noop = np.asarray([0, 0], dtype=int)
414 |
415 |
416 | def air_cargo_generator_v2(num_airport, num_cargo, num_plane,
417 | one_hot_size=None, vec_fn=None):
418 | """Init:
419 | each plane must be at airport
420 | each cargo must be at an airport
421 | each cargo has a destination
422 |
423 | Returned typed state expressions with type and instance of each object
424 | 0-0 indicates
425 | A P C
426 | +---------------+
427 | | 1-0 2-0 3-0 |
428 | +---------------+
429 | """
430 | if one_hot_size is None:
431 | one_hot_size = max([num_airport, num_cargo, num_plane]) + 3
432 | if vec_fn is None:
433 | vec_fn = np.eye
434 |
435 | assert num_airport <= one_hot_size
436 | print(num_airport)
437 | o_h_airprt = np.asarray([[0, i] for i in range(num_airport)], dtype=int)
438 | o_h_planes = np.asarray([[1, i] for i in range(num_plane)], dtype=int)
439 | o_h_cargos = np.asarray([[2, i] for i in range(num_cargo)], dtype=int)
440 |
441 | o_h_state, o_h_goals = [], []
442 | idx_state, idx_goals = [], []
443 | airprt_dict, cargos_dict, planes_dict = {}, {}, {}
444 |
445 | for idx, o_h_airpr in enumerate(o_h_airprt):
446 | airprt_dict['A{}'.format(idx)] = o_h_airpr
447 |
448 | # planes can go wherever
449 | for idx, o_h_plane in enumerate(o_h_planes):
450 | # generate random
451 | airprt_idx = randint(0, num_airport - 1)
452 | # make final state vecs and add to states.
453 | airport_vec = o_h_airprt[airprt_idx]
454 | oh_plane_state = np.concatenate([airport_vec, o_h_plane, noop], axis=0)
455 | o_h_state.append(oh_plane_state)
456 | # lookup idxs
457 | pl_ent = 'P{}'.format(idx)
458 | planes_dict[pl_ent] = o_h_plane
459 | # make state dicts and add to states.
460 | ar_ent = reverse_lookup(airprt_dict, airport_vec)
461 | idx_state.append('At({}, {})'.format(pl_ent, ar_ent))
462 |
463 | # cargos start end end are mutually exclusive
464 | # for porpises of this problem. Erp erp.
465 | for idx, o_h_cargo in enumerate(o_h_cargos):
466 | # generate random
467 | init_idx = random.randint(0, num_airport - 1)
468 | allowed_values = set(range(num_airport))
469 | #
470 | allowed_values.discard(init_idx)
471 | goal_idx = random.choice(list(allowed_values))
472 | # make one_hot vecs
473 | air_in_vec = o_h_airprt[init_idx]
474 | air_gl_vec = o_h_airprt[goal_idx]
475 | # make final state vecs and add to states.
476 | oh_init = np.concatenate([air_in_vec, noop, o_h_cargo], axis=0)
477 | oh_goal = np.concatenate([air_gl_vec, noop, o_h_cargo], axis=0)
478 | o_h_state.append(oh_init)
479 | o_h_goals.append(oh_goal)
480 | # lookup idxs
481 | cr_ent = 'C{}'.format(idx)
482 | cargos_dict[cr_ent] = o_h_cargo
483 | # make state dicts and add to states.
484 | idx_state.append('At({}, {})'.format(cr_ent, reverse_lookup(airprt_dict, air_in_vec)))
485 | idx_goals.append('At({}, {})'.format(cr_ent, reverse_lookup(airprt_dict, air_gl_vec)))
486 |
487 | return [[np.asarray(o_h_state), np.asarray(o_h_goals)],
488 | [idx_state, idx_goals],
489 | [airprt_dict, cargos_dict, planes_dict]]
490 |
491 |
492 | def ent_(label, num):
493 | return '{}{}'.format(label, num)
494 |
495 |
496 | def entity_ix_generator(label, num_ents, start=0):
497 | start += 1
498 | ents_to_ix = {ent_(label, idx): i for idx, i in enumerate(range(start, start + num_ents))}
499 | ix_to_ents = {i: ent_(label, idx) for idx, i in enumerate(range(start, start + num_ents))}
500 | return ents_to_ix, ix_to_ents
501 |
502 |
503 | def entity_2ix_generator(label, type_num, num_ents, start=0):
504 | start += 1
505 | ents_to_ix = {ent_(label, idx): (type_num, i) for idx, i in enumerate(range(start, start + num_ents))}
506 | ix_to_ents = {(type_num, i): ent_(label, idx) for idx, i in enumerate(range(start, start + num_ents))}
507 | return ents_to_ix, ix_to_ents
508 |
509 |
510 | def air_cargo_generator_v3(num_airport, num_cargo, num_plane, encoding=1):
511 | """Init:
512 | each plane must be at airport
513 | each cargo must be at an airport
514 | each cargo has a destination
515 |
516 | Returned typed state expressions with type and instance of each object
517 | 0-0 indicates
518 | A P C
519 | +---------------+
520 | | 1-0 2-0 3-0 |
521 | +---------------+
522 | """
523 | if encoding == 1:
524 | airpt_to_ix, ix_to_airpt = entity_ix_generator("A", num_airport)
525 | cargo_to_ix, ix_to_cargo = entity_ix_generator("C", num_cargo, start=num_airport)
526 | plane_to_ix, ix_to_plane = entity_ix_generator("P", num_plane, start=num_cargo + num_airport)
527 | else:
528 | airpt_to_ix, ix_to_airpt = entity_2ix_generator("A", 1, num_airport)
529 | cargo_to_ix, ix_to_cargo = entity_2ix_generator("C", 2, num_cargo)
530 | plane_to_ix, ix_to_plane = entity_2ix_generator("P", 3, num_plane)
531 |
532 | state_exprs, goal_exprs = [], []
533 | ents_to_ix = {**airpt_to_ix, **cargo_to_ix, **plane_to_ix}
534 | ix_to_ents = {**ix_to_airpt, **ix_to_cargo, **ix_to_plane}
535 |
536 | # planes can go wherever
537 | for idx, plane in ix_to_plane.items():
538 | # find an airport to put the plane at
539 | airprt_idx = random.choice(list(airpt_to_ix.keys()))
540 | state_exprs.append('At({}, {})'.format(plane, airprt_idx))
541 |
542 | # cargos start end end are mutually exclusive
543 | # for porpises of this problem. Erp erp.
544 | for idx, cargo in ix_to_cargo.items():
545 |
546 | # generate random
547 | allowed_values = set(list(airpt_to_ix.keys()))
548 | init_idx = random.choice(list(allowed_values))
549 |
550 | # set a goal airport
551 | allowed_values.discard(init_idx)
552 | goal_idx = random.choice(list(allowed_values))
553 |
554 | # make state dicts and add to states.
555 | state_exprs.append('At({}, {})'.format(cargo, init_idx))
556 | goal_exprs.append('At({}, {})'.format(cargo, goal_idx))
557 |
558 | return [[ents_to_ix, ix_to_ents],
559 | [state_exprs, goal_exprs]]
560 |
561 |
562 |
563 | def arbitrary_ACP(n_airport, n_cargo, n_plane, one_hot_size=None):
564 | """
565 | Generate ACP of arbitrary specified size in a roundabout way.
566 | Problem Object goes to engine,
567 | one_hot vecs go to dnc for training.
568 | """
569 | # create one hot and entity and state dictionaries
570 | o_h, k_b, ent_dic = air_cargo_generator_v2(n_airport, n_cargo, n_plane,
571 | one_hot_size=one_hot_size)
572 |
573 | airprt_dict, cargos_dict, planes_dict = ent_dic
574 | mp_merged = {**airprt_dict, **cargos_dict, **planes_dict}
575 |
576 | airports = list(airprt_dict.keys())
577 | cargos = list(cargos_dict.keys())
578 | planes = list(planes_dict.keys())
579 |
580 | init_exprs, goal_exprs = k_b
581 | o_h_state, o_h_goals = o_h
582 |
583 | pos = [expr(x) for x in init_exprs]
584 | neg = posneg_helper(cargos, planes, airports, pos)
585 | init = FluentState(pos, neg)
586 |
587 | goal = [expr(x) for x in goal_exprs]
588 | acp = AirCargoProblem(cargos, planes, airports, init, goal)
589 | return (acp, o_h_state, o_h_goals, mp_merged)
590 |
591 |
592 | def arbitrary_ACP2(n_airport, n_cargo, n_plane, encoding=1):
593 | """
594 | Generate ACP of arbitrary specified size in a roundabout way.
595 | Problem Object goes to engine,
596 | one_hot vecs go to dnc for training.
597 | """
598 | # create one hot and entity and state dictionaries
599 | (e_ix, ix_e), (state, goals) = air_cargo_generator_v3(n_airport, n_cargo, n_plane, encoding)
600 |
601 | airports = [e for e in e_ix.keys() if 'A' in e]
602 | cargos = [e for e in e_ix.keys() if 'C' in e]
603 | planes = [e for e in e_ix.keys() if 'P' in e]
604 |
605 | pos = [expr(x) for x in state]
606 | neg = posneg_helper(cargos, planes, airports, pos)
607 | init = FluentState(pos, neg)
608 |
609 | goal = [expr(x) for x in goals]
610 | acp = AirCargoProblem(cargos, planes, airports, init, goal)
611 | return acp, (e_ix, ix_e), (pos, goal)
--------------------------------------------------------------------------------
/problem/my_planning_graph.py:
--------------------------------------------------------------------------------
1 | from .planning import Action
2 | from .search import Problem
3 | from .utils import expr
4 | from .lp_utils import decode_state
5 | #import svgwrite
6 |
7 | class PgNode():
8 | """Base class for planning graph nodes.
9 |
10 | includes instance sets common to both types of nodes used in a planning graph
11 | parents: the set of nodes in the previous level
12 | children: the set of nodes in the subsequent level
13 | mutex: the set of sibling nodes that are mutually exclusive with this node
14 | """
15 |
16 | def __init__(self):
17 | self.parents = set()
18 | self.children = set()
19 | self.mutex = set()
20 |
21 | def is_mutex(self, other) -> bool:
22 | """Boolean test for mutual exclusion
23 |
24 | :param other: PgNode
25 | the other node to compare with
26 | :return: bool
27 | True if this node and the other are marked mutually exclusive (mutex)
28 | """
29 | if other in self.mutex:
30 | return True
31 | return False
32 |
33 | def show(self):
34 | """helper print for debugging shows counts of parents, children, siblings
35 |
36 | :return:
37 | print only
38 | """
39 | print("{} parents".format(len(self.parents)))
40 | print("{} children".format(len(self.children)))
41 | print("{} mutex".format(len(self.mutex)))
42 |
43 |
44 | class PgNode_s(PgNode):
45 | """A planning graph node representing a state (literal fluent) from a
46 | planning problem.
47 |
48 | Args:
49 | ----------
50 | symbol : str
51 | A string representing a literal expression from a planning problem
52 | domain.
53 |
54 | is_pos : bool
55 | Boolean flag indicating whether the literal expression is positive or
56 | negative.
57 | """
58 |
59 | def __init__(self, symbol: str, is_pos: bool):
60 | """S-level Planning Graph node constructor
61 |
62 | :param symbol: expr
63 | :param is_pos: bool
64 | Instance variables calculated:
65 | literal: expr
66 | fluent in its literal form including negative operator if applicable
67 | Instance variables inherited from PgNode:
68 | parents: set of nodes connected to this node in previous A level; initially empty
69 | children: set of nodes connected to this node in next A level; initially empty
70 | mutex: set of sibling S-nodes that this node has mutual exclusion with; initially empty
71 | """
72 | PgNode.__init__(self)
73 | self.symbol = symbol
74 | self.is_pos = is_pos
75 | self.__hash = None
76 |
77 | def show(self):
78 | """helper print for debugging shows literal plus counts of parents,
79 | children, siblings
80 |
81 | :return:
82 | print only
83 | """
84 | if self.is_pos:
85 | print("\n*** {}".format(self.symbol))
86 | else:
87 | print("\n*** ~{}".format(self.symbol))
88 | PgNode.show(self)
89 |
90 | def __eq__(self, other):
91 | """equality test for nodes - compares only the literal for equality
92 |
93 | :param other: PgNode_s
94 | :return: bool
95 | """
96 | return (isinstance(other, self.__class__) and
97 | self.is_pos == other.is_pos and
98 | self.symbol == other.symbol)
99 |
100 | def __hash__(self):
101 | self.__hash = self.__hash or hash(self.symbol) ^ hash(self.is_pos)
102 | return self.__hash
103 |
104 |
105 | class PgNode_a(PgNode):
106 | """A-type (action) Planning Graph node - inherited from PgNode """
107 |
108 |
109 | def __init__(self, action: Action):
110 | """A-level Planning Graph node constructor
111 |
112 | :param action: Action
113 | a ground action, i.e. this action cannot contain any variables
114 | Instance variables calculated:
115 | An A-level will always have an S-level as its parent and an S-level as its child.
116 | The preconditions and effects will become the parents and children of the A-level node
117 | However, when this node is created, it is not yet connected to the graph
118 | prenodes: set of *possible* parent S-nodes
119 | effnodes: set of *possible* child S-nodes
120 | is_persistent: bool True if this is a persistence action, i.e. a no-op action
121 | Instance variables inherited from PgNode:
122 | parents: set of nodes connected to this node in previous S level; initially empty
123 | children: set of nodes connected to this node in next S level; initially empty
124 | mutex: set of sibling A-nodes that this node has mutual exclusion with; initially empty
125 | """
126 | PgNode.__init__(self)
127 | self.action = action
128 | self.prenodes = self.precond_s_nodes()
129 | self.effnodes = self.effect_s_nodes()
130 | self.is_persistent = self.prenodes == self.effnodes
131 | self.__hash = None
132 |
133 | def show(self):
134 | """helper print for debugging shows action plus counts of parents, children, siblings
135 |
136 | :return:
137 | print only
138 | """
139 | print("\n*** {!s}".format(self.action))
140 | PgNode.show(self)
141 |
142 | def precond_s_nodes(self):
143 | """precondition literals as S-nodes (represents possible parents for this node).
144 | It is computationally expensive to call this function; it is only called by the
145 | class constructor to populate the `prenodes` attribute.
146 |
147 | :return: set of PgNode_s
148 | """
149 | nodes = set()
150 | for p in self.action.precond_pos:
151 | nodes.add(PgNode_s(p, True))
152 | for p in self.action.precond_neg:
153 | nodes.add(PgNode_s(p, False))
154 | return nodes
155 |
156 | def effect_s_nodes(self):
157 | """effect literals as S-nodes (represents possible children for this node).
158 | It is computationally expensive to call this function; it is only called by the
159 | class constructor to populate the `effnodes` attribute.
160 |
161 | :return: set of PgNode_s
162 | """
163 | nodes = set()
164 | for e in self.action.effect_add:
165 | nodes.add(PgNode_s(e, True))
166 | for e in self.action.effect_rem:
167 | nodes.add(PgNode_s(e, False))
168 | return nodes
169 |
170 | def __eq__(self, other):
171 | """equality test for nodes - compares only the action name for equality
172 |
173 | :param other: PgNode_a
174 | :return: bool
175 | """
176 | return (isinstance(other, self.__class__) and
177 | self.is_persistent == other.is_persistent and
178 | self.action.name == other.action.name and
179 | self.action.args == other.action.args)
180 |
181 | def __hash__(self):
182 | self.__hash = self.__hash or hash(self.action.name) ^ hash(self.action.args)
183 | return self.__hash
184 |
185 |
186 |
187 |
188 | def mutexify(node1: PgNode, node2: PgNode):
189 | """ adds sibling nodes to each other's mutual exclusion (mutex) set. These should be sibling nodes!
190 |
191 | :param node1: PgNode (or inherited PgNode_a, PgNode_s types)
192 | :param node2: PgNode (or inherited PgNode_a, PgNode_s types)
193 | :return:
194 | node mutex sets modified
195 | """
196 | if type(node1) != type(node2):
197 | raise TypeError('Attempted to mutex two nodes of different types')
198 | node1.mutex.add(node2)
199 | node2.mutex.add(node1)
200 |
201 | def crude_inspect(obj):
202 | for prop, value in vars(obj).iteritems():
203 | print(prop)
204 | print(value)
205 |
206 |
207 | class PlanningGraph():
208 | """
209 | A planning graph as described in chapter 10 of the AIMA text. The planning
210 | graph can be used to reason about
211 | """
212 |
213 | def __init__(self, problem: Problem, state: str, serial_planning=True):
214 | """
215 | :param problem: PlanningProblem (or subclass such as AirCargoProblem or HaveCakeProblem)
216 | :param state: str (will be in form TFTTFF... representing fluent states)
217 | :param serial_planning: bool (whether or not to assume that only one
218 | action can occur at a time)
219 | Instance variable calculated:
220 | fs: FluentState
221 | the state represented as positive and negative fluent literal lists
222 | all_actions: list of the PlanningProblem valid ground actions
223 | combined with calculated no-op actions
224 | s_levels: list of sets of PgNode_s, where each set in the list
225 | represents an S-level in the planning graph
226 | a_levels: list of sets of PgNode_a, where each set in the list
227 | represents an A-level in the planning graph
228 | """
229 | self.problem = problem
230 | self.fs = decode_state(state, problem.state_map)
231 | self.serial = serial_planning
232 | self.all_actions = self.problem.actions_list + self.noop_actions(self.problem.state_map)
233 | self.s_levels = []
234 | self.a_levels = []
235 | self.create_graph()
236 |
237 | def show_all(self):
238 | print("\n------problem-----")
239 | print(self.problem)
240 | print("\n------FluenStates------")
241 | print("------ pos------")
242 | print(self.fs.pos)
243 | print("------ neg------")
244 | print(self.fs.neg)
245 |
246 | print("\n------S_levels------")
247 | for sl in self.s_levels:
248 | for s in sl:
249 | s.show()
250 |
251 | print("\n------A_levels------")
252 | for a in self.a_levels:
253 | print(a)
254 |
255 | print("\n------actions------")
256 | for action in self.all_actions:
257 | print(action)
258 | print("-------------------------------------------------------")
259 |
260 | def noop_actions(self, literal_list):
261 | """create persistent action for each possible fluent
262 |
263 | "No-Op" actions are virtual actions (i.e., actions that only exist in
264 | the planning graph, not in the planning problem domain) that operate
265 | on each fluent (literal expression) from the problem domain. No op
266 | actions "pass through" the literal expressions from one level of the
267 | planning graph to the next.
268 |
269 | The no-op action list requires both a positive and a negative action
270 | for each literal expression. Positive no-op actions require the literal
271 | as a positive precondition and add the literal expression as an effect
272 | in the output, and negative no-op actions require the literal as a
273 | negative precondition and remove the literal expression as an effect in
274 | the output.
275 |
276 | This function should only be called by the class constructor.
277 |
278 | :param literal_list:
279 | :return: list of Action
280 | """
281 | action_list = []
282 | for fluent in literal_list:
283 | act1 = Action(expr("Noop_pos({})".format(fluent)), ([fluent], []), ([fluent], []))
284 | action_list.append(act1)
285 | act2 = Action(expr("Noop_neg({})".format(fluent)), ([], [fluent]), ([], [fluent]))
286 | action_list.append(act2)
287 | return action_list
288 |
289 | def create_graph(self):
290 | """ build a Planning Graph as described in Russell-Norvig 3rd Ed 10.3 or 2nd Ed 11.4
291 |
292 | The S0 initial level has been implemented for you. It has no parents and includes all of
293 | the literal fluents that are part of the initial state passed to the constructor. At the
294 | start of a problem planning search, this will be the same as the initial state of the prob.
295 | However, the planning graph can be built from any state in the Planning Problem
296 |
297 | This function should only be called by the class constructor.
298 |
299 | :return:
300 | builds the graph by filling s_levels[] and a_levels[]
301 | lists with node sets for each level
302 | """
303 | # the graph should only be built during class construction
304 | if (len(self.s_levels) != 0) or (len(self.a_levels) != 0):
305 | raise Exception(
306 | 'Planning Graph already created; construct a new planning graph for each new state in the planning sequence')
307 |
308 | # initialize S0 to literals in initial state provided.
309 | leveled = False
310 | level = 0
311 | self.s_levels.append(set()) # S0 set of s_nodes - empty to start
312 | # for each fluent in the initial state, add the correct literal PgNode_s
313 | for literal in self.fs.pos:
314 | self.s_levels[level].add(PgNode_s(literal, True))
315 | for literal in self.fs.neg:
316 | self.s_levels[level].add(PgNode_s(literal, False))
317 |
318 | # continue to build the graph alternating A, S levels until
319 | # last two S levels contain the same literals, i.e. until it is "leveled"
320 | while not leveled:
321 | self.add_action_level(level)
322 | self.update_a_mutex(self.a_levels[level])
323 |
324 | level += 1
325 | self.add_literal_level(level)
326 | self.update_s_mutex(self.s_levels[level])
327 |
328 | if self.s_levels[level] == self.s_levels[level - 1]:
329 | leveled = True
330 |
331 | def get_levels(self):
332 | return [[s.symbol for s in s_level] for s_level in self.s_levels]
333 |
334 | def add_action_level(self, level):
335 | """ add an A (action) level to the Planning Graph
336 |
337 | :param level: int
338 | the level number alternates S0, A0, S1, A1, S2, .... etc
339 | the level number is also used as the
340 | index for the node set lists self.a_levels[] and self.s_levels[]
341 | :return:
342 | adds A nodes to the current level in self.a_levels[level]
343 | """
344 | self.a_levels.append(set())
345 | for s_literal in self.s_levels[level]:
346 | for action in self.all_actions:
347 | node_a = PgNode_a(action)
348 | if s_literal in node_a.prenodes:
349 | node_a.parents.add(s_literal)
350 | s_literal.children.add(node_a)
351 | self.a_levels[level].add(node_a)
352 |
353 | def add_literal_level(self, level):
354 | """ add an S (literal) level to the Planning Graph
355 |
356 | :param level: int
357 | the level number alternates S0, A0, S1, A1, S2, .... etc the level number is also used as the
358 | index for the node set lists self.a_levels[] and self.s_levels[]
359 | :return:
360 | adds S nodes to the current level in self.s_levels[level]
361 | """
362 | self.s_levels.append(set())
363 | for a_node in self.a_levels[level-1]:
364 | for eff_literal in a_node.effnodes:
365 | a_node.children.add(eff_literal)
366 | eff_literal.parents.add(a_node)
367 | self.s_levels[level].add(eff_literal)
368 |
369 | def update_a_mutex(self, nodeset):
370 | """ Determine and update sibling mutual exclusion for A-level nodes
371 |
372 | Mutex action tests section from 3rd Ed. 10.3 or 2nd Ed. 11.4
373 | A mutex relation holds between two actions a given level
374 | if the planning graph is a serial planning graph and the pair are nonpersistence actions
375 | or if any of the three conditions hold between the pair:
376 | Inconsistent Effects
377 | Interference
378 | Competing needs
379 |
380 | :param nodeset: set of PgNode_a (siblings in the same level)
381 | :return:
382 | mutex set in each PgNode_a in the set is appropriately updated
383 | """
384 | nodelist = list(nodeset)
385 | for i, n1 in enumerate(nodelist[:-1]):
386 | for n2 in nodelist[i + 1:]:
387 | if (self.serialize_actions(n1, n2) or
388 | self.inconsistent_effects_mutex(n1, n2) or
389 | self.interference_mutex(n1, n2) or
390 | self.competing_needs_mutex(n1, n2)):
391 | mutexify(n1, n2)
392 |
393 | def serialize_actions(self, node_a1: PgNode_a, node_a2: PgNode_a) -> bool:
394 | """
395 | Test a pair of actions for mutual exclusion, returning True if the
396 | planning graph is serial, and if either action is persistent; otherwise
397 | return False. Two serial actions are mutually exclusive if they are
398 | both non-persistent.
399 |
400 | :param node_a1: PgNode_a
401 | :param node_a2: PgNode_a
402 | :return: bool
403 | """
404 | if not self.serial:
405 | return False
406 | if node_a1.is_persistent or node_a2.is_persistent:
407 | return False
408 | return True
409 |
410 | def inconsistent_effects_mutex(self, node_a1: PgNode_a, node_a2: PgNode_a) -> bool:
411 | """
412 | Test a pair of actions for inconsistent effects, returning True if
413 | one action negates an effect of the other, and False otherwise.
414 |
415 | HINT: The Action instance associated with an action node is accessible
416 | through the PgNode_a.action attribute. See the Action class
417 | documentation for details on accessing the effects and preconditions of
418 | an action.
419 |
420 | :param node_a1: PgNode_a
421 | :param node_a2: PgNode_a
422 | :return: bool
423 | """
424 | a1 = node_a1.action
425 | a2 = node_a2.action
426 |
427 | if (any(set(a1.effect_add).intersection(a2.effect_rem)) or
428 | any(set(a2.effect_add).intersection(a1.effect_rem)) or
429 | any(set(a1.effect_rem).intersection(a2.effect_add)) or
430 | any(set(a2.effect_rem).intersection(a1.effect_add))):
431 | return True
432 | return False
433 |
434 | def interference_mutex(self, node_a1: PgNode_a, node_a2: PgNode_a) -> bool:
435 | """
436 | Test a pair of actions for mutual exclusion, returning True if the
437 | effect of one action is the negation of a precondition of the other.
438 |
439 | HINT: The Action instance associated with an action node is accessible
440 | through the PgNode_a.action attribute. See the Action class
441 | documentation for details on accessing the effects and preconditions of
442 | an action.
443 |
444 | :param node_a1: PgNode_a
445 | :param node_a2: PgNode_a
446 | :return: bool
447 | """
448 | a1 = node_a1.action
449 | a2 = node_a2.action
450 |
451 | if (any(set(a1.effect_add).intersection(a2.precond_neg)) or
452 | any(set(a2.effect_add).intersection(a1.precond_neg)) or
453 | any(set(a1.effect_rem).intersection(a2.precond_pos)) or
454 | any(set(a2.effect_rem).intersection(a1.precond_pos))):
455 | return True
456 | return False
457 |
458 | def competing_needs_mutex(self, node_a1: PgNode_a, node_a2: PgNode_a) -> bool:
459 | """
460 | Test a pair of actions for mutual exclusion, returning True if one of
461 | the precondition of one action is mutex with a precondition of the
462 | other action.
463 |
464 | :param node_a1: PgNode_a
465 | :param node_a2: PgNode_a
466 | :return: bool
467 | """
468 | for pre_a1 in node_a1.parents:
469 | for pre_a2 in node_a2.parents:
470 | if pre_a1.is_mutex(pre_a2):
471 | return True
472 | return False
473 |
474 | def update_s_mutex(self, nodeset: set):
475 | """ Determine and update sibling mutual exclusion for S-level nodes
476 |
477 | Mutex action tests section from 3rd Ed. 10.3 or 2nd Ed. 11.4
478 | A mutex relation holds between literals at a given level
479 | if either of the two conditions hold between the pair:
480 | Negation
481 | Inconsistent support
482 |
483 | :param nodeset: set of PgNode_a (siblings in the same level)
484 | :return:
485 | mutex set in each PgNode_a in the set is appropriately updated
486 | """
487 | nodelist = list(nodeset)
488 | for i, n1 in enumerate(nodelist[:-1]):
489 | for n2 in nodelist[i + 1:]:
490 | if self.negation_mutex(n1, n2) or self.inconsistent_support_mutex(n1, n2):
491 | mutexify(n1, n2)
492 |
493 | def negation_mutex(self, node_s1: PgNode_s, node_s2: PgNode_s) -> bool:
494 | """
495 | Test a pair of state literals for mutual exclusion, returning True if
496 | one node is the negation of the other, and False otherwise.
497 |
498 | HINT: Look at the PgNode_s.__eq__ defines the notion of equivalence for
499 | literal expression nodes, and the class tracks whether the literal is
500 | positive or negative.
501 |
502 | :param node_s1: PgNode_s
503 | :param node_s2: PgNode_s
504 | :return: bool
505 | """
506 | same_symbol = node_s1.symbol == node_s2.symbol
507 | not_negated = node_s1.is_pos == node_s1.is_pos
508 | return same_symbol and not_negated
509 |
510 | def inconsistent_support_mutex(self, node_s1: PgNode_s, node_s2: PgNode_s):
511 | """
512 | Test a pair of state literals for mutual exclusion, returning True if
513 | there are no actions that could achieve the two literals at the same
514 | time, and False otherwise.
515 | In other words, the two literal nodes are
516 | mutex if all of the actions that could achieve the first literal node
517 | are pairwise mutually exclusive with all of the actions that could
518 | achieve the second literal node.
519 |
520 | HINT: The PgNode.is_mutex method can be used to test whether two nodes
521 | are mutually exclusive.
522 |
523 | :param node_s1: PgNode_s
524 | :param node_s2: PgNode_s
525 | :return: bool
526 | """
527 | if not any(node_s1.parents.intersection(node_s2.parents)):
528 | for pre_a1 in node_s1.parents:
529 | if all((pre_a2.is_mutex(pre_a1) for pre_a2 in node_s2.parents)):
530 | return True
531 |
532 | return False
533 |
534 | def h_levelsum(self) -> int:
535 | """The sum of the level costs of the individual goals
536 | (admissible if goals independent)
537 | for each goal in the problem, determine the level cost,
538 | then add them together
539 | :return: int
540 | """
541 | level_sum = 0
542 | for goal in self.problem.goal:
543 | for idx, level in enumerate(self.s_levels):
544 | if any(((s.symbol == goal and s.is_pos is True) for s in level)):
545 | level_sum += idx
546 | break
547 | return level_sum
548 |
549 |
550 |
551 | def visualize_plan(self):
552 | """
553 | {symbol: {h_top }}
554 |
555 | :return:
556 | """
557 | outpath = 'graph.svg'
558 | h = 500
559 | w = 1000
560 | def swap_rows(r1, r2):
561 | pass
562 |
563 | def add_node(s_node):
564 | pass
565 |
566 | num_levels = len(self.s_levels)
567 | #max nodes are always at lest level since noops
568 | #cause any possible state to remain so total monotonically increases
569 | max_nodes = len(self.s_levels[-1])
570 |
571 | all_levels = zip(self.s_levels, self.a_levels)
572 |
573 | level_h = h // max_nodes
574 | level_w = w // num_levels
575 |
576 | bx_h = level_h // 2
577 | bx_w = level_w // 2
578 |
579 | # dwg = svgwrite.Drawing(h, w)
580 |
581 | sv_struc = {}
582 |
583 | for idx, (s_level, a_level) in enumerate(self.s_levels):
584 | num_elements = len(s_level)
585 |
586 | #start with last level of graph:
587 | for idx, s_level in enumerate(reversed(self.s_levels)):
588 | num_elements = len(s_level)
589 | for s_node in s_level:
590 | if s_node.symbol in sv_struc:
591 | if s_node.is_pos in sv_struc[s_node.symbol]:
592 | sv_node_h = sv_struc[s_node.symbol][s_node.is_pos]
593 | # create rectangle at x = level_idx *
594 | # svgwrite.shapes.Rect(insert=( ,idx* ), size=(bx_w, bx_h))
595 |
596 | pass
597 |
598 |
599 |
600 |
601 |
602 |
603 |
604 |
605 |
606 |
607 |
608 |
609 |
--------------------------------------------------------------------------------
/problem/planning.py:
--------------------------------------------------------------------------------
1 | """Planning (Chapters 10-11)
2 | """
3 |
4 | from .utils import Expr
5 |
6 |
7 | class Action:
8 | """
9 | Defines an action schema using preconditions and effects
10 | Use this to describe actions in PDDL
11 | action is an Expr where variables are given as arguments(args)
12 | Precondition and effect are both lists with positive and negated literals
13 | Example:
14 | precond_pos = [expr("Human(person)"), expr("Hungry(Person)")]
15 | precond_neg = [expr("Eaten(food)")]
16 | effect_add = [expr("Eaten(food)")]
17 | effect_rem = [expr("Hungry(person)")]
18 | eat = Action(expr("Eat(person, food)"), [precond_pos, precond_neg], [effect_add, effect_rem])
19 | """
20 |
21 | def __init__(self, action, precond, effect):
22 | self.name = action.op
23 | self.args = action.args
24 | self.precond_pos = precond[0]
25 | self.precond_neg = precond[1]
26 | self.effect_add = effect[0]
27 | self.effect_rem = effect[1]
28 |
29 | def __call__(self, kb, args):
30 | return self.act(kb, args)
31 |
32 | def __str__(self):
33 | return "{}{!s}".format(self.name, self.args)
34 |
35 | def substitute(self, e, args):
36 | """Replaces variables in expression with their
37 | respective Propostional symbol"""
38 | new_args = list(e.args)
39 | for num, x in enumerate(e.args):
40 | for i in range(len(self.args)):
41 | if self.args[i] == x:
42 | new_args[num] = args[i]
43 | return Expr(e.op, *new_args)
44 |
45 | def check_precond(self, kb, args):
46 | """Checks if the precondition is satisfied in the current state"""
47 | # check for positive clauses
48 | for clause in self.precond_pos:
49 | if self.substitute(clause, args) not in kb.clauses:
50 | return False
51 | # check for negative clauses
52 | for clause in self.precond_neg:
53 | if self.substitute(clause, args) in kb.clauses:
54 | return False
55 | return True
56 |
57 | def act(self, kb, args):
58 | """Executes the action on the state's kb"""
59 | # check if the preconditions are satisfied
60 | if not self.check_precond(kb, args):
61 | raise Exception("Action pre-conditions not satisfied")
62 | # remove negative literals
63 | for clause in self.effect_rem:
64 | kb.retract(self.substitute(clause, args))
65 | # add positive literals
66 | for clause in self.effect_add:
67 | kb.tell(self.substitute(clause, args))
68 |
--------------------------------------------------------------------------------
/problem/search.py:
--------------------------------------------------------------------------------
1 | """Search (Chapters 3-4)
2 |
3 | The way to use this code is to subclass Problem to create a class of problems,
4 | then create problem instances and solve them with calls to the various search
5 | functions."""
6 |
7 | from .utils import (
8 | is_in, memoize, print_table, Stack, FIFOQueue, PriorityQueue, name
9 | )
10 |
11 | import sys
12 |
13 | infinity = float('inf')
14 |
15 | # ______________________________________________________________________________
16 |
17 |
18 | class Problem:
19 |
20 | """The abstract class for a formal problem. You should subclass
21 | this and implement the methods actions and result, and possibly
22 | __init__, goal_test, and path_cost. Then you will create instances
23 | of your subclass and solve them with the various search functions."""
24 |
25 | def __init__(self, initial, goal=None):
26 | """The constructor specifies the initial state, and possibly a goal
27 | state, if there is a unique goal. Your subclass's constructor can add
28 | other arguments."""
29 | self.initial = initial
30 | self.goal = goal
31 |
32 | def actions(self, state):
33 | """Return the actions that can be executed in the given
34 | state. The result would typically be a list, but if there are
35 | many actions, consider yielding them one at a time in an
36 | iterator, rather than building them all at once."""
37 | raise NotImplementedError
38 |
39 | def result(self, state, action):
40 | """Return the state that results from executing the given
41 | action in the given state. The action must be one of
42 | self.actions(state)."""
43 | raise NotImplementedError
44 |
45 | def goal_test(self, state):
46 | """Return True if the state is a goal. The default method compares the
47 | state to self.goal or checks for state in self.goal if it is a
48 | list, as specified in the constructor. Override this method if
49 | checking against a single self.goal is not enough."""
50 | if isinstance(self.goal, list):
51 | return is_in(state, self.goal)
52 | else:
53 | return state == self.goal
54 |
55 | def path_cost(self, c, state1, action, state2):
56 | """Return the cost of a solution path that arrives at state2 from
57 | state1 via action, assuming cost c to get up to state1. If the problem
58 | is such that the path doesn't matter, this function will only look at
59 | state2. If the path does matter, it will consider c and maybe state1
60 | and action. The default method costs 1 for every step in the path."""
61 | return c + 1
62 |
63 | def value(self, state):
64 | """For optimization problems, each state has a value. Hill-climbing
65 | and related algorithms try to maximize this value."""
66 | raise NotImplementedError
67 | # ______________________________________________________________________________
68 |
69 |
70 | class Node:
71 |
72 | """A node in a search tree. Contains a pointer to the parent (the node
73 | that this is a successor of) and to the actual state for this node. Note
74 | that if a state is arrived at by two paths, then there are two nodes with
75 | the same state. Also includes the action that got us to this state, and
76 | the total path_cost (also known as g) to reach the node. Other functions
77 | may add an f and h value; see best_first_graph_search and astar_search for
78 | an explanation of how the f and h values are handled. You will not need to
79 | subclass this class."""
80 |
81 | def __init__(self, state, parent=None, action=None, path_cost=0):
82 | "Create a search tree Node, derived from a parent by an action."
83 | self.state = state
84 | self.parent = parent
85 | self.action = action
86 | self.path_cost = path_cost
87 | self.depth = 0
88 | if parent:
89 | self.depth = parent.depth + 1
90 |
91 | def __repr__(self):
92 | return "" % (self.state,)
93 |
94 | def __lt__(self, node):
95 | return self.state < node.state
96 |
97 | def expand(self, problem):
98 | "List the nodes reachable in one step from this node."
99 | return [self.child_node(problem, action)
100 | for action in problem.actions(self.state)]
101 |
102 | def child_node(self, problem, action):
103 | "[Figure 3.10]"
104 | next = problem.result(self.state, action)
105 | return Node(next, self, action,
106 | problem.path_cost(self.path_cost, self.state,
107 | action, next))
108 |
109 | def solution(self):
110 | "Return the sequence of actions to go from the root to this node."
111 | return [node.action for node in self.path()[1:]]
112 |
113 | def path(self):
114 | "Return a list of nodes forming the path from the root to this node."
115 | node, path_back = self, []
116 | while node:
117 | path_back.append(node)
118 | node = node.parent
119 | return list(reversed(path_back))
120 |
121 | # We want for a queue of nodes in breadth_first_search or
122 | # astar_search to have no duplicated states, so we treat nodes
123 | # with the same state as equal. [Problem: this may not be what you
124 | # want in other contexts.]
125 |
126 | def __eq__(self, other):
127 | return isinstance(other, Node) and self.state == other.state
128 |
129 | def __hash__(self):
130 | return hash(self.state)
131 |
132 | # ______________________________________________________________________________
133 | # Uninformed Search algorithms
134 |
135 |
136 | def tree_search(problem, frontier):
137 | """Search through the successors of a problem to find a goal.
138 | The argument frontier should be an empty queue.
139 | Don't worry about repeated paths to a state. [Figure 3.7]"""
140 | frontier.append(Node(problem.initial))
141 | while frontier:
142 | node = frontier.pop()
143 | if problem.goal_test(node.state):
144 | return node
145 | frontier.extend(node.expand(problem))
146 | return None
147 |
148 |
149 | def graph_search(problem, frontier):
150 | """Search through the successors of a problem to find a goal.
151 | The argument frontier should be an empty queue.
152 | If two paths reach a state, only use the first one. [Figure 3.7]"""
153 | frontier.append(Node(problem.initial))
154 | explored = set()
155 | while frontier:
156 | node = frontier.pop()
157 | if problem.goal_test(node.state):
158 | return node
159 | explored.add(node.state)
160 | frontier.extend(child for child in node.expand(problem)
161 | if child.state not in explored and
162 | child not in frontier)
163 | return None
164 |
165 |
166 | def breadth_first_tree_search(problem):
167 | "Search the shallowest nodes in the search tree first."
168 | return tree_search(problem, FIFOQueue())
169 |
170 |
171 | def depth_first_tree_search(problem):
172 | "Search the deepest nodes in the search tree first."
173 | return tree_search(problem, Stack())
174 |
175 |
176 | def depth_first_graph_search(problem):
177 | "Search the deepest nodes in the search tree first."
178 | return graph_search(problem, Stack())
179 |
180 |
181 | def breadth_first_search(problem):
182 | "[Figure 3.11]"
183 | node = Node(problem.initial)
184 | if problem.goal_test(node.state):
185 | return node
186 | frontier = FIFOQueue()
187 | frontier.append(node)
188 | explored = set()
189 | while frontier:
190 | node = frontier.pop()
191 | explored.add(node.state)
192 | for child in node.expand(problem):
193 | if child.state not in explored and child not in frontier:
194 | if problem.goal_test(child.state):
195 | return child
196 | frontier.append(child)
197 | return None
198 |
199 |
200 | def best_first_graph_search(problem, f):
201 | """Search the nodes with the lowest f scores first.
202 | You specify the function f(node) that you want to minimize; for example,
203 | if f is a heuristic estimate to the goal, then we have greedy best
204 | first search; if f is node.depth then we have breadth-first search.
205 | There is a subtlety: the line "f = memoize(f, 'f')" means that the f
206 | values will be cached on the nodes as they are computed. So after doing
207 | a best first search you can examine the f values of the path returned."""
208 | f = memoize(f, 'f')
209 | node = Node(problem.initial)
210 | if problem.goal_test(node.state):
211 | return node
212 | frontier = PriorityQueue(min, f)
213 | frontier.append(node)
214 | explored = set()
215 | while frontier:
216 | node = frontier.pop()
217 | if problem.goal_test(node.state):
218 | return node
219 | explored.add(node.state)
220 | for child in node.expand(problem):
221 | if child.state not in explored and child not in frontier:
222 | frontier.append(child)
223 | elif child in frontier:
224 | incumbent = frontier[child]
225 | if f(child) < f(incumbent):
226 | # del frontier[incumbent]
227 | frontier.append(child)
228 | return None
229 |
230 |
231 | def uniform_cost_search(problem):
232 | "[Figure 3.14]"
233 | return best_first_graph_search(problem, lambda node: node.path_cost)
234 |
235 |
236 | def depth_limited_search(problem, limit=50):
237 | "[Figure 3.17]"
238 | def recursive_dls(node, problem, limit):
239 | if problem.goal_test(node.state):
240 | return node
241 | elif limit == 0:
242 | return 'cutoff'
243 | else:
244 | cutoff_occurred = False
245 | for child in node.expand(problem):
246 | result = recursive_dls(child, problem, limit - 1)
247 | if result == 'cutoff':
248 | cutoff_occurred = True
249 | elif result is not None:
250 | return result
251 | return 'cutoff' if cutoff_occurred else None
252 |
253 | # Body of depth_limited_search:
254 | return recursive_dls(Node(problem.initial), problem, limit)
255 |
256 |
257 | def iterative_deepening_search(problem):
258 | "[Figure 3.18]"
259 | for depth in range(sys.maxsize):
260 | result = depth_limited_search(problem, depth)
261 | if result != 'cutoff':
262 | return result
263 |
264 | # ______________________________________________________________________________
265 | # Informed (Heuristic) Search
266 |
267 | greedy_best_first_graph_search = best_first_graph_search
268 | # Greedy best-first search is accomplished by specifying f(n) = h(n).
269 |
270 |
271 | def astar_search(problem, h=None):
272 | """A* search is best-first graph search with f(n) = g(n)+h(n).
273 | You need to specify the h function when you call astar_search, or
274 | else in your Problem subclass."""
275 | h = memoize(h or problem.h, 'h')
276 | return best_first_graph_search(problem, lambda n: n.path_cost + h(n))
277 |
278 | # ______________________________________________________________________________
279 | # Other search algorithms
280 |
281 |
282 | def recursive_best_first_search(problem, h=None):
283 | "[Figure 3.26]"
284 | h = memoize(h or problem.h, 'h')
285 |
286 | def RBFS(problem, node, flimit):
287 | if problem.goal_test(node.state):
288 | return node, 0 # (The second value is immaterial)
289 | successors = node.expand(problem)
290 | if len(successors) == 0:
291 | return None, infinity
292 | for s in successors:
293 | s.f = max(s.path_cost + h(s), node.f)
294 | while True:
295 | # Order by lowest f value
296 | successors.sort(key=lambda x: x.f)
297 | best = successors[0]
298 | if best.f > flimit:
299 | return None, best.f
300 | if len(successors) > 1:
301 | alternative = successors[1].f
302 | else:
303 | alternative = infinity
304 | result, best.f = RBFS(problem, best, min(flimit, alternative))
305 | if result is not None:
306 | return result, best.f
307 |
308 | node = Node(problem.initial)
309 | node.f = h(node)
310 | result, bestf = RBFS(problem, node, infinity)
311 | return result
312 |
313 | # ______________________________________________________________________________
314 |
315 | # Code to compare searchers on various problems.
316 |
317 |
318 | class InstrumentedProblem(Problem):
319 |
320 | """Delegates to a problem, and keeps statistics."""
321 |
322 | def __init__(self, problem):
323 | self.problem = problem
324 | self.succs = self.goal_tests = self.states = 0
325 | self.found = None
326 |
327 | def actions(self, state):
328 | self.succs += 1
329 | return self.problem.actions(state)
330 |
331 | def result(self, state, action):
332 | self.states += 1
333 | return self.problem.result(state, action)
334 |
335 | def goal_test(self, state):
336 | self.goal_tests += 1
337 | result = self.problem.goal_test(state)
338 | if result:
339 | self.found = state
340 | return result
341 |
342 | def path_cost(self, c, state1, action, state2):
343 | return self.problem.path_cost(c, state1, action, state2)
344 |
345 | def value(self, state):
346 | return self.problem.value(state)
347 |
348 | def __getattr__(self, attr):
349 | return getattr(self.problem, attr)
350 |
351 | def __repr__(self):
352 | return '<%4d/%4d/%4d/%s>' % (self.succs, self.goal_tests,
353 | self.states, str(self.found)[:4])
354 |
355 |
356 | def compare_searchers(problems, header,
357 | searchers=[breadth_first_tree_search,
358 | breadth_first_search,
359 | depth_first_graph_search,
360 | iterative_deepening_search,
361 | depth_limited_search,
362 | recursive_best_first_search]):
363 | def do(searcher, problem):
364 | p = InstrumentedProblem(problem)
365 | searcher(p)
366 | return p
367 | table = [[name(s)] + [do(s, p) for p in problems] for s in searchers]
368 | print_table(table, header)
369 |
--------------------------------------------------------------------------------
/problem/utils.py:
--------------------------------------------------------------------------------
1 | """Provides some utilities widely used by other modules"""
2 |
3 | import bisect
4 | import collections
5 | import collections.abc
6 | import functools
7 | import operator
8 | import os.path
9 | import random
10 | import math
11 |
12 | import heapq
13 | from collections import defaultdict
14 |
15 | # ______________________________________________________________________________
16 | # Functions on Sequences and Iterables
17 |
18 |
19 | def sequence(iterable):
20 | "Coerce iterable to sequence, if it is not already one."
21 | return (iterable if isinstance(iterable, collections.abc.Sequence)
22 | else tuple(iterable))
23 |
24 |
25 | def removeall(item, seq):
26 | """Return a copy of seq (or string) with all occurences of item removed."""
27 | if isinstance(seq, str):
28 | return seq.replace(item, '')
29 | else:
30 | return [x for x in seq if x != item]
31 |
32 |
33 | def unique(seq): # TODO: replace with set
34 | """Remove duplicate elements from seq. Assumes hashable elements."""
35 | return list(set(seq))
36 |
37 |
38 | def count(seq):
39 | """Count the number of items in sequence that are interpreted as true."""
40 | return sum(bool(x) for x in seq)
41 |
42 |
43 | def product(numbers):
44 | """Return the product of the numbers, e.g. product([2, 3, 10]) == 60"""
45 | result = 1
46 | for x in numbers:
47 | result *= x
48 | return result
49 |
50 |
51 | def first(iterable, default=None):
52 | "Return the first element of an iterable or the next element of a generator; or default."
53 | try:
54 | return iterable[0]
55 | except IndexError:
56 | return default
57 | except TypeError:
58 | return next(iterable, default)
59 |
60 |
61 | def is_in(elt, seq):
62 | """Similar to (elt in seq), but compares with 'is', not '=='."""
63 | return any(x is elt for x in seq)
64 |
65 | # ______________________________________________________________________________
66 | # argmin and argmax
67 |
68 | identity = lambda x: x
69 |
70 | argmin = min
71 | argmax = max
72 |
73 |
74 | def argmin_random_tie(seq, key=identity):
75 | """Return a minimum element of seq; break ties at random."""
76 | return argmin(shuffled(seq), key=key)
77 |
78 |
79 | def argmax_random_tie(seq, key=identity):
80 | "Return an element with highest fn(seq[i]) score; break ties at random."
81 | return argmax(shuffled(seq), key=key)
82 |
83 |
84 | def shuffled(iterable):
85 | "Randomly shuffle a copy of iterable."
86 | items = list(iterable)
87 | random.shuffle(items)
88 | return items
89 |
90 |
91 |
92 | # ______________________________________________________________________________
93 | # Statistical and mathematical functions
94 |
95 |
96 | def histogram(values, mode=0, bin_function=None):
97 | """Return a list of (value, count) pairs, summarizing the input values.
98 | Sorted by increasing value, or if mode=1, by decreasing count.
99 | If bin_function is given, map it over values first."""
100 | if bin_function:
101 | values = map(bin_function, values)
102 |
103 | bins = {}
104 | for val in values:
105 | bins[val] = bins.get(val, 0) + 1
106 |
107 | if mode:
108 | return sorted(list(bins.items()), key=lambda x: (x[1], x[0]),
109 | reverse=True)
110 | else:
111 | return sorted(bins.items())
112 |
113 |
114 | def dotproduct(X, Y):
115 | """Return the sum of the element-wise product of vectors X and Y."""
116 | return sum(x * y for x, y in zip(X, Y))
117 |
118 |
119 | def element_wise_product(X, Y):
120 | """Return vector as an element-wise product of vectors X and Y"""
121 | assert len(X) == len(Y)
122 | return [x * y for x, y in zip(X, Y)]
123 |
124 |
125 | def matrix_multiplication(X_M, *Y_M):
126 | """Return a matrix as a matrix-multiplication of X_M and arbitary number of matrices *Y_M"""
127 |
128 | def _mat_mult(X_M, Y_M):
129 | """Return a matrix as a matrix-multiplication of two matrices X_M and Y_M
130 | matrix_multiplication([[1, 2, 3],
131 | [2, 3, 4]],
132 | [[3, 4],
133 | [1, 2],
134 | [1, 0]])
135 | [[8, 8],[13, 14]]
136 | """
137 | assert len(X_M[0]) == len(Y_M)
138 |
139 | result = [[0 for i in range(len(Y_M[0]))] for j in range(len(X_M))]
140 | for i in range(len(X_M)):
141 | for j in range(len(Y_M[0])):
142 | for k in range(len(Y_M)):
143 | result[i][j] += X_M[i][k] * Y_M[k][j]
144 | return result
145 |
146 | result = X_M
147 | for Y in Y_M:
148 | result = _mat_mult(result, Y)
149 |
150 | return result
151 |
152 |
153 | def vector_to_diagonal(v):
154 | """Converts a vector to a diagonal matrix with vector elements
155 | as the diagonal elements of the matrix"""
156 | diag_matrix = [[0 for i in range(len(v))] for j in range(len(v))]
157 | for i in range(len(v)):
158 | diag_matrix[i][i] = v[i]
159 |
160 | return diag_matrix
161 |
162 |
163 | def vector_add(a, b):
164 | """Component-wise addition of two vectors."""
165 | return tuple(map(operator.add, a, b))
166 |
167 |
168 |
169 | def scalar_vector_product(X, Y):
170 | """Return vector as a product of a scalar and a vector"""
171 | return [X * y for y in Y]
172 |
173 |
174 | def scalar_matrix_product(X, Y):
175 | return [scalar_vector_product(X, y) for y in Y]
176 |
177 |
178 | def inverse_matrix(X):
179 | """Inverse a given square matrix of size 2x2"""
180 | assert len(X) == 2
181 | assert len(X[0]) == 2
182 | det = X[0][0] * X[1][1] - X[0][1] * X[1][0]
183 | assert det != 0
184 | inv_mat = scalar_matrix_product(1.0/det, [[X[1][1], -X[0][1]], [-X[1][0], X[0][0]]])
185 |
186 | return inv_mat
187 |
188 |
189 | def probability(p):
190 | "Return true with probability p."
191 | return p > random.uniform(0.0, 1.0)
192 |
193 |
194 | def weighted_sample_with_replacement(seq, weights, n):
195 | """Pick n samples from seq at random, with replacement, with the
196 | probability of each element in proportion to its corresponding
197 | weight."""
198 | sample = weighted_sampler(seq, weights)
199 |
200 | return [sample() for _ in range(n)]
201 |
202 |
203 | def weighted_sampler(seq, weights):
204 | "Return a random-sample function that picks from seq weighted by weights."
205 | totals = []
206 | for w in weights:
207 | totals.append(w + totals[-1] if totals else w)
208 |
209 | return lambda: seq[bisect.bisect(totals, random.uniform(0, totals[-1]))]
210 |
211 |
212 | def rounder(numbers, d=4):
213 | "Round a single number, or sequence of numbers, to d decimal places."
214 | if isinstance(numbers, (int, float)):
215 | return round(numbers, d)
216 | else:
217 | constructor = type(numbers) # Can be list, set, tuple, etc.
218 | return constructor(rounder(n, d) for n in numbers)
219 |
220 |
221 | def num_or_str(x):
222 | """The argument is a string; convert to a number if
223 | possible, or strip it.
224 | """
225 | try:
226 | return int(x)
227 | except ValueError:
228 | try:
229 | return float(x)
230 | except ValueError:
231 | return str(x).strip()
232 |
233 |
234 | def normalize(dist):
235 | """Multiply each number by a constant such that the sum is 1.0"""
236 | if isinstance(dist, dict):
237 | total = sum(dist.values())
238 | for key in dist:
239 | dist[key] = dist[key] / total
240 | assert 0 <= dist[key] <= 1, "Probabilities must be between 0 and 1."
241 | return dist
242 | total = sum(dist)
243 | return [(n / total) for n in dist]
244 |
245 |
246 | def clip(x, lowest, highest):
247 | """Return x clipped to the range [lowest..highest]."""
248 | return max(lowest, min(x, highest))
249 |
250 |
251 | def sigmoid(x):
252 | """Return activation value of x with sigmoid function"""
253 | return 1/(1 + math.exp(-x))
254 |
255 |
256 | def step(x):
257 | """Return activation value of x with sign function"""
258 | return 1 if x >= 0 else 0
259 |
260 | try: # math.isclose was added in Python 3.5; but we might be in 3.4
261 | from math import isclose
262 | except ImportError:
263 | def isclose(a, b, rel_tol=1e-09, abs_tol=0.0):
264 | "Return true if numbers a and b are close to each other."
265 | return abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)
266 |
267 | # ______________________________________________________________________________
268 | # Misc Functions
269 |
270 |
271 | # TODO: Use functools.lru_cache memoization decorator
272 |
273 |
274 | def memoize(fn, slot=None):
275 | """Memoize fn: make it remember the computed value for any argument list.
276 | If slot is specified, store result in that slot of first argument.
277 | If slot is false, store results in a dictionary."""
278 | if slot:
279 | def memoized_fn(obj, *args):
280 | if hasattr(obj, slot):
281 | return getattr(obj, slot)
282 | else:
283 | val = fn(obj, *args)
284 | setattr(obj, slot, val)
285 | return val
286 | else:
287 | def memoized_fn(*args):
288 | if args not in memoized_fn.cache:
289 | memoized_fn.cache[args] = fn(*args)
290 | return memoized_fn.cache[args]
291 |
292 | memoized_fn.cache = {}
293 |
294 | return memoized_fn
295 |
296 |
297 | def name(obj):
298 | "Try to find some reasonable name for the object."
299 | return (getattr(obj, 'name', 0) or getattr(obj, '__name__', 0) or
300 | getattr(getattr(obj, '__class__', 0), '__name__', 0) or
301 | str(obj))
302 |
303 |
304 | def isnumber(x):
305 | "Is x a number?"
306 | return hasattr(x, '__int__')
307 |
308 |
309 | def issequence(x):
310 | "Is x a sequence?"
311 | return isinstance(x, collections.abc.Sequence)
312 |
313 |
314 | def print_table(table, header=None, sep=' ', numfmt='%g'):
315 | """Print a list of lists as a table, so that columns line up nicely.
316 | header, if specified, will be printed as the first row.
317 | numfmt is the format for all numbers; you might want e.g. '%6.2f'.
318 | (If you want different formats in different columns,
319 | don't use print_table.) sep is the separator between columns."""
320 | justs = ['rjust' if isnumber(x) else 'ljust' for x in table[0]]
321 |
322 | if header:
323 | table.insert(0, header)
324 |
325 | table = [[numfmt.format(x) if isnumber(x) else x for x in row]
326 | for row in table]
327 |
328 | sizes = list(
329 | map(lambda seq: max(map(len, seq)),
330 | list(zip(*[map(str, row) for row in table]))))
331 |
332 | for row in table:
333 | print(sep.join(getattr(
334 | str(x), j)(size) for (j, size, x) in zip(justs, sizes, row)))
335 |
336 |
337 | def AIMAFile(components, mode='r'):
338 | "Open a file based at the AIMA root directory."
339 | aima_root = os.path.dirname(__file__)
340 |
341 | aima_file = os.path.join(aima_root, *components)
342 |
343 | return open(aima_file)
344 |
345 |
346 | def DataFile(name, mode='r'):
347 | "Return a file in the AIMA /aimacode-data directory."
348 | return AIMAFile(['aimacode-data', name], mode)
349 |
350 |
351 | # ______________________________________________________________________________
352 | # Expressions
353 |
354 | # See https://docs.python.org/3/reference/expressions.html#operator-precedence
355 | # See https://docs.python.org/3/reference/datamodel.html#special-method-names
356 |
357 | class Expr(object):
358 | """A mathematical expression with an operator and 0 or more arguments.
359 | op is a str like '+' or 'sin'; args are Expressions.
360 | Expr('x') or Symbol('x') creates a symbol (a nullary Expr).
361 | Expr('-', x) creates a unary; Expr('+', x, 1) creates a binary."""
362 |
363 | def __init__(self, op, *args):
364 | self.op = str(op)
365 | self.args = args
366 | self.__hash = None
367 |
368 | # Operator overloads
369 | def __neg__(self): return Expr('-', self)
370 | def __pos__(self): return Expr('+', self)
371 | def __invert__(self): return Expr('~', self)
372 | def __add__(self, rhs): return Expr('+', self, rhs)
373 | def __sub__(self, rhs): return Expr('-', self, rhs)
374 | def __mul__(self, rhs): return Expr('*', self, rhs)
375 | def __pow__(self, rhs): return Expr('**',self, rhs)
376 | def __mod__(self, rhs): return Expr('%', self, rhs)
377 | def __and__(self, rhs): return Expr('&', self, rhs)
378 | def __xor__(self, rhs): return Expr('^', self, rhs)
379 | def __rshift__(self, rhs): return Expr('>>', self, rhs)
380 | def __lshift__(self, rhs): return Expr('<<', self, rhs)
381 | def __truediv__(self, rhs): return Expr('/', self, rhs)
382 | def __floordiv__(self, rhs): return Expr('//', self, rhs)
383 | def __matmul__(self, rhs): return Expr('@', self, rhs)
384 |
385 | def __or__(self, rhs):
386 | "Allow both P | Q, and P |'==>'| Q."
387 | if isinstance(rhs, Expression):
388 | return Expr('|', self, rhs)
389 | else:
390 | return PartialExpr(rhs, self)
391 |
392 | # Reverse operator overloads
393 | def __radd__(self, lhs): return Expr('+', lhs, self)
394 | def __rsub__(self, lhs): return Expr('-', lhs, self)
395 | def __rmul__(self, lhs): return Expr('*', lhs, self)
396 | def __rdiv__(self, lhs): return Expr('/', lhs, self)
397 | def __rpow__(self, lhs): return Expr('**', lhs, self)
398 | def __rmod__(self, lhs): return Expr('%', lhs, self)
399 | def __rand__(self, lhs): return Expr('&', lhs, self)
400 | def __rxor__(self, lhs): return Expr('^', lhs, self)
401 | def __ror__(self, lhs): return Expr('|', lhs, self)
402 | def __rrshift__(self, lhs): return Expr('>>', lhs, self)
403 | def __rlshift__(self, lhs): return Expr('<<', lhs, self)
404 | def __rtruediv__(self, lhs): return Expr('/', lhs, self)
405 | def __rfloordiv__(self, lhs): return Expr('//', lhs, self)
406 | def __rmatmul__(self, lhs): return Expr('@', lhs, self)
407 |
408 | def __call__(self, *args):
409 | "Call: if 'f' is a Symbol, then f(0) == Expr('f', 0)."
410 | if self.args:
411 | raise ValueError('can only do a call for a Symbol, not an Expr')
412 | else:
413 | return Expr(self.op, *args)
414 |
415 | # Equality and repr
416 | def __eq__(self, other):
417 | "'x == y' evaluates to True or False; does not build an Expr."
418 | return (isinstance(other, Expr)
419 | and self.op == other.op
420 | and self.args == other.args)
421 |
422 | def __hash__(self):
423 | self.__hash = self.__hash or hash(self.op) ^ hash(self.args)
424 | return self.__hash
425 |
426 | def __repr__(self):
427 | op = self.op
428 | args = [str(arg) for arg in self.args]
429 | if op.isidentifier(): # f(x) or f(x, y)
430 | return '{}({})'.format(op, ', '.join(args)) if args else op
431 | elif len(args) == 1: # -x or -(x + 1)
432 | return op + args[0]
433 | else: # (x - y)
434 | opp = (' ' + op + ' ')
435 | return '(' + opp.join(args) + ')'
436 |
437 | # An 'Expression' is either an Expr or a Number.
438 | # Symbol is not an explicit type; it is any Expr with 0 args.
439 |
440 | Number = (int, float, complex)
441 | Expression = (Expr, Number)
442 |
443 |
444 | def Symbol(name):
445 | "A Symbol is just an Expr with no args."
446 | return Expr(name)
447 |
448 |
449 | def symbols(names):
450 | "Return a tuple of Symbols; names is a comma/whitespace delimited str."
451 | return tuple(Symbol(name) for name in names.replace(',', ' ').split())
452 |
453 |
454 | def subexpressions(x):
455 | "Yield the subexpressions of an Expression (including x itself)."
456 | yield x
457 | if isinstance(x, Expr):
458 | for arg in x.args:
459 | yield from subexpressions(arg)
460 |
461 |
462 | def arity(expression):
463 | "The number of sub-expressions in this expression."
464 | if isinstance(expression, Expr):
465 | return len(expression.args)
466 | else: # expression is a number
467 | return 0
468 |
469 | # For operators that are not defined in Python, we allow new InfixOps:
470 |
471 |
472 | class PartialExpr:
473 | """Given 'P |'==>'| Q, first form PartialExpr('==>', P), then combine with Q."""
474 | def __init__(self, op, lhs): self.op, self.lhs = op, lhs
475 | def __or__(self, rhs): return Expr(self.op, self.lhs, rhs)
476 | def __repr__(self): return "PartialExpr('{}', {})".format(self.op, self.lhs)
477 |
478 |
479 | def expr(x):
480 | """Shortcut to create an Expression. x is a str in which:
481 | - identifiers are automatically defined as Symbols.
482 | - ==> is treated as an infix |'==>'|, as are <== and <=>.
483 | If x is already an Expression, it is returned unchanged. Example:
484 | >>> expr('P & Q ==> Q')
485 | ((P & Q) ==> Q)
486 | """
487 | if isinstance(x, str):
488 | return eval(expr_handle_infix_ops(x), defaultkeydict(Symbol))
489 | else:
490 | return x
491 |
492 | infix_ops = '==> <== <=>'.split()
493 |
494 |
495 | def expr_handle_infix_ops(x):
496 | """Given a str, return a new str with ==> replaced by |'==>'|, etc.
497 | >>> expr_handle_infix_ops('P ==> Q')
498 | "P |'==>'| Q"
499 | """
500 | for op in infix_ops:
501 | x = x.replace(op, '|' + repr(op) + '|')
502 | return x
503 |
504 |
505 | class defaultkeydict(collections.defaultdict):
506 | """Like defaultdict, but the default_factory is a function of the key.
507 | >>> d = defaultkeydict(len); d['four']
508 | 4
509 | """
510 | def __missing__(self, key):
511 | self[key] = result = self.default_factory(key)
512 | return result
513 |
514 |
515 | # ______________________________________________________________________________
516 | # Queues: Stack, FIFOQueue, PriorityQueue
517 |
518 | # TODO: Possibly use queue.Queue, queue.PriorityQueue
519 | # TODO: Priority queues may not belong here -- see treatment in search.py
520 |
521 |
522 | class Queue:
523 |
524 | """Queue is an abstract class/interface. There are three types:
525 | Stack(): A Last In First Out Queue.
526 | FIFOQueue(): A First In First Out Queue.
527 | PriorityQueue(order, f): Queue in sorted order (default min-first).
528 | Each type supports the following methods and functions:
529 | q.append(item) -- add an item to the queue
530 | q.extend(items) -- equivalent to: for item in items: q.append(item)
531 | q.pop() -- return the top item from the queue
532 | len(q) -- number of items in q (also q.__len())
533 | item in q -- does q contain item?
534 | Note that isinstance(Stack(), Queue) is false, because we implement stacks
535 | as lists. If Python ever gets interfaces, Queue will be an interface."""
536 |
537 | def __init__(self):
538 | raise NotImplementedError
539 |
540 | def extend(self, items):
541 | for item in items:
542 | self.append(item)
543 |
544 |
545 | def Stack():
546 | """Return an empty list, suitable as a Last-In-First-Out Queue."""
547 | return []
548 |
549 |
550 | class FIFOQueue(Queue):
551 |
552 | """A First-In-First-Out Queue."""
553 |
554 | def __init__(self):
555 | self.A = []
556 | self.start = 0
557 |
558 | def append(self, item):
559 | self.A.append(item)
560 |
561 | def __len__(self):
562 | return len(self.A) - self.start
563 |
564 | def extend(self, items):
565 | self.A.extend(items)
566 |
567 | def pop(self):
568 | e = self.A[self.start]
569 | self.start += 1
570 | if self.start > 5 and self.start > len(self.A) / 2:
571 | self.A = self.A[self.start:]
572 | self.start = 0
573 | return e
574 |
575 | def __contains__(self, item):
576 | return item in self.A[self.start:]
577 |
578 |
579 | class PriorityQueue(Queue):
580 | """A queue in which the minimum element (as determined by f and
581 | order) is returned first. Also supports dict-like lookup.
582 |
583 | MODIFIED FROM AIMA VERSION
584 | - Use heapq
585 | - Use an additional dict to track membership
586 | - remove __delitem__ (AIMA version contains error)
587 | """
588 |
589 | def __init__(self, order=None, f=lambda x: x):
590 | self.A = []
591 | self._A = defaultdict(lambda: 0)
592 | self.f = f
593 |
594 | def append(self, item):
595 | heapq.heappush(self.A, (self.f(item), item))
596 | self._A[item] += 1
597 |
598 | def __len__(self):
599 | return len(self.A)
600 |
601 | def pop(self):
602 | _, item = heapq.heappop(self.A)
603 | self._A[item] -= 1
604 | return item
605 |
606 | def __contains__(self, item):
607 | return self._A[item] > 0
608 |
609 | def __getitem__(self, key):
610 | if self._A[key] > 0:
611 | return key
612 |
613 | # ______________________________________________________________________________
614 | # Useful Shorthands
615 |
616 |
617 | class Bool(int):
618 | """Just like `bool`, except values display as 'T' and 'F' instead of 'True' and 'False'"""
619 | __str__ = __repr__ = lambda self: 'T' if self else 'F'
620 |
621 | T = Bool(True)
622 | F = Bool(False)
623 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import dnc_arity_list as dnc
2 | import numpy as np
3 | from utils import running_avg, flat, save, _variable
4 | import utils as u
5 | import torch
6 | import torch.nn as nn
7 | from problem import generators_v2 as gen
8 | import torch.optim as optim
9 | import time, random
10 | from visualize import logger as sl
11 | import os
12 | import losses as L
13 | from arg import args
14 |
15 | random.seed()
16 | batch_size = 1
17 | dnc_args = {'num_layers': 2,
18 | 'num_read_heads': 2,
19 | 'hidden_size': 250,
20 | 'num_write_heads': 1,
21 | 'memory_size': 100, #50
22 | 'batch_size': batch_size}
23 |
24 |
25 | def generate_data_spec(args, num_ents=2, solve=True):
26 | if args.typed is 1:
27 | ix_size = [4, args.max_ents, 4, args.max_ents, 4, args.max_ents]
28 | encoding = 2
29 | else:
30 | ix_state = args.max_ents * 3
31 | ix_size = [ix_state, ix_state, ix_state]
32 | encoding = 1
33 | return {'num_plane': num_ents, 'num_cargo': num_ents, 'num_airport': num_ents,
34 | 'one_hot_size': ix_size, 'plan_phase': num_ents * 3, 'cuda': args.cuda,
35 | 'batch_size': 1, 'encoding': encoding, 'solve': solve, 'mapping': None}
36 |
37 |
38 | def setupLSTM(args):
39 | data = gen.AirCargoData(**generate_data_spec(args))
40 | dnc_args['output_size'] = data.nn_in_size # output has no phase component
41 | dnc_args['word_len'] = data.nn_out_size
42 | print(dnc_args)
43 | # input_size = self.output_size + word_len * num_read_heads
44 | Dnc = dnc.VanillaLSTM(batch_size=1, num_layers=2, input_size=data.nn_in_size,
45 | output_size=data.nn_out_size, hidden_size=250, num_reads=2)
46 | previous_out, (ho1, hc1), (ho2, hc2) = Dnc.init_state()
47 | if args.opt == 'adam':
48 | optimizer = optim.Adam([{'params': Dnc.parameters()}, {'params': ho1},
49 | {'params': hc1}, {'params': hc2}], lr=args.lr)
50 | else:
51 | optimizer = optim.SGD([{'params': Dnc.parameters()}, {'params': ho1},
52 | {'params': hc1}, {'params': hc2}], lr=args.lr)
53 |
54 | lstm_state = (previous_out, (ho1, hc1), (ho2, hc2))
55 | return data, Dnc, optimizer, lstm_state
56 |
57 |
58 | def setupDNC(args):
59 | """
60 | Loader for files or setup new DNC and optimizer
61 | :param args:
62 | :return:
63 | """
64 | if args.algo == 'lstm':
65 | return setupLSTM(args)
66 | data = gen.AirCargoData(**generate_data_spec(args))
67 | dnc_args['output_size'] = data.nn_in_size # output has no phase component
68 | dnc_args['word_len'] = data.nn_out_size
69 | print('dnc_args:\n', dnc_args, '\n')
70 | if args.load == '':
71 | Dnc = dnc.DNC(**dnc_args)
72 | if args.opt == 'adam':
73 | optimizer = optim.Adam(Dnc.parameters(), lr=args.lr)
74 | elif args.opt == 'sgd':
75 | optimizer = optim.SGD(Dnc.parameters(), lr=args.lr)
76 | else:
77 | optimizer = None
78 | else:
79 | model_path, optim_path = u.get_chkpt(args.load)
80 | print('loading', model_path)
81 | Dnc = dnc.DNC(**dnc_args)
82 | Dnc.load_state_dict(torch.load(model_path))
83 |
84 | optimizer = optim.Adam(Dnc.parameters(), lr=args.lr)
85 | if os.path.exists(optim_path):
86 | optimizer.load_state_dict(torch.load(optim_path))
87 |
88 | if args.cuda is True:
89 | Dnc = Dnc.cuda()
90 | lstm_state = Dnc.init_rnn()
91 | return data, Dnc, optimizer, lstm_state
92 |
93 |
94 | def tick(n_total, n_correct, truth, pred):
95 | n_total += 1
96 | n_correct += 1 if truth == pred else 0
97 | sl.global_step += 1
98 | return n_total, n_correct
99 |
100 |
101 | def train_qa2(args, data, DNC, optimizer):
102 | """
103 | I am jacks liver. This is a sanity test
104 |
105 | 0 - describe state.
106 | 1 - describe goal.
107 | 2 - do actions.
108 | 3 - ask some questions
109 | :param args:
110 | :return:
111 | """
112 | criterion = nn.CrossEntropyLoss()
113 | cum_correct, cum_total = [], []
114 |
115 | for trial in range(args.iters):
116 | phase_masks = data.make_new_problem()
117 | n_total, n_correct, loss = 0, 0, 0
118 | dnc_state = DNC.init_state(grad=False)
119 | optimizer.zero_grad()
120 |
121 | for phase_idx in phase_masks:
122 | if phase_idx == 0 or phase_idx == 1:
123 | inputs = _variable(data.getitem_combined())
124 | logits, dnc_state = DNC(inputs, dnc_state)
125 | else:
126 | final_moves = data.get_actions(mode='one')
127 | if final_moves == []:
128 | break
129 | data.send_action(final_moves[0])
130 | mask = data.phase_oh[2].unsqueeze(0)
131 | inputs2 = _variable(torch.cat([mask, data.vec_to_ix(final_moves[0])], 1))
132 | logits, dnc_state = DNC(inputs2, dnc_state)
133 |
134 | for _ in range(args.num_tests):
135 | # ask where is ---?
136 | if args.zero_at == 'step':
137 | optimizer.zero_grad()
138 | masked_input, mask_chunk, ground_truth = data.masked_input()
139 | logits, dnc_state = DNC(_variable(masked_input), dnc_state)
140 | expanded_logits = data.ix_input_to_ixs(logits)
141 |
142 | # losses
143 | lstep = L.action_loss(expanded_logits, ground_truth, criterion, log=True)
144 | if args.opt_at == 'problem':
145 | loss += lstep
146 | else:
147 | lstep.backward(retain_graph=args.ret_graph)
148 | optimizer.step()
149 | loss = lstep
150 |
151 | # update counters
152 | prediction = u.get_prediction(expanded_logits, [3, 4])
153 | n_total, n_correct = tick(n_total, n_correct, mask_chunk, prediction)
154 |
155 | if args.opt_at == 'problem':
156 | loss.backward(retain_graph=args.ret_graph)
157 | optimizer.step()
158 | sl.writer.add_scalar('losses.end', loss.data[0], sl.global_step)
159 |
160 | cum_total.append(n_total)
161 | cum_correct.append(n_correct)
162 | sl.writer.add_scalar('recall.pct_correct', n_correct / n_total, sl.global_step)
163 | print("trial: {}, step:{}, accy {:0.4f}, cum_score {:0.4f}, loss: {:0.4f}".format(
164 | trial, sl.global_step, n_correct / n_total, running_avg(cum_correct, cum_total), loss.data[0]))
165 | return DNC, optimizer, dnc_state, running_avg(cum_correct, cum_total)
166 |
167 |
168 | def random_seq(args, data, DNC, lstm_state, optimizer):
169 | pass
170 |
171 |
172 | def train_rl(args, data, DNC, lstm_state, optimizer):
173 | """
174 |
175 | :param args:
176 | :param data:
177 | :param DNC: a tuple of value and action networks
178 | :param lstm_state:
179 | :param optimizer:
180 | :return:
181 | """
182 | for trial in range(args.iters):
183 | start_prob = time.time()
184 | phase_masks = data.make_new_problem()
185 |
186 |
187 | pass
188 |
189 |
190 | def train_plan(args, data, DNC, lstm_state, optimizer):
191 | """
192 | Things to test after some iterations:
193 | - on planning phase and on
194 |
195 | with goals - chose a goal and work toward that
196 | :param args:
197 | :return:
198 | """
199 | criterion = nn.CrossEntropyLoss().cuda() if args.cuda is True else nn.CrossEntropyLoss()
200 | cum_correct, cum_total, prob_times, n_success = [], [], [], 0
201 | penalty = 1.1
202 |
203 | for trial in range(args.iters):
204 | start_prob = time.time()
205 | phase_masks = data.make_new_problem()
206 | n_total, n_correct, prev_action, loss, stats = 0, 0, None, 0, []
207 | dnc_state = DNC.init_state(grad=False)
208 | lstm_state = DNC.init_rnn(grad=False) # lstm_state,
209 | optimizer.zero_grad()
210 |
211 | for phase_idx in phase_masks:
212 |
213 | if phase_idx == 0 or phase_idx == 1:
214 | inputs = _variable(data.getitem_combined())
215 | logits, dnc_state, lstm_state = DNC(inputs, lstm_state, dnc_state)
216 | _, prev_action = data.strip_ix_mask(logits)
217 |
218 | elif phase_idx == 2:
219 | mask = _variable(data.getmask())
220 | inputs = torch.cat([mask, prev_action], 1)
221 | logits, dnc_state, lstm_state = DNC(inputs, lstm_state, dnc_state)
222 | _, prev_action = data.strip_ix_mask(logits)
223 |
224 | else:
225 | # sample from best moves
226 | actions_star, all_actions = data.get_actions(mode='both')
227 | if not actions_star:
228 | break
229 | if args.zero_at == 'step':
230 | optimizer.zero_grad()
231 |
232 | mask = data.getmask()
233 | prev_action = prev_action.cuda() if args.cuda is True else prev_action
234 | pr = u.depackage(prev_action)
235 |
236 | final_inputs = _variable(torch.cat([mask, pr], 1))
237 | logits, dnc_state, lstm_state = DNC(final_inputs, lstm_state, dnc_state)
238 | exp_logits = data.ix_input_to_ixs(logits)
239 |
240 | guided = random.random() < args.beta
241 | # thing 1
242 | if guided: # guided loss
243 | final_action, lstep = L.naive_loss(exp_logits, actions_star, criterion, log=True)
244 | else: # pick own move
245 | final_action, lstep = L.naive_loss(exp_logits, all_actions, criterion, log=True)
246 |
247 | # penalty for todo tests this !!!!
248 | action_own = u.get_prediction(exp_logits)
249 | if args.penalty and not [tuple(flat(t)) for t in all_actions]:
250 | final_loss = lstep * _variable([args.penalty])
251 | else:
252 | final_loss = lstep
253 |
254 | if args.opt_at == 'problem':
255 | loss += final_loss
256 | else:
257 |
258 | final_loss.backward(retain_graph=args.ret_graph)
259 | if args.clip:
260 | torch.nn.utils.clip_grad_norm(DNC.parameters(), args.clip)
261 | optimizer.step()
262 | loss = lstep
263 |
264 | data.send_action(final_action)
265 |
266 | if (trial + 1) % args.show_details == 0:
267 | action_accs = u.human_readable_res(data, all_actions, actions_star,
268 | action_own, guided, lstep.data[0])
269 | stats.append(action_accs)
270 | n_total, _ = tick(n_total, n_correct, action_own, flat(final_action))
271 | n_correct += 1 if action_own in [tuple(flat(t)) for t in actions_star] else 0
272 | prev_action = data.vec_to_ix(final_action)
273 |
274 | if stats:
275 | arr = np.array(stats)
276 | correct = len([1 for i in list(arr.sum(axis=1)) if i == len(stats[0])]) / len(stats)
277 | sl.log_acc(list(arr.mean(axis=0)), correct)
278 |
279 | if args.opt_at == 'problem':
280 | floss = loss / n_total
281 | floss.backward(retain_graph=args.ret_graph)
282 | if args.clip:
283 | torch.nn.utils.clip_grad_norm(DNC.parameters(), args.clip)
284 | optimizer.step()
285 | sl.writer.add_scalar('losses.end', floss.data[0], sl.global_step)
286 |
287 | n_success += 1 if n_correct / n_total > args.passing else 0
288 | cum_total.append(n_total)
289 | cum_correct.append(n_correct)
290 | sl.add_scalar('recall.pct_correct', n_correct / n_total, sl.global_step)
291 | print("trial {}, step {} trial accy: {}/{}, {:0.2f}, running total {}/{}, running avg {:0.4f}, loss {:0.4f} ".format(
292 | trial, sl.global_step, n_correct, n_total, n_correct / n_total, n_success, trial,
293 | running_avg(cum_correct, cum_total), loss.data[0]
294 | ))
295 | end_prob = time.time()
296 | prob_times.append(start_prob - end_prob)
297 | print("solved {} out of {} -> {}".format(n_success, args.iters, n_success / args.iters))
298 | return DNC, optimizer, lstm_state, running_avg(cum_correct, cum_total)
299 |
300 |
301 | def train_manager(args, train_fn):
302 | """
303 |
304 | :param args: args object. see arg.py or run.py -h for details
305 | :param train_fn: the training function -
306 | :return:
307 | """
308 | datspec = generate_data_spec(args)
309 | print('\nInitial Spec', datspec)
310 |
311 | _, DNC, optimizer, lstm_state = setupDNC(args)
312 | start_ents, score, global_epoch = args.n_init_start, 0, args.start_epoch
313 | print('\nDnc structure', DNC)
314 |
315 | for problem_size in range(args.max_ents):
316 | test_size = problem_size + start_ents
317 | passing = False
318 | data_spec = generate_data_spec(args, num_ents=test_size, solve=test_size * 3)
319 | data = gen.AirCargoData(**data_spec)
320 |
321 | print("beginning new training Size: {}".format(test_size))
322 | for train_epoch in range(args.n_phases):
323 | ep_start = time.time()
324 | global_epoch += 1
325 | print("\nStarting Epoch {}".format(train_epoch))
326 |
327 | DNC, optimizer, lstm_state, score = train_fn(args, data, DNC, lstm_state, optimizer)
328 | if (train_epoch + 1) % args.checkpoint_every and args.save != '':
329 | save(DNC, optimizer, lstm_state, args, global_epoch)
330 |
331 | ep_end = time.time()
332 | ttl_s = ep_end - ep_start
333 | print('finished epoch: {}, score: {}, ttl-time: {:0.4f}, time/prob: {:0.4f}'.format(
334 | train_epoch, score, ttl_s, ttl_s / args.iters
335 | ))
336 | if score > args.passing:
337 | print('model_successful: {}, {} '.format(score, train_epoch))
338 | print('----------------------WOO!!--------------------------')
339 | passing = True
340 | break
341 |
342 | if passing is False:
343 | print("Training has FAILED for problem of size: {}, after {} epochs of {} phases".format(
344 | test_size, args.max_ents, args.n_phases
345 | ))
346 | print("final score was {}".format(score))
347 | break
348 |
349 |
350 | if __name__== "__main__":
351 | print(args)
352 | if args.act == 'plan':
353 | train_manager(args, train_plan)
354 | elif args.act == 'qa':
355 | train_manager(args, train_qa2)
356 | elif args.act == 'clean':
357 | pass
358 | else:
359 | print("wrong action")
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 |
2 | # apt-get install -y supervisor
3 |
4 | # mkdir /etc/supervisor/conf.d/
5 | # cp tf.conf /etc/supervisor/conf.d/tensorboard.conf
6 | pip install tensorboardX
7 | pip install tensorflow
8 | pip install git+https://github.com/lanpa/tensorboard-pytorch
9 |
10 | # tensordboard logdir /output/runs
11 |
--------------------------------------------------------------------------------
/tests.py:
--------------------------------------------------------------------------------
1 | import unittest, copy, time
2 | import torch
3 | from problem import generators_v2 as gen
4 | from dnc_arity_list import DNC
5 | import problem.my_air_cargo_problems as mac
6 | from problem.lp_utils import (
7 | decode_state,
8 | )
9 | import run
10 | from visualize import wuddido as viz
11 |
12 |
13 | def test_solve_with_logic(data):
14 | print('\nWITH LOGICAL HEURISTICS')
15 | print('GOAL', data.current_problem.goal, '\n')
16 | start = time.time()
17 | f = 0
18 | for n in range(20):
19 | log_actions = data.best_logic(data.get_raw_actions(mode='all'))
20 | if log_actions != []:
21 | print('chosen', log_actions[0])
22 | data.send_action(data.expr_to_vec(log_actions[0]))
23 | else:
24 | f = n
25 | break
26 | end = time.time()
27 | print('DONE in {} steps, {:0.4f} s'.format(f, end - start))
28 |
29 |
30 | def test_solve_with_algo(data):
31 | print('\nWITH ASTAR')
32 | print('GOAL', data.current_problem.goal, '\n')
33 | f = 0
34 | start2 = time.time()
35 | for n in range(20):
36 | actions = data.get_raw_actions(mode='best')
37 | # print(actions)
38 | if actions != []:
39 | print('chosen', actions[0])
40 | data.send_action(data.expr_to_vec(actions[0]))
41 | else:
42 | f = n
43 | break
44 | end2 = time.time()
45 | print('DONE in {} steps, {:0.4f} s'.format(f, end2 - start2 ))
46 |
47 |
48 | sample_args = {'num_plane': 2, 'num_cargo': 2,
49 | 'num_airport': 2,
50 | 'one_hot_size': [4,6, 4, 6, 4, 6],
51 | 'plan_phase': 2 * 3,
52 | 'cuda': False, 'batch_size': 1,
53 | 'encoding': 2, 'solve': True, 'mapping': None}
54 |
55 |
56 | class Misc(unittest.TestCase):
57 | def setUp(self):
58 | self.dataspec = {'solve': True, 'mapping': None,
59 | 'num_plane': 2, 'one_hot_size': [4, 6, 4, 6, 4, 6],
60 | 'num_airport': 2, 'plan_phase': 6,
61 | 'encoding': 2, 'batch_size': 1, 'num_cargo': 2}
62 | self.dataspec2 = {'solve': True, 'mapping': None,
63 | 'num_plane': 3, 'one_hot_size': [4, 6, 4, 6, 4, 6],
64 | 'num_airport': 3, 'plan_phase': 6,
65 | 'encoding': 2, 'batch_size': 1, 'num_cargo': 3}
66 |
67 | def est_cache(self):
68 | data = gen.AirCargoData(**self.dataspec)
69 | data.make_new_problem()
70 | test_solve_with_logic(copy.deepcopy(data))
71 | test_solve_with_algo(copy.deepcopy(data))
72 | print('\n\n ROUND 2')
73 | data.make_new_problem()
74 | test_solve_with_logic(copy.deepcopy(data))
75 | test_solve_with_algo(copy.deepcopy(data))
76 |
77 | def est_searches(self):
78 | problem = mac.air_cargo_p1()
79 | ds = decode_state(problem.initial, problem.state_map)
80 |
81 |
82 | class TestVis(unittest.TestCase):
83 | def setUp(self):
84 | self.folder = '1512692566_clip_cont2_40'
85 | iter = 109 #
86 | self.base = './models/{}/checkpts/{}/dnc_model.pkl'
87 | dict1 = torch.load(self.base.format(self.folder, iter))
88 |
89 | self.data = gen.AirCargoData(**sample_args)
90 |
91 | args = run.dnc_args.copy()
92 | args['output_size'] = self.data.nn_in_size
93 | args['word_len'] = self.data.nn_out_size
94 |
95 | self.Dnc = DNC(**args)
96 | self.Dnc.load_state_dict(dict1)
97 | pass
98 |
99 | def test_show_state(self):
100 | rand_vec = torch.randn(39, 1)
101 | viz.ix_to_color(rand_vec)
102 |
103 | def test_run(self):
104 | record = viz.recorded_step(self.data, self.Dnc)
105 | viz.make_usage_viz(record)
106 |
107 |
108 | if __name__ == '__main__':
109 | unittest.main()
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
--------------------------------------------------------------------------------
/tf.conf:
--------------------------------------------------------------------------------
1 | logdir /output/runs
--------------------------------------------------------------------------------
/training.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | from visualize import logger as sl
6 | import utils as u
7 | import losses as l
8 |
9 | def tick(n_total, n_correct, truth, pred):
10 | n_total += 1
11 | n_correct += 1 if truth == pred else 0
12 | return n_total, n_correct
13 |
14 | class TrainPlanning:
15 | def __init__(self, problem, optim, args):
16 | self.ProblemGen = problem
17 | self.args = args
18 | self.optimizer = optim
19 | self.current_prob = []
20 | self.reset()
21 |
22 | def reset(self):
23 | self.num_correct, self.total_moves, self.prev_action = 0, 0, None
24 | self.n_success = 0
25 |
26 | def phase_fn(self, idx):
27 | if idx == 0: return self.state_phase
28 | if idx == 1: return self.goal_phase
29 | if idx == 2: return self.plan_phase
30 | if idx == 3: return self.response_phase
31 |
32 | def state_phase(self, Dnc, dnc_state, inputs, mask):
33 | pass
34 |
35 | def goal_phase(self, Dnc, dnc_state, inputs, mask):
36 | pass
37 |
38 | def plan_phase(self, Dnc, dnc_state, inputs, mask):
39 | inputs = Variable(torch.cat([mask_, data.ix_input_to_ixs(prev_action)], 1))
40 | prev_action, dnc_state = Dnc(inputs, dnc_state)
41 | pass
42 |
43 | def response_phase(self, Dnc, dnc_state, inputs, mask):
44 | sl.global_step += 1
45 | self.total_moves += 1
46 | final_inputs = Variable(torch.cat([mask, self.ProblemGen.ix_input_to_ixs(prev_action)], 1))
47 |
48 | logits, dnc_state = Dnc(final_inputs, dnc_state)
49 | expanded_logits = self.ProblemGen.ix_input_to_ixs(logits)
50 | #
51 | chosen_act_own, loss_own = L.naive_loss(expanded_logits, all_actions, loss_fn)
52 | chosen_act_star, loss_star = L.naive_loss(expanded_logits, targets_star, loss_fn, log_itr=sl.global_step)
53 |
54 | # set next input to be the networks current action ...
55 | if random.random() < self.args.beta:
56 | loss = loss_star
57 | final_action = chosen_act_star
58 | else:
59 | loss = loss_own
60 | final_action = chosen_act_own
61 | self.num_correct += 1 if chosen_act_own == final_action else 0
62 | return Dnc, dnc_state, final_action
63 |
64 | def train(self, Dnc):
65 | for n in range(self.args.iters):
66 | self.current_prob = self.ProblemGen.make_new_problem()
67 | dnc_state = Dnc.init_state(grad=False)
68 | self.optimizer.zero_grad()
69 |
70 | for idx, _ in enumerate(self.current_prob):
71 | inputs, mask = self.ProblemGen.getitem()
72 | phase_fn = self.phase_fn(idx)
73 | phase_fn(Dnc, dnc_state, inputs, mask)
74 | pass
75 |
76 | def step(self):
77 | pass
78 |
79 | def end_problem(self):
80 | print("solved {} out of {} -> {}".format(self.n_success, self.args.iters, self.n_success / self.args.iters))
81 | pass
82 |
83 |
84 |
85 | def play_qa_readable(args, data, DNC):
86 | criterion = nn.CrossEntropyLoss()
87 | cum_correct, cum_total = [], []
88 |
89 | for trial in range(args.iters):
90 | phase_masks = data.make_new_problem()
91 | n_total, n_correct, loss = 0, 0, 0
92 | dnc_state = DNC.init_state(grad=False)
93 |
94 |
95 | for phase_idx in phase_masks:
96 | if phase_idx == 0 or phase_idx == 1:
97 |
98 | inputs, msk = data.getitem()
99 | print(data.human_readable(inputs, msk))
100 |
101 | inputs = Variable(torch.cat([msk, inputs], 1))
102 | logits, dnc_state = DNC(inputs, dnc_state)
103 | else:
104 | final_moves = data.get_actions(mode='one')
105 | if final_moves == []:
106 | break
107 | data.send_action(final_moves[0])
108 | mask = data.phase_oh[2].unsqueeze(0)
109 | vec = data.vec_to_ix(final_moves[0])
110 | print('\n')
111 | print(data.human_readable(vec, mask))
112 |
113 | inputs2 = Variable(torch.cat([mask, vec], 1))
114 | logits, dnc_state = DNC(inputs2, dnc_state)
115 |
116 | for _ in range(args.num_tests):
117 | # ask where is ---?
118 |
119 | masked_input, mask_chunk, ground_truth = data.masked_input()
120 | print("Context:", data.human_readable(ground_truth))
121 | print("Q:")
122 |
123 | logits, dnc_state = DNC(Variable(masked_input), dnc_state)
124 | expanded_logits = data.ix_input_to_ixs(logits)
125 |
126 | #losses
127 | lstep = l.action_loss(expanded_logits, ground_truth, criterion, log=True)
128 |
129 | #update counters
130 | prediction = u.get_prediction(expanded_logits, [3, 4])
131 | print("A:")
132 | n_total, n_correct = tick(n_total, n_correct, mask_chunk, prediction)
133 | print("correct:", mask_chunk == prediction)
134 |
135 |
136 | cum_total.append(n_total)
137 | cum_correct.append(n_correct)
138 | sl.writer.add_scalar('recall.pct_correct', n_correct / n_total, sl.global_step)
139 | print("trial: {}, step:{}, accy {:0.4f}, cum_score {:0.4f}, loss: {:0.4f}".format(
140 | trial, sl.global_step, n_correct / n_total, u.running_avg(cum_correct, cum_total), loss.data[0]))
141 | return DNC, dnc_state, u.running_avg(cum_correct, cum_total)
142 |
143 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.cuda
3 | from torch.autograd import Variable
4 | import os, glob
5 | from arg import args
6 |
7 | eps = 10e-6
8 | MODEL_NAME = 'dnc_model.pkl'
9 | OPTIM_NAME = 'optimizer.pkl'
10 |
11 |
12 | def _variable(xs, **kwargs):
13 | if args.cuda is True:
14 | return Variable(xs, **kwargs).cuda()
15 | else:
16 | return Variable(xs, **kwargs)
17 |
18 |
19 | def repackage(xs):
20 | """Wraps hidden states in new Variables, to detach them from their history."""
21 | if type(xs) == Variable:
22 | return Variable(xs.data)
23 | else:
24 | return tuple(repackage(v) for v in xs)
25 |
26 |
27 | def depackage(xs):
28 | """Wraps hidden states in new Variables, to detach them from their history."""
29 | if type(xs) == Variable:
30 | return xs.data
31 | elif type(xs) == torch.Tensor:
32 | return xs
33 | elif type(xs) == torch.cuda.FloatTensor:
34 | return xs
35 | else:
36 | return tuple(depackage(v) for v in xs)
37 |
38 |
39 | def running_avg(cum_correct, cum_total_move, last=-100):
40 | return sum(cum_correct[last:]) / sum(cum_total_move[last:])
41 |
42 |
43 | def flat(container):
44 | acc = []
45 | for i in container:
46 | if isinstance(i, (list, tuple)):
47 | for j in flat(i):
48 | acc.append(j)
49 | else:
50 | acc.append(i)
51 | return acc
52 |
53 |
54 | def show(tensors, m=""):
55 | print("--")
56 | if type(tensors) == torch.Tensor:
57 | print(m, tensors.size())
58 | else:
59 | print(m)
60 | [print(t.size()) for t in tensors]
61 |
62 |
63 | def clean_runs(dirs, lim=10):
64 | def convert_bytes(num):
65 | for x in ['bytes', 'KB', 'MB']:
66 | if num < 1024.0:
67 | return "%3.1f %s" % (num, x)
68 | num /= 1024.0
69 | for d in glob.glob(dirs + '*'):
70 | if convert_bytes(os.stat(d).st_size) < lim:
71 | os.remove(d)
72 |
73 |
74 | def get_chkpt(name, idx=0):
75 | search = './models/*{}/checkpts/*'.format(name)
76 | chkpt = sorted(glob.glob(search))[-1] + '/'
77 | return chkpt + MODEL_NAME, chkpt + OPTIM_NAME
78 |
79 |
80 | def dnc_checksum(state):
81 | return [state[i].data.sum() for i in range(6)]
82 |
83 |
84 | def save(model, optimizer, lstm_state, args, global_epoch):
85 | chkpt = '{}{}/{}/'.format(args.base_dir, 'checkpts', global_epoch)
86 | os.mkdir(chkpt)
87 |
88 | torch.save(model.state_dict(), chkpt + MODEL_NAME)
89 | torch.save(model, chkpt + 'dnc_model_full.pkl')
90 | torch.save(lstm_state, chkpt + 'lstm_state.pkl')
91 | torch.save(optimizer.state_dict(), chkpt + OPTIM_NAME)
92 | torch.save(optimizer, chkpt + 'optimizer_full.pkl')
93 | print("Saving ... file...{}, chkpt_num:{}".format(args.base_dir, global_epoch))
94 |
95 |
96 | def get_prediction(expanded_logits, idxs='all'):
97 | max_idxs = []
98 | if idxs == 'all':
99 | idxs = range(len(expanded_logits))
100 | for idx in idxs:
101 | _, pidx = expanded_logits[idx].data.topk(1)
102 | max_idxs.append(pidx.squeeze()[0])
103 | return tuple(max_idxs)
104 |
105 |
106 | def closest_action(pred, actions):
107 | best, chosen_action = 0, None
108 | for action in [flat(a) for a in actions]:
109 | scores = [1 if pred[i] == action[i] else 0 for i in range(len(pred))]
110 | if sum(scores) >= best:
111 | best, chosen_action = sum(scores), scores
112 | return chosen_action
113 |
114 |
115 | def human_readable_res(Data, all_actions, best_actions, pred, guided, loss_data):
116 |
117 | base_prob = 1 / len(all_actions)
118 | action_own = [pred[0], (pred[1], pred[2]), (pred[3], pred[4]), (pred[5], pred[6])]
119 | correct = action_own in best_actions
120 |
121 | action_der = closest_action(pred, best_actions)
122 | best_move_exprs = [Data.vec_to_expr(t) for t in best_actions]
123 | # all_move_exprs = [Data.vec_to_expr(t) for t in all_actions]
124 | chos_move, crest = Data.vec_to_expr(action_own)
125 |
126 | # print("all {}".format(', '.join(["{} {}".format(m[0], m[1]) for m in all_move_exprs])))
127 | print("best {}".format(', '.join(["{} {}".format(m[0], m[1]) for m in best_move_exprs])))
128 | print("chosen: {} {}, guided {}, prob {:0.2f}, T? {}---loss {:0.4f}".format(
129 | chos_move, crest, guided, base_prob, correct, loss_data
130 | ))
131 | return action_der
132 |
133 |
134 | def interface_part(num_reads, W):
135 | partition = [num_reads* W, num_reads, W, 1, W, W, num_reads, 1, 1, num_reads * 3]
136 | ds = []
137 | cntr = 0
138 | for idx in partition:
139 | tn = [cntr, cntr + idx]
140 | ds.append(tn)
141 | cntr += idx
142 | return ds
143 |
144 |
--------------------------------------------------------------------------------
/visualize/logger.py:
--------------------------------------------------------------------------------
1 | # from tensorboardX import SummaryWriter
2 | import torchvision.utils as vutils
3 | from torch.autograd import Variable
4 | from arg import writer
5 |
6 | global_step = 0
7 | log_step = 1
8 | # state = ["access_ouptut", "memory", "read_weights", "write_wghts", "link_matrix", "link_weights", "usage"]
9 | interface = ['read_keys', 'read_str', 'write_key', 'write_str',
10 | 'erase_vec', 'write_vec', 'free_gates', 'alloc_gate', 'write_gate', 'read_modes']
11 | losses_desc = ['action', 'ent1-type', 'ent1', 'ent2-type', 'ent2', 'ent3-type', 'ent3']
12 | losses_desc2 = ['action', 'ent1', 'ent2', 'ent3']
13 |
14 |
15 | def to_log(tnsr):
16 | if writer:
17 | clone = tnsr.clone()
18 | if type(clone) == Variable:
19 | clone = clone.data
20 | sizes = list(clone.size())
21 | if len(sizes) > 1 and not all((s == 1 for s in sizes)):
22 | return clone.cpu().squeeze().numpy()
23 | else:
24 | return clone.cpu().numpy()
25 |
26 |
27 | def log_if(name, data, f=None):
28 | if writer:
29 | if global_step % log_step == 0:
30 | if f is None:
31 | writer.add_histogram(name, data, global_step, bins='sturges')
32 | else:
33 | f(name, data, global_step)
34 |
35 |
36 | def log_interface(interface_vec, step):
37 | if writer:
38 | for name, data in zip(interface, interface_vec):
39 | writer.add_histogram("interface." + name, to_log(data), step, bins='sturges')
40 |
41 |
42 | def log_model(model):
43 | if writer:
44 | for name, param in model.named_parameters():
45 | writer.add_histogram(name, to_log(param), global_step, bins='sturges')
46 |
47 |
48 | def log_loss(losses, loss):
49 | if global_step % log_step == 0 and writer:
50 | type_descs = losses_desc2 if len(losses) == 4 else losses_desc
51 | for name, param in zip(type_descs, losses):
52 | writer.add_scalar("lossess." + name, to_log(param), global_step)
53 | writer.add_scalar('losses.total', loss.clone().cpu().data[0], global_step)
54 |
55 |
56 | def log_acc(accs, total):
57 | if global_step % log_step == 0 and writer:
58 | type_descs = losses_desc2 if len(accs) == 4 else losses_desc
59 | for name, param in zip(type_descs, accs):
60 | writer.add_scalar("acc." + name, param, global_step)
61 | writer.add_scalar('acc.total', total, global_step)
62 |
63 | def add_scalar(*args, **kwdargs):
64 | if writer:
65 | writer.add_scalar(*args, **kwdargs)
66 |
67 |
68 | def log_loss_qa(ent, inst, loss):
69 | if global_step % log_step == 0 and writer:
70 | writer.add_scalar('losses.ent1-type', ent.clone().cpu().data[0], global_step)
71 | writer.add_scalar('losses.ent1', inst.clone().cpu().data[0], global_step)
72 | writer.add_scalar('losses.total', loss.clone().cpu().data[0], global_step)
73 |
74 |
75 | def log_state(state):
76 | if global_step % log_step == 0 and len(state) == 8 and writer:
77 | out, mem, r_wghts, w_wghts, links, l_wghts, usage, hidden = state
78 | writer.add_histogram("state.access_ouptut", to_log(out), global_step, bins='sturges')
79 | writer.add_histogram("state.memory", to_log(mem), global_step, bins='sturges')
80 | writer.add_histogram("state.read_weights", to_log(r_wghts), global_step, bins='sturges')
81 | writer.add_histogram("state.write_wghts", to_log(w_wghts), global_step, bins='sturges')
82 | writer.add_histogram("state.link_matrix", to_log(links), global_step, bins='sturges')
83 | writer.add_histogram("state.link_weights", to_log(l_wghts), global_step, bins='sturges')
84 | writer.add_histogram("state.usage", to_log(usage), global_step, bins='sturges')
85 |
--------------------------------------------------------------------------------
/visualize/wuddido.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import matplotlib.figure as fig
3 | import matplotlib
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import losses as L
8 | from utils import depackage, _variable
9 | import copy
10 | from itertools import accumulate
11 | from torch.autograd import Variable
12 | import plotly.plotly as py
13 | import plotly.graph_objs as go
14 | from matplotlib.ticker import FuncFormatter, MaxNLocator
15 |
16 | """
17 | During the state phase, the network must learn to write each tuple to seperate mem location
18 | During the goal phase, the goal phase, the goal tuples are recorded.
19 | --todo-- the triple stored at each location can be recovered by logistic reg. decoder
20 | --todo--
21 | During state phase read modes -
22 |
23 | Extended Data Figure 2
24 | """
25 |
26 | c_map = plt.get_cmap('jet')
27 |
28 |
29 | def ix_to_color(ix_vec):
30 | """
31 |
32 | :param ix_vec:
33 | :return: Image
34 | """
35 | ix_np = ix_vec.numpy()
36 |
37 |
38 | def recorded_step(data, DNC):
39 | """
40 | state phase answer phase
41 | m w r
42 | e w r
43 | m w r
44 | w r
45 | l
46 | o
47 | c
48 | ------------time------------->
49 | :param state:
50 | :return:
51 | """
52 | criterion = nn.CrossEntropyLoss()
53 |
54 | phase_masks = data.make_new_problem()
55 | state = DNC.init_state(grad=False)
56 | lstm_state = DNC.init_rnn(grad=False)
57 |
58 | prev_action, time_steps = None, []
59 | # time_steps.append(copy.copy(state) + [-1])
60 |
61 | for phase_idx in phase_masks:
62 | if phase_idx == 0 or phase_idx == 1:
63 | inputs = Variable(data.getitem_combined())
64 | logits, state, lstm_state = DNC(inputs, lstm_state, state)
65 | #
66 | _, prev_action = data.strip_ix_mask(logits)
67 | time_steps.append({'state': copy.copy(state), 'input': inputs.data,
68 | 'outputs': logits.data, 'phase': phase_idx})
69 |
70 | elif phase_idx == 2:
71 | inputs = torch.cat([Variable(data.getmask()), prev_action], 1)
72 | logits, dnc_state, lstm_state = DNC(inputs, lstm_state, state)
73 | #
74 | _, prev_action = data.strip_ix_mask(logits)
75 | time_steps.append({'state': copy.copy(state), 'input': inputs.data,
76 | 'outputs': logits.data, 'phase': phase_idx})
77 | else:
78 | _, all_actions = data.get_actions(mode='both')
79 | if data.goals_idx == {}:
80 | break
81 | mask, pr = data.getmask(), depackage(prev_action)
82 | final_inputs = Variable(torch.cat([mask, pr], 1))
83 | #
84 | logits, state, lstm_state = DNC(final_inputs, lstm_state, state)
85 | exp_logits = data.ix_input_to_ixs(logits)
86 | # time_steps.append(copy.copy(state) + [phase_idx])
87 | final_action, _ = L.naive_loss(exp_logits, all_actions, criterion)
88 |
89 | print(final_action)
90 | # send action to Data Generator, and set locally
91 | data.send_action(final_action)
92 | prev_action = data.vec_to_ix(final_action)
93 | time_steps.append({'state': copy.copy(state), 'input': final_inputs.data,
94 | 'outputs': logits.data, 'phase': phase_idx})
95 | return time_steps
96 |
97 |
98 | def make_usage_viz(states):
99 | """
100 |
101 | :param states:
102 | :return:
103 | """
104 | # access, memory, read_wghts, write_wghts, \
105 | # link, link_wghts, usage, phase_idx = state
106 |
107 | mem_size = list(states[0]['state'][1].size())[1]
108 | map, writes = [], []
109 | phases = [c['phase'] for c in states]
110 |
111 | for idx, state in enumerate(states):
112 |
113 | # usage Vectors to calculate write positions
114 | current_usage = state['state'][6][0].data
115 | prev_usage = states[idx-1]['state'][6][0].data if idx > 0 else current_usage
116 |
117 | # max diff in usages =>
118 | diffs = current_usage - prev_usage
119 | diff_idx, w = max(enumerate(diffs))
120 |
121 | # decode max position
122 | map.append(np.abs(diffs.numpy()))
123 | write_vec = state['state'][1][0][diff_idx].data.numpy()
124 | print(write_vec)
125 |
126 | counts = dict((x, phases.count(x)) for x in set(phases))
127 | pos = [0] + list(accumulate(counts.values()))[:-1]
128 | ax = plt.gca()
129 |
130 | # fig = matplotlib.figure(figsize=(3, 3))
131 | rotated = np.rot90(np.asarray(map))
132 |
133 | ax.xaxis.set_ticks(pos)
134 | ax.xaxis.set_ticklabels(list(range(4)))
135 |
136 | plot = plt.imshow(rotated, cmap='hot', interpolation='nearest',
137 | extent=[0, len(states), 0, mem_size])
138 |
139 | plt.colorbar()
140 | plt.show()
141 |
142 | pass
143 |
144 |
145 |
146 |
147 |
--------------------------------------------------------------------------------