├── .floydexpt ├── .floydignore ├── .gitignore ├── .idea ├── fun-with-dnc.iml ├── misc.xml ├── modules.xml ├── vcs.xml └── workspace.xml ├── .vscode └── settings.json ├── README.md ├── __pycache__ ├── dnc.cpython-35.pyc └── generators.cpython-35.pyc ├── arg.py ├── dnc_arity_list.py ├── floydrequirements.txt ├── images ├── acc.png └── training.png ├── losses.py ├── misc ├── 78fa9003f6c0f735bc3250fc2116f6100463.pdf ├── dnc_sizes.ods └── train.csv ├── old ├── dnc.py ├── dnc_clean.py ├── dnc_model.png ├── dnc_stateful.py ├── dnc_v1_safe.py ├── generators.py ├── notes.txt └── train.py ├── problem ├── __init__.py ├── copy_squence.py ├── generators_v2.py ├── logic.py ├── lp_utils.py ├── my_air_cargo_problems.py ├── my_planning_graph.py ├── planning.py ├── search.py └── utils.py ├── run.py ├── setup.sh ├── tests.py ├── tf.conf ├── training.py ├── utils.py └── visualize ├── logger.py └── wuddido.py /.floydexpt: -------------------------------------------------------------------------------- 1 | {"family_id": "B7wQJjD7KMoyM7nAkQXfJj", "name": "dnc"} -------------------------------------------------------------------------------- /.floydignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .git 3 | .idea 4 | .gitignore 5 | __pycache__ 6 | runs 7 | models 8 | old 9 | misc 10 | images -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | __pycache__ 3 | runs 4 | 5 | models 6 | -------------------------------------------------------------------------------- /.idea/fun-with-dnc.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.linting.pylintEnabled": true 3 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fun-With-Dnc (Differentiable Neural Computing) 2 | 3 | Pytorch implementation of deepmind paper [Hybrid computing using a neural network with dynamic external memory]: https://pdfs.semanticscholar.org/7635/78fa9003f6c0f735bc3250fc2116f6100463.pdf. The code is based on the tensorflow implementation [here]: https://github.com/deepmind/dnc. 4 | 5 | Todo finish retraining, and better writeup. 6 | 7 | ## Problems and Expirements 8 | There are a few tasks setup. One is the "Air Cargo Prolbem" from Arificial Intelligence (Russell & Norvig). The origional code for the problem is based on the [Udacity Implementation]: https://github.com/udacity/AIND-Planning , and the full description is in the problem repo. 9 | 10 | The Air Cargo problem can be seen as structured prediction, (every prediction step can be seen as changing the state of the problem). The algorithms used to solve it in the book included graphplan, and Astar search of the state space, putting it in the same family of problems as the Blocks Problem (SHRDLU) solved in the origional paper. 11 | 12 | ## Installation 13 | Requires pytorch (no cuda) 14 | 15 | conda install pytorch torchvision cuda80 -c soumith 16 | 17 | ## Tensorboard 18 | Run tensorboard with "--log 10" flag (the number is logging frequency). Below is a shot during training. Losses are recorded seperately for each entity and type, as well as the action for more addictive monitoring. 19 | 20 | ![alt text](images/training.png) 21 | ![alt text](images/acc.png) 22 | 23 | To run with tensorboard: 24 | pip install tensorboardX (for tensorboard) 25 | pip install tensorflow (for tensorboard web server) 26 | 27 | ## Training Scenarios 28 | ### Planning 29 | The code implements a training schedule as in the paper. Start small with the minimum sized problem (2 entities of each kind) 30 | 31 | python run.py --act plan --iters 1000 --ret_graph 1 --zero_at step --n_phases 20 --opt_at step 32 | python run.py --act plan --iters 1000 --ret_graph 1 --zero_at step --n_phases 20 --opt_at step --save opt_zero_step 33 | python run.py --act plan --iters 1000 --ret_graph 0 --opt_at problem --save opt_problem_plan --n_phases 20 34 | 35 | 36 | We humans would think about the problem in terms of actions and type, so I thought the first thing the DNC would start getting correct would be the (Action, typeofthing1, typeofthing2, typeofthing3) 'tuple', since those must be correct in order to reliably get the instance correct. This was indeed the case as can be seen on the 'accuracies' plots during training. By the scemantics of the problem, the last 'type' is always Airplane, so that goes to 100% accuracy immediately. The next chunk of 1/3td of the training bumps up the types to 0.9-1.0 range. Only then does the loss for the entities themselves start dropping consistently. Even then, the ent1 and ent3 were coupled, which in the logic of the problem... 37 | 38 | To show details at each step of what was predicted vs best moves, specify the --detail flag. You will get something like this: 39 | 40 | trial 978, step 19514 trial accy: 6/7, 0.86, pass total 296/978, running avg 0.7463, loss 0.0774 41 | best Load ['C1', 'P1', 'A1'], Fly ['P1', 'A1', 'A0'] 42 | chosen: Load ['C1', 'P1', 'A1'], guided True, prob 0.25, T? True ---loss 0.2553 43 | best Fly ['P1', 'A1', 'A0'], Unload ['C1', 'P1', 'A0'] 44 | chosen: Fly ['P1', 'A1', 'A0'], guided True, prob 0.33, T? True ---loss 0.0784 45 | best Unload ['C1', 'P1', 'A0'], Fly ['P0', 'A0', 'A1'] 46 | chosen: Unload ['C1', 'P1', 'A0'], guided True, prob 0.33, T? True ---loss 0.0830 47 | best Fly ['P0', 'A0', 'A1'], Load ['C0', 'P0', 'A1'] 48 | chosen: Fly ['C0', 'P0', 'A1'], guided True, prob 0.25, T? False ---loss 0.3716 49 | best Load ['C0', 'P0', 'A1'], Fly ['P0', 'A1', 'A0'] 50 | chosen: Load ['C0', 'P0', 'A1'], guided True, prob 0.25, T? True ---loss 0.1288 51 | best Fly ['P0', 'A1', 'A0'], Unload ['C0', 'P0', 'A0'] 52 | chosen: Fly ['P0', 'A1', 'A0'], guided False, prob 0.25, T? True ---loss 1.1554 53 | best Fly ['P0', 'A1', 'A0'], Unload ['C0', 'P0', 'A0'] 54 | chosen: Unload ['C0', 'P1', 'A0'], guided True, prob 0.33, T? False ---loss 0.9901 55 | best Unload ['C0', 'P0', 'A0'] 56 | chosen: Unload ['C0', 'P0', 'A0'], guided False, prob 0.33, T? True ---loss 0.8087 57 | best Fly ['P0', 'A1', 'A0'], Unload ['C0', 'P0', 'A0'] 58 | chosen: Unload ['C0', 'P0', 'A0'], guided True, prob 0.33, T? True ---loss 0.7677 59 | 60 | The best actions are what was deterimined by the problem heurstics (not always optimal to save time). The chosen action is what the DNC ended up chosing. 'Guided' refers to Beta from the paper. 'Prob' is the chance of chosing that action (out of all legal actions), and the loss is there as well. 61 | 62 | 63 | ### Question Answering 64 | Another task that would be interesting I figured would be to give the DNC a problem (initial state, and goal), then make some moves, and ask where a certain Cargo is (which airport is it in? is it in a plane? which plane?). This did not work too well. See the run.py train_qa function. 65 | 66 | python run.py --act qa --iters 1000 --n_phases 20 67 | 68 | ## Training Misc 69 | ### Other Problems 70 | In an initial pass I tested with the sequence Memorization task from the deepmind repo. I have not tested it recently and I doubt it works (see todo). To run this specify the wit the problem 71 | 72 | ### Other Setups 73 | The DNC was tested against vanilla Lstms. The Lstm appears to get stuck on air cargo problem at ~40%. To run the training with LSTM only specify with '--algo LSTM' flag like so: 74 | 75 | python run.py --act plan --algo lstm --iters 1000 --n_phases 20 76 | 77 | ### Misc 78 | Training at each 'level' took 20K steps. This is way more than reported in the paper. On my crappy home CPU, this meant about a day, aka forever. Since I also lost my computer, causing me to need to retrain everything, I only got through the first level of training before having to submit (2 airports, 2 cargos, 2 planes). 79 | 80 | 81 | ## Differences from Original 82 | There was some expirementation here, so there are a bunch of flags on when to optimize. In the paper they calculated loss at end of each problem. This did not work for me, so I ended up with running the optimzer after each response. 83 | 84 | ## Loading Previous Run 85 | 86 | python run.py --act plan --iters 1000 --n_phases 20 --load the_saved_name_or_path --save the_new 87 | 88 | ## Flags 89 | 90 | 91 | ## Running on floydhub 92 | Set the --env flag to floyd. When it gets up there, the script will create all the directories in /output. Tensorboard for pytorch does not appear to work on there for reasons I do not understand. 93 | 94 | floyd run --env pytorch-0.2 --tensorboard "bash setup.sh && python run.py --act dag --iters 1000 --env floyd" 95 | 96 | ## Todo 97 | Upload best models 98 | Test the sequence memorization task. probably does not work. 99 | 100 | 101 | ~~Implement with GPU.~~ 102 | 103 | 104 | ~~Faster problem generator~~ 105 | 106 | 107 | ~~fix tensorboard issues~~ 108 | 109 | 110 | ~~gradient clipping~~ 111 | 112 | 113 | visualization of what dnc is doing internally (per paper) 114 | 115 | 116 | penalty for bad actions when not using the beta coefficient for forcing 117 | 118 | 119 | losses by prediction (fast loss) 120 | 121 | 122 | run whole lstm on input and goal state? 123 | 124 | 125 | Document args in argparse 126 | 127 | 128 | Testing on moar problems. 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /__pycache__/dnc.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/psavine42/fun-with-dnc/bdf110762f00347cae77bb3689d628d053893a4b/__pycache__/dnc.cpython-35.pyc -------------------------------------------------------------------------------- /__pycache__/generators.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/psavine42/fun-with-dnc/bdf110762f00347cae77bb3689d628d053893a4b/__pycache__/generators.cpython-35.pyc -------------------------------------------------------------------------------- /arg.py: -------------------------------------------------------------------------------- 1 | import argparse, os, time, json 2 | 3 | parser = argparse.ArgumentParser(description='Hyperparams') 4 | ########### STANDARD ENV SPEC ############## 5 | parser.add_argument('-a', '--act', nargs='?', type=str, default='train', help='[]') 6 | parser.add_argument('-l', '--load', nargs='?', type=str, default='', help='load model and state') 7 | parser.add_argument('--start_epoch', nargs='?', type=int, default=0, help='load model and state checkpt start') 8 | parser.add_argument('-s', '--save', nargs='?', type=str, default='', help='save if true') 9 | parser.add_argument('-i', '--iters', nargs='?', type=int, default=21, help='number of iterations') 10 | parser.add_argument('--env', type=str, default='', help='') 11 | parser.add_argument('--lr', type=float, default=1e-5, help='learning rate 1e-5 in paper') 12 | parser.add_argument('--checkpoint_every', type=int, default=5, help='save models every n epochs') 13 | parser.add_argument('--log', type=int, default=0, help='send to tensorboard frequency. 0 to disable') 14 | parser.add_argument('--notes', type=str, default='', help='any notes') 15 | parser.add_argument('-d', '--show_details', type=int, default=20, help='any notes') 16 | 17 | ########### Algorithm and Optimizer ############## 18 | parser.add_argument('--opt', type=str, default='adam', help='optimizer to use. options are sgd and adam') 19 | parser.add_argument('--algo', type=str, default='dnc', help='dnc or lstm. default is dnc') 20 | parser.add_argument('-c', '--cuda', type=int, default=0, help='device to run - 1 if cuda, 0 if cpu') 21 | parser.add_argument('--clip', type=float, default=0, help='gradient clipping, if zero, no clip') 22 | 23 | ########### CONTROL FLOW ############## 24 | parser.add_argument('--feed_last', type=int, default=1, help='') 25 | parser.add_argument('--opt_at', type=str, default='step', help='') 26 | parser.add_argument('--zero_at', type=str, default='step', help='') 27 | parser.add_argument('--ret_graph', type=int, default=1, help='retain graph todo change to ') 28 | parser.add_argument('--rpkg_step', type=int, default=1, help='repackage input vars at each step') 29 | 30 | ########### PROBLEM SETUP ############## 31 | parser.add_argument('-p', '--n_phases', type=int, default=15, help='number of training phases') 32 | parser.add_argument('--n_cargo', type=int, default=2, help='number of cargo at starts') 33 | parser.add_argument('--n_plane', type=int, default=2, help='number of plane at starts') 34 | parser.add_argument('--n_airport', type=int, default=2, help='number of airports at starts') 35 | parser.add_argument('-n', '--n_init_start', type=int, default=2, help='number of entities to start with') 36 | parser.add_argument('--typed', type=int, default=1, help='1=use typed entity descriptions, 0=one hot each entity') 37 | 38 | ########### PROBLEM CONTROL ############## 39 | parser.add_argument('--passing', type=float, default=0.9, help='passing percentage for a run default 0.9') 40 | parser.add_argument('--num_tests', type=int, default=2, help='') 41 | parser.add_argument('--num_repeats', type=int, default=2, help='') 42 | parser.add_argument('--max_ents', type=int, default=6, help='maximum number of entities') 43 | parser.add_argument('--beta', type=float, default=0.8, help='mixture param from paper') 44 | parser.add_argument('--penalty', type=float, default=0.0, help='mixture param from paper') 45 | args = parser.parse_args() 46 | print('\n\n') 47 | 48 | args.repakge_each_step = True if args.rpkg_step == 1 else False 49 | args.ret_graph = True if args.ret_graph == 1 else False 50 | args.cuda = True if args.cuda == 1 else False 51 | args.clip = None if args.clip == 0 else args.clip 52 | args.prefix = '/output/' if args.env == 'floyd' else './' 53 | args.penalty = None if args.penalty == 0.0 else args.penalty 54 | 55 | if not os.path.exists(args.prefix + 'models'): 56 | os.mkdir(args.prefix + 'models') 57 | 58 | start_timer = time.time() 59 | start = '{:0.0f}'.format(start_timer) 60 | 61 | if args.save != '': 62 | args.base_dir = '{}{}{}_{}/'.format(args.prefix, 'models/', start, args.save) 63 | os.mkdir(args.base_dir) 64 | os.mkdir(args.base_dir + 'checkpts/') 65 | argparse_dict = vars(args) 66 | with open(args.base_dir + 'params.txt', 'w') as outfile: 67 | json.dump(argparse_dict, outfile) 68 | print('Saving in folder {}'.format(args.base_dir)) 69 | 70 | writer = None 71 | if args.log > 0: 72 | from tensorboardX import SummaryWriter 73 | global writer 74 | writer = SummaryWriter() 75 | from visualize import logger 76 | logger.log_step += args.log 77 | 78 | # test +new no log # 79 | # python run.py -a plan -n 3 -p 2 -d 5 --iters 10 --opt_at step --zero_at step 80 | 81 | # test +load # 82 | # python run.py -a plan -n 3 -p 2 -d 5 --load --iters 10 --opt_at step --zero_at step 83 | 84 | # test +new +log 85 | # python run.py -a plan -n 3 -p 2 -d 5 --iters 10 --log 2 --opt_at step --zero_at step 86 | 87 | # test +new +cuda # 88 | # python run.py -a plan -n 3 -p 2 -d 5 --iters 10 --cuda 1 --opt_at step --zero_at step 89 | 90 | # test +new +cuda +clip # 91 | # python run.py -a plan -n 3 -p 2 -d 5 --clip 20 --iters 10 --cuda 1 --opt_at step --zero_at step 92 | 93 | # test +new +load +log # 94 | # python run.py -a plan -n 2 --n_phases 2 --show_details 5 --clip 40 --log 2 --iters 10 --cuda 0 --opt_at step --zero_at step 95 | 96 | 97 | 98 | 99 | # floyd run --env pytorch-0.2 --tensorboard "bash setup.sh && python run.py --act run --opt_at problem --ret_graph 0 --env floyd --save _nopkg --n_phases 2 --iters 10000" 100 | # python run.py --act run --opt_at problem --ret_graph 0 --save _nopkg --n_phases 2 --iters 10000 101 | 102 | # QA training 103 | # floyd run --env pytorch-0.2 --tensorboard "bash setup.sh && python run.py --act dag --iters 1000 --save _new_hidden --ret_graph 1 --opt_at step --env floyd" 104 | # floyd run --cpu --env pytorch-0.2 --tensorboard 'bash setup.sh && python run.py --act dag --iters 1000 --save _new_hidden --ret_graph 1 --opt_at step --zero_at step --env floyd' 105 | -------------------------------------------------------------------------------- /floydrequirements.txt: -------------------------------------------------------------------------------- 1 | 2 | tensorflow 3 | tensorboardX -------------------------------------------------------------------------------- /images/acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/psavine42/fun-with-dnc/bdf110762f00347cae77bb3689d628d053893a4b/images/acc.png -------------------------------------------------------------------------------- /images/training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/psavine42/fun-with-dnc/bdf110762f00347cae77bb3689d628d053893a4b/images/training.png -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from visualize import logger as sl 3 | from utils import flat, repackage, _variable 4 | 5 | 6 | def action_loss(logits, action, criterion, log=None): 7 | """ 8 | Sum of losses of one hot vectors encoding an action 9 | :param logits: network output vector of [action, [[type_i, ent_i], for i in ents]] 10 | :param action: target vector size [7] 11 | :param criterion: loss function 12 | :return: 13 | """ 14 | losses = [] 15 | for idx, action_part in enumerate(flat(action)): 16 | tgt = _variable(torch.LongTensor([action_part])) 17 | losses.append(criterion(logits[idx], tgt)) 18 | loss = torch.stack(losses, 0).mean() 19 | if log is not None: 20 | sl.log_loss(losses, loss) 21 | return loss 22 | 23 | 24 | def get_top_prediction(expanded_logits, idxs=None): 25 | max_idxs = [] 26 | idxs = range(len(expanded_logits)) if idxs is None else idxs 27 | for idx in idxs: 28 | _, pidx = expanded_logits[idx].data.topk(1) 29 | max_idxs.append(pidx.squeeze()[0]) 30 | return tuple(max_idxs) 31 | 32 | 33 | def combined_ent_loss(logits, action, criterion, log=None): 34 | """ 35 | some hand tunining of penalties for illegal actions... 36 | trying to force learning of types. 37 | 38 | action type => type_e... 39 | :param logits: network output vector of one_hot distributions 40 | [action, [type_i, ent_i], for i in ents] 41 | :param action: target vector size [7] 42 | :param criterion: loss function 43 | :return: 44 | """ 45 | losses = [] 46 | for idx, action_part in enumerate(flat(action)): 47 | tgt = _variable(torch.Tensor([action_part]).float()) 48 | losses.append(criterion(logits[idx], tgt)) 49 | lfs = [[losses[0]]] 50 | n = 2 51 | for l in(losses[i:i+n] for i in range(1, len(losses), n)): 52 | lfs.append(torch.stack(losses, 0).sum()) 53 | loss = torch.stack(lfs, 0).mean() 54 | if log is not None: 55 | sl.log_loss(losses, loss) 56 | return loss 57 | 58 | 59 | def naive_loss(logits, targets, criterion, log=None): 60 | """ 61 | Calculate best choice from among targets, and return loss 62 | 63 | :param logits: 64 | :param targets: 65 | :param criterion: 66 | :return: loss 67 | """ 68 | # copy_logits = depackage(logits) 69 | # final_action = closest_action(copy_logits, targets) 70 | loss_idx, _ = min(enumerate([action_loss(repackage(logits), a, criterion) for a in targets])) 71 | final_action = targets[loss_idx] 72 | return final_action, action_loss(logits, final_action, criterion, log=log) 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /misc/78fa9003f6c0f735bc3250fc2116f6100463.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/psavine42/fun-with-dnc/bdf110762f00347cae77bb3689d628d053893a4b/misc/78fa9003f6c0f735bc3250fc2116f6100463.pdf -------------------------------------------------------------------------------- /misc/dnc_sizes.ods: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/psavine42/fun-with-dnc/bdf110762f00347cae77bb3689d628d053893a4b/misc/dnc_sizes.ods -------------------------------------------------------------------------------- /old/dnc_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/psavine42/fun-with-dnc/bdf110762f00347cae77bb3689d628d053893a4b/old/dnc_model.png -------------------------------------------------------------------------------- /old/generators.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from torch.utils.data import Dataset 4 | import numpy as np 5 | from timeit import default_timer as timer 6 | import torch 7 | parent = os.path.dirname(os.path.realpath(__file__)) 8 | sys.path.insert(0, os.path.dirname(parent)) 9 | import aimacode 10 | import my_air_cargo_problems as mac 11 | from aimacode.search import * 12 | from aimacode.utils import Expr 13 | 14 | props = {'At':0, 'In':1} 15 | actions = {'Fly':0, 'Load':1, 'Unload':2} 16 | actions_1h = {0:'Fly', 1:'Load', 2:'Unload'} 17 | 18 | exprs = ['Fly', 'Load', 'Unload', 'At', 'In'] 19 | 20 | phases = ['State', 'Goal', 'Plan', 'Solve'] 21 | 22 | phase_to_ix = {word: i for i, word in enumerate(phases)} 23 | ix_to_phase = {i: word for i, word in enumerate(phases)} 24 | exprs_to_ix = {exxp: i for i, exxp in enumerate(exprs)} 25 | ix_to_exprs = {i: exxp for i, exxp in enumerate(exprs)} 26 | 27 | def encoding_(): 28 | pass 29 | 30 | def swap_fly(fly_action): 31 | """ 32 | 33 | :param fly_action: Fly(P1, A1, A0) 34 | :return: Fly(P1 _ A0) 35 | """ 36 | 37 | pass 38 | 39 | class Encoded_Expr(): 40 | def __init__(self, op, args): 41 | self.op = str(op) 42 | self.args = args 43 | self.one_hot = [] 44 | 45 | def vec_to_expr(self): 46 | pass 47 | 48 | 49 | 50 | 51 | class EncodedAirCargoProblem(mac.AirCargoProblem): 52 | def __init__(self, problem, init_vec, goals_vec, mp_merged, one_hot_size): 53 | self.problem = problem 54 | self.succs = self.goal_tests = self.states = 0 55 | self.found = None 56 | self.init_state = init_vec 57 | 58 | self.goal_state = goals_vec 59 | # dictionary of mappings of ents to one_hot 60 | self.problem_ents = mp_merged 61 | # dictionary pf mappings of one_hot to ents 62 | self.ent_to_vec = self.flip_encoding(mp_merged) 63 | print(self.problem_ents) 64 | print(self.ent_to_vec) 65 | self.one_hot_size = one_hot_size 66 | self.solution_node = None 67 | self.entity_o_h = torch.eye(one_hot_size).long() 68 | self.action_o_h = torch.eye(3).long() 69 | self.types_o_h = torch.eye(3).long() 70 | self.phases_o_h = torch.LongTensor([[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]) 71 | 72 | def flip_encoding(self, ents): 73 | encoding = {} 74 | for key, value in ents.items(): 75 | encoding[str(list(value))] = key 76 | return encoding 77 | 78 | def reverse_lookup(self, one_hot): 79 | return next(key for key, value in self.problem_ents.items() if value == one_hot) 80 | 81 | def action_expr_to_vec(self, action_expr): 82 | """ 83 | 84 | :param action_expr: Action Expr object Fly(P0, A1, A2) 85 | :return: action vec [0, 86 | """ 87 | action_vec = [actions[action_expr.name]] 88 | for arg in action_expr.args: 89 | e_type, ent = self.problem_ents[str(arg)] 90 | action_vec.append(e_type) 91 | action_vec.append(ent) 92 | print("action_vec", action_vec) 93 | return np.asarray(action_vec, dtype=int) 94 | 95 | def get_best_action_vecs(self, solution_node): 96 | """ 97 | 98 | :param solution_node: Graph Node reporesenting solution 99 | :return: vectors for each expr in solution 100 | """ 101 | action_vecs = [] 102 | for action_expr in solution_node.solution()[0:len(self.problem.planes)]: 103 | action_vec = self.action_expr_to_vec(action_expr) 104 | action_vecs.append(torch.from_numpy(action_vec)) 105 | return action_vecs 106 | 107 | def get_all_actions(self, state): 108 | zz = [torch.from_numpy(self.action_expr_to_vec(a)) for a in self.problem.actions(state)] 109 | return zz 110 | 111 | def encode_solution(self, solution): 112 | zz = [torch.from_numpy(self.action_expr_to_vec(a)) for a in solution] 113 | return zz 114 | 115 | def decode_action(self, coded_action): 116 | sym = actions_1h[coded_action[0]] 117 | ent1 = str(list(coded_action[1:3])) 118 | ent2 = str(list(coded_action[3:5])) 119 | ent3 = str(list(coded_action[5:7])) 120 | ex1 = self.ent_to_vec[ent1] 121 | ex2 = self.ent_to_vec[ent2] 122 | ex3 = self.ent_to_vec[ent3] 123 | return sym, [ex1, ex2, ex3] 124 | 125 | def send_action(self, state, coded_action): 126 | sym, args = self.decode_action(coded_action) 127 | actions_ = self.actions(state) 128 | print(actions_) 129 | final_act = [] 130 | for a in actions_: 131 | if a.name == sym and all((str(ar) == at) for ar, at in zip(a.args, args)): 132 | final_act = a 133 | break 134 | assert final_act != [] 135 | print(final_act) 136 | result_state = self.problem.result(state, final_act) 137 | return result_state 138 | 139 | def actions(self, state): 140 | self.succs += 1 141 | return self.problem.actions(state) 142 | 143 | def result(self, state, action): 144 | self.states += 1 145 | return self.problem.result(state, action) 146 | 147 | def goal_test(self, state): 148 | self.goal_tests += 1 149 | result = self.problem.goal_test(state) 150 | if result: 151 | self.found = state 152 | return result 153 | 154 | def run_search(self, search_fn, parameter=None): 155 | if parameter is not None: 156 | prm = getattr(self.problem, parameter) 157 | node = search_fn(self.problem, prm) 158 | else: 159 | node = search_fn(self.problem) 160 | return node 161 | 162 | def path_cost(self, c, state1, action, state2): 163 | return self.problem.path_cost(c, state1, action, state2) 164 | 165 | def value(self, state): 166 | return self.problem.value(state) 167 | 168 | def __getattr__(self, attr): 169 | return getattr(self.problem, attr) 170 | 171 | 172 | class AirCargoData(): 173 | """ 174 | Flags for 175 | """ 176 | def __init__(self, 177 | num_plane=10, num_cargo=6, batch_size=6, 178 | num_airport=1000, one_hot_size=10, mode='loose', 179 | search_function=astar_search): 180 | self.num_plane = num_plane 181 | self.num_cargo = num_cargo 182 | self.num_airport = num_airport 183 | self.batch_size = batch_size 184 | self._actions_mode = mode 185 | self.one_hot_size = one_hot_size 186 | self.search_fn = search_function 187 | self.STATE = '' 188 | self.search_param = 'h_ignore_preconditions' 189 | self.problem = None 190 | self.current_problem = None 191 | self.current_index = 0 192 | #self.make_new_problem() 193 | 194 | def print_solution(self, node): 195 | for action in node.solution(): 196 | print("{}{}".format(action.name, action.args)) 197 | 198 | def phase_vec(self, tensor_, add_channel): 199 | chans = torch.stack([torch.LongTensor(add_channel) for _ in range(tensor_.size(0))], dim=0) 200 | return torch.cat([chans, tensor_.long()], dim=-1) 201 | 202 | def get_actions(self, mode='strict'): 203 | self.problem.problem.initial = self.STATE 204 | if mode == 'all': 205 | return self.problem.get_all_actions(self.STATE) 206 | else: 207 | solution = self.problem.run_search(self.search_fn, self.search_param) 208 | return self.problem.get_best_action_vecs(solution) 209 | 210 | def encode_action(self, action_obj): 211 | return torch.from_numpy(self.problem.encode_action(action_obj)).long() 212 | 213 | def send_action(self, coded_action): 214 | self.problem.problem.initial = self.STATE 215 | self.STATE = self.problem.send_action(self.STATE, coded_action) 216 | return True 217 | 218 | def vec_to_one_hot(self, coded_ent): 219 | ents = [] 220 | for idx in range(1, 7, 2): 221 | e_type = coded_ent[idx] 222 | ent = coded_ent[idx + 1] 223 | if e_type == 0 and ent == 0: 224 | ents.append(torch.zeros(3 + self.one_hot_size).long()) 225 | else: 226 | ents.append(torch.cat([self.problem.types_o_h[e_type], self.problem.entity_o_h[ent]], 0)) 227 | return torch.cat(ents, 0) 228 | 229 | def expand_state_vec(self, coded_state): 230 | """Input target vec representing cross entropy loss target [1 0 2 0 0 0 0] 231 | Returns a one hot version of it as training input [01 00, 100, 000, 000, 000]""" 232 | ents = self.vec_to_one_hot(coded_state) 233 | phase = self.problem.phases_o_h[coded_state[0]] 234 | return torch.cat([phase, ents], 0) 235 | 236 | def expand_action_vec(self, coded_action): 237 | """Input target vec representing action 238 | [1 0 2 0 0 0 0] 239 | Returns a one hot version of it as training input 240 | [01 00, 100, 000, 000, 000] 241 | """ 242 | ents = self.vec_to_one_hot(coded_action) 243 | action = self.problem.action_o_h[coded_action[0]] 244 | return torch.cat([action, ents], 0).unsqueeze(0).float() 245 | 246 | def make_new_problem(self): 247 | acp, i, g, m = mac.arbitrary_ACP(self.num_airport, self.num_plane, 248 | self.num_cargo, one_hot_size=self.one_hot_size) 249 | problem = EncodedAirCargoProblem(acp, i, g, m, self.one_hot_size) 250 | 251 | # print(problem) 252 | # run the solution to determine how long to give dnc 253 | solution_node = problem.run_search(self.search_fn, self.search_param) 254 | 255 | word_len = len(problem.init_state[0]) 256 | len_init_phase = len(problem.init_state) 257 | len_goal_phase = len(problem.goal_state) 258 | len_plan_phase = 10 259 | len_resp_phase = len(solution_node.solution()) + 6 260 | 261 | # determine the number of iterations input will happen for 262 | mask_zero = torch.zeros(len_init_phase + len_goal_phase + len_plan_phase) 263 | mask_ones = torch.ones(len_resp_phase) 264 | masks = torch.cat([mask_zero, mask_ones], 0) 265 | 266 | init_phs_data = self.phase_vec(torch.from_numpy(problem.init_state), [0]) 267 | goal_phs_data = self.phase_vec(torch.from_numpy(problem.goal_state), [1]) 268 | # during planning and response phases, there is no inputs. 269 | plan_phs_data = self.phase_vec(torch.zeros(len_plan_phase, word_len), [2]) 270 | resp_phs_data = self.phase_vec(torch.zeros(len_resp_phase, word_len), [3]) 271 | 272 | inputs = torch.cat([init_phs_data, goal_phs_data, plan_phs_data, resp_phs_data], 0) 273 | self.current_problem = [inputs, masks] 274 | self.problem = problem 275 | self.STATE = problem.initial 276 | self.current_index = 0 277 | return problem.problem, solution_node, masks 278 | 279 | def len__(self): 280 | if self.current_index >= self.current_problem[0].size(0): 281 | self.make_new_problem() 282 | return len(self.current_problem[1]) 283 | 284 | def getitem(self, batch=1): 285 | """Returns a problem, [initial-state, goals] 286 | and a runnable solution object [problem, solution_node] 287 | 288 | Otherwise take the target one_hot class mask in form of 289 | [ent1-type, ent1 ....entN, channel] 290 | 291 | """ 292 | if self.current_index >= self.current_problem[0].size(0): 293 | self.make_new_problem() 294 | 295 | masks = self.current_problem[1][self.current_index:self.current_index+batch] 296 | inputs = self.current_problem[0][self.current_index:self.current_index + batch] 297 | inputs = torch.stack([self.expand_state_vec(i).float() for i in inputs], 0) 298 | self.current_index += batch 299 | 300 | return inputs, masks 301 | 302 | 303 | class RandomData(Dataset): 304 | def __init__(self, 305 | num_seq=10, 306 | seq_len=6, 307 | iters=1000, 308 | seq_width=4): 309 | self.seq_width = seq_width 310 | self.num_seq = num_seq 311 | self.seq_len = seq_len 312 | self.iters = iters 313 | 314 | def __getitem__(self, index): 315 | con = np.random.randint(0, self.seq_width, size=self.seq_len) 316 | seq = np.zeros((self.seq_len, self.seq_width)) 317 | seq[np.arange(self.seq_len), con] = 1 318 | end = torch.from_numpy(np.asarray([[-1] * self.seq_width])).float() 319 | zer = np.zeros((self.seq_len, self.seq_width)) 320 | return seq, zer 321 | 322 | def __len__(self): 323 | return self.iters 324 | 325 | class GraphData(Dataset): 326 | def __init__(self, 327 | num_seq=10, 328 | seq_len=6, 329 | iters=1000, 330 | domain=None, 331 | actions=None, 332 | start_state=None, 333 | seq_width=4): 334 | """ 335 | Each vector encoded a triple consisting of a source label, 336 | an edge label and a destination label. 337 | All labels were represented as numbers between 0 and 999, 338 | with each digit represented as a 10-way one-hot encoding. 339 | We reserved a special ‘blank’ label, represented by the all-zero vector 340 | for the three digits, to indicate an unspecified label. 341 | Each label required 30 input elements, and each triple required 90. 342 | The sequences were divided into multiple phases: 343 | 1) first a graph description phase, 344 | then a series of query (2 Q) and answer (3 A) phases; 345 | in some cases the Q and A were separated by an additional planning phase 346 | with no input, during which the network was given time to compute the answer. 347 | During the graph description phase, the triples defining the input graph were 348 | presented in random order. 349 | Target vectors were present during only the answer phases. 350 | 351 | Params from Paper: 352 | GRAPH 353 | input vectors were size 92 354 | 90 info 355 | binary chan for phase transition 356 | binary chan for when prediction is needed 357 | target vectors 90 358 | BABI 359 | input vector of 159 one-hot-vector of words (156 unique words + 3 tokens) 360 | 361 | Propositions: 362 | :: one-hot 0-9 363 | :: [At ( C1 , SFO )] 364 | :: [{0..1} ( {0..1} , {0..1} )] + [True/False, end-exp, pred-required?] 365 | Load({}, {}, {}) 366 | special tokens [ '(', ')', ',', ] 367 | 368 | Actions vs Propositions 369 | Action => precond_pos , precond_neg , effects_pos , effects_ned] 370 | Eat(Cake) => [[Have(Cake) ], [], , [Eaten(Cake)], [Not Have(Cake)]] 371 | +---------------------------------------------------------------------------+ 372 | | 1 9 3 9 8 4 9 2 9 3 5 9 2 9 3 7 4 9 2 9 | 373 | | 1 | 374 | +---------------------------------------------------------------------------+ 375 | 376 | input: T? Op ( Pred ) 377 | +----------------------------+ 378 | start-seq | 1| 379 | Eat(Cake) | 00 0001 1001 0011 1001 0 0| Action Name | h * h * args 380 | Have(Cake) | 11 0100 1001 0011 1001 1 0| } Pre-Conditions 381 | . | 0 0| 382 | . | 0 0| 383 | Eaten(Cake) | 11 0101 1001 0011 1001 1 0| } Post-Conditions 384 | ¬ Have(Cake) | 01 0100 1001 0011 1001 0 1| 385 | +----------------------------+ 386 | 387 | Final input is concat of all statements 388 | 389 | Goal -> At(C1, Place ) 390 | 391 | Paper (first block, adjacency relation, second block) 392 | (100000, 1000, 010000) 393 | “block 1 above block 2” 394 | 395 | let the goals be 1 of 26 possible letters designated by one-hot encodings; 396 | that is, A =​ (1, 0, ..., 0), Z =​ (0, 0, ..., 1) and so on 397 | 398 | -The board is represented as a set of place-coded representations, one for each square. 399 | Therefore, (000000, 100000, ...) designates that the bottom, 400 | left-hand square is empty, block 1 is in the bottom centre square, and so on 401 | 402 | The network also sees a binary flag that represents a ‘go cue’. 403 | While the go cue is active, a goal is selected from the list of goals that have 404 | been shown to the network, its label is retransmitted to the network for one 405 | time-step, and the network can begin to move the blocks on the board. 406 | 407 | All told, the policy observes at each time-step a vector with features 408 | 409 | Constraints 16 410 | (goal name, first block, adjacency relation, second block, go cue, board state). 411 | [26 ... (6 4 6 )x6 1 63] -> 186 ~state 412 | 413 | there are 7 possible actions so output is mapped to size 7 one_hot vector. 414 | 415 | 10 goals -> 250, 416 | 417 | Up to 10 goals with 6 constraints each can be sent to the network before action begins. 418 | 419 | Once the go cue arrives, it is possible for the policy network to move a block 420 | from one column to another or to pass at each turn. 421 | We parameterize these actions using another one-hot encoding so that, 422 | for a 3 ×​ 3 board, a move can be made from any column to any other; 423 | with the pass move, there are therefore 7 moves. 424 | 425 | 426 | 8 Airports , 4 Cargos, 4 Airplanes 427 | 0001 , 0000 , 0100 428 | 429 | Network - [2 x 250] 430 | 431 | 432 | 433 | Input at t: prev 434 | 435 | 436 | Types of tasks -> 437 | 1) given True Statements, we may want to generate negative statements 438 | 2) 439 | 440 | """ 441 | self.seq_width = seq_width 442 | self.num_seq = num_seq 443 | self.seq_len = seq_len 444 | self.iters = iters 445 | 446 | def __getitem__(self, index): 447 | #con = np.random.randint(0, self.seq_width, size=self.seq_len) 448 | 449 | 450 | 451 | return None #seq, zer 452 | 453 | def __len__(self): 454 | return self.iters 455 | 456 | -------------------------------------------------------------------------------- /old/notes.txt: -------------------------------------------------------------------------------- 1 | 2 | >>> rnn = nn.LSTM(10, 20, 2) 3 | input_size – The number of expected features in the input x 4 | hidden_size – The number of features in the hidden state h 5 | num_layers – Number of recurrent layers. 6 | 7 | input (seq_len, batch, input_size): tensor containing the features of the input sequence. The input can also be a packed variable length sequence. See torch.nn.utils.rnn.pack_padded_sequence() for details. 8 | h_0 (num_layers * num_directions, batch, hidden_size): tensor containing the initial hidden state for each element in the batch. 9 | c_0 (num_layers * num_directions, batch, hidden_size): tensor containing the initial cell state for each element in the batch. 10 | 11 | 12 | >>> input = Variable(torch.randn(5, 3, 10)) 13 | >>> h0 = Variable(torch.randn(2, 3, 20)) 14 | >>> c0 = Variable(torch.randn(2, 3, 20)) 15 | >>> output, hn = rnn(input, (h0, c0)) 16 | 17 | rnn = nn.LSTM(word_size, hidden_size, num_layers) 18 | h0 = Variable(torch.randn(num_layers, z, hidden_size)) 19 | c0 = Variable(torch.randn(q, z, y)) 20 | 21 | input = Variable(torch.randn(5, z, x)) 22 | 23 | output, hn = rnn(input, (h0, c0)) 24 | 25 | nn (x, y, q) 26 | in (a, z, x) 27 | hdn (q, z, y) 28 | hdn (q, z, y) -------------------------------------------------------------------------------- /old/train.py: -------------------------------------------------------------------------------- 1 | 2 | def train_qa(args, num_problems, data, Dnc, optimizer, dnc_state, save_=False): 3 | """ 4 | I am jacks liver. This is a sanity test 5 | 6 | 0 - describe state. 7 | 1 - describe goal. 8 | 2 - do actions. 9 | 3 - ask some questions 10 | :param args: 11 | :return: 12 | """ 13 | sl.log_step += args.log 14 | 15 | print(Dnc) 16 | criterion = nn.CrossEntropyLoss() 17 | 18 | cum_correct = [] 19 | cum_total_move = [] 20 | num_tests = 2 21 | num_repeats = 1 22 | for n in range(num_problems): 23 | masks = data.make_new_problem() 24 | num_correct = 0 25 | total_moves = 0 26 | # prev_action = None 27 | 28 | if args.repakge_each_step is False: 29 | dnc_state = dnc.repackage(dnc_state) 30 | optimizer.zero_grad() 31 | # repackage the state to zero grad at start of each problem 32 | 33 | for idx, mask in enumerate(masks): 34 | sl.global_step += 1 35 | inputs, mask_ = data.getitem() 36 | 37 | if mask == 0 or mask == 1: 38 | for nz in range(num_repeats): 39 | inputs1 = Variable(torch.cat([mask_, inputs], 1)) 40 | logits, dnc_state = Dnc(inputs1, dnc_state) 41 | 42 | # sl.log_state(dnc_state) 43 | else: 44 | targets_star = data.get_actions(mode='one') 45 | if targets_star == []: 46 | break 47 | final_move = targets_star[0] 48 | 49 | if args.repakge_each_step is True: 50 | dnc_state = dnc.repackage(dnc_state) 51 | optimizer.zero_grad() 52 | 53 | for nq in range(num_repeats): 54 | inputs2 = Variable(torch.cat([data.phase_oh[2].unsqueeze(0), data.vec_to_ix(final_move)], 1)) 55 | logits, dnc_state = Dnc(inputs2, dnc_state) 56 | 57 | data.send_action(final_move) 58 | # sl.log_state(dnc_state) 59 | 60 | for _ in range(num_tests): 61 | total_moves += 1 62 | state_expr = random.choice(data.pull_state()) 63 | state_vec = data.expr_to_vec(state_expr) 64 | mask_idx = 2 # random.randint(1, 2) 65 | mask_chunk = state_vec[mask_idx] 66 | 67 | zeros = 0 if type(mask_chunk) == int else tuple([0] * len(mask_chunk)) 68 | masked_state_vec = state_vec.copy() 69 | # masked_state_vec[mask_idx] = zeros 70 | masked_state_vec[2] = zeros 71 | 72 | inputs3 = Variable(torch.cat([data.phase_oh[3].unsqueeze(0), data.vec_to_ix(masked_state_vec)], 1)) 73 | logits, dnc_state = Dnc(inputs3, dnc_state) 74 | 75 | expanded_logits = data.ix_input_to_ixs(logits) 76 | idx1, idx2 = mask_idx * 2 - 1, mask_idx * 2 77 | 78 | target_chunk1 = Variable(torch.LongTensor([mask_chunk[0]])) 79 | target_chunk2 = Variable(torch.LongTensor([mask_chunk[1]])) 80 | 81 | loss1 = criterion(expanded_logits[idx1], target_chunk1) 82 | loss2 = criterion(expanded_logits[idx2], target_chunk2) 83 | loss = loss1 + loss2 84 | sl.log_state(dnc_state) 85 | 86 | loss.backward(retain_graph=args.ret_graph) 87 | optimizer.step() 88 | sl.log_state(dnc_state) 89 | 90 | resp1, pidx1 = expanded_logits[idx1].data.topk(1) 91 | resp2, pidx2 = expanded_logits[idx2].data.topk(1) 92 | pred_tuple = pidx1.squeeze()[0], pidx2.squeeze()[0] 93 | # pred_expr = data.lookup_ix_to_expr(pred_tuple) 94 | correct_ = mask_chunk == pred_tuple 95 | num_correct += 1 if mask_chunk == pred_tuple else 0 96 | # print("step {}.{}, loss: {:0.2f}, state {} actual {} pred: {}, {}".format( 97 | # n, idx, loss.data[0], state_expr, mask_chunk, pred_tuple, correct_)) 98 | sl.log_loss_qa(loss1, loss2, loss) 99 | 100 | cum_total_move.append(total_moves) 101 | cum_correct.append(num_correct) 102 | trial_acc = num_correct / total_moves 103 | sl.writer.add_scalar('recall.pct_correct', trial_acc, sl.global_step) 104 | sl.log_state(dnc_state) 105 | sl.log_model(Dnc) 106 | 107 | print("trial: {} accy {:0.4f}, cum_score {:0.4f}".format(n, trial_acc, sum(cum_correct[-100:]) / sum(cum_total_move[-100:]))) 108 | if save_ is not False: 109 | save(Dnc, dnc_state, start, args.save, sl.global_step) 110 | 111 | score = sum(cum_correct[-100:]) / sum(cum_total_move[-100:]) 112 | return Dnc, optimizer, dnc_state, score 113 | 114 | def setupDNC(args): 115 | if args.algo == 'lstm': 116 | return setupLSTM(args) 117 | data = gen.AirCargoData(**generate_data_spec(args)) 118 | dnc_args['output_size'] = data.nn_in_size # output has no phase component 119 | dnc_args['word_len'] = data.nn_out_size 120 | print(dnc_args) 121 | if args.load == '': 122 | Dnc = dnc.DNC(**dnc_args) 123 | o, m, r, w, l, lw, u, (ho, hc) = Dnc.init_state() 124 | else: 125 | model_path = './models/dnc_model_' + args.load 126 | state_path = './models/dnc_state_' + args.load 127 | print('loading', model_path, state_path) 128 | Dnc = dnc.DNC(**dnc_args) 129 | Dnc.load_state_dict(torch.load(model_path)) 130 | o, m, r, w, l, lw, u, (ho, hc) = torch.load(state_path) 131 | print(dnc_checksum([o, m, r, w, l, lw, u])) 132 | 133 | lr = 5e-5 if args.lr is None else args.lr 134 | if args.opt == 'adam': 135 | optimizer = optim.Adam([{'params': Dnc.parameters()}, {'params': o}, {'params': m}, {'params': r}, {'params': w}, 136 | {'params': l}, {'params': lw}, {'params': u}, {'params': ho}, {'params': hc}], 137 | lr=lr) 138 | else: 139 | optimizer = optim.SGD([{'params': Dnc.parameters()}, {'params': o}, {'params': m}, {'params': r}, {'params': w}, 140 | {'params': l}, {'params': lw}, {'params': u}, {'params': ho}, {'params': hc}], 141 | lr=lr) 142 | dnc_state = (o, m, r, w, l, lw, u, (ho, hc)) 143 | return data, Dnc, optimizer, dnc_state 144 | 145 | 146 | """ 147 | #target_chunk1 = Variable(torch.LongTensor([mask_chunk[0]])) 148 | #target_chunk2 = Variable(torch.LongTensor([mask_chunk[1]])) 149 | #loss1 = criterion(expanded_logits[idx1], target_chunk1) 150 | #loss2 = criterion(expanded_logits[idx2], target_chunk2) 151 | #lstep = loss1 + loss2 152 | # action_loss(logits, expanded_logits, criterion) 153 | """ -------------------------------------------------------------------------------- /problem/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/psavine42/fun-with-dnc/bdf110762f00347cae77bb3689d628d053893a4b/problem/__init__.py -------------------------------------------------------------------------------- /problem/copy_squence.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/psavine42/fun-with-dnc/bdf110762f00347cae77bb3689d628d053893a4b/problem/copy_squence.py -------------------------------------------------------------------------------- /problem/generators_v2.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import torch 4 | from utils import flat 5 | from . import my_air_cargo_problems as mac 6 | from .search import * 7 | from .planning import Action 8 | import random 9 | 10 | 11 | EXPRESSIONS = ['Fly', 'Load', 'Unload', 'At', 'In'] 12 | PHASES = ['State', 'Goal', 'Plan', 'Solve'] 13 | 14 | phase_to_ix = {word: i for i, word in enumerate(PHASES)} 15 | ix_to_phase = {i: word for i, word in enumerate(PHASES)} 16 | exprs_to_ix = {exxp: i for i, exxp in enumerate(EXPRESSIONS)} 17 | ix_to_exprs = {i: exxp for i, exxp in enumerate(EXPRESSIONS)} 18 | 19 | 20 | default_encoding = {'expr', 'action', 'ent1-type', 'ent1'} 21 | ix_size = [9, 9, 9] 22 | ix2_size = [3, 6, 3, 6, 3, 6] 23 | split1 = [[0, 1], [0, 2], [0, 3]] 24 | split2 = [[0, 1], [1, 2], [0, 3], [1, 4], [0, 5], [1, 6]] 25 | 26 | permutations = {'At': {'insert': [0, 0], 'tst': [0, 'P'], 'idx': [2, 2], 'permute': [4, 5, 0, 1, 2, 3]}, 27 | 'Fly': {'insert': [0, 0], 'idx':0, 'permute': [0, 1, 2, 3, 4, 5,]}, 28 | 'Load': {'insert': [0, 0], 'idx':0, 'permute': [0, 1, 2, 3, 4, 5,]}, 29 | 'In': {'insert': [0, 0], 'idx':0, 'permute': [0, 1, 2, 3, 4, 5,]}, 30 | 'Unload': {'insert': [0, 0], 'idx':0, 'permute': [0, 1, 2, 3, 4, 5,]}, 31 | } 32 | 33 | perm_one = {'At': {'insert': [0], 'tst': [0, 'P'], 'idx': [2, 2], 'permute': [0, 1, 2]}, 34 | 'Fly': {'insert': [0], 'idx':0, 'permute': [0, 1, 2]}, 35 | 'Load': {'insert': [0], 'idx':0, 'permute': [0, 1, 2]}, 36 | 'In': {'insert': [0], 'idx':0, 'permute': [0, 1, 2]}, 37 | 'Unload': {'insert': [0], 'idx':0, 'permute': [0, 1, 2]}, 38 | } 39 | 40 | 41 | class BitData(): 42 | def __init__(self, 43 | num_bits=6, 44 | batch_size=1, 45 | min_length=1, 46 | max_length=1, 47 | min_repeats=1, 48 | max_repeats=2, 49 | norm_max=10, 50 | log_prob_in_bits=False, 51 | time_average_cost=False, 52 | name='repeat_copy',): 53 | self._batch_size = batch_size 54 | self._num_bits = num_bits 55 | self._min_length = min_length 56 | self._max_length = max_length 57 | self._min_repeats = min_repeats 58 | self._max_repeats = max_repeats 59 | self._norm_max = norm_max 60 | self._log_prob_in_bits = log_prob_in_bits 61 | self._time_average_cost = time_average_cost 62 | 63 | def _normilize(self, val): 64 | return val / self._norm_max 65 | 66 | def _unnormalize(self, val): 67 | return val * self._norm_max 68 | 69 | @property 70 | def time_average_cost(self): 71 | return self._time_average_cost 72 | 73 | @property 74 | def log_prob_in_bits(self): 75 | return self._log_prob_in_bits 76 | 77 | @property 78 | def num_bits(self): 79 | """The dimensionality of each random binary vector in a pattern.""" 80 | return self._num_bits 81 | 82 | @property 83 | def target_size(self): 84 | """The dimensionality of the target tensor.""" 85 | return self._num_bits + 1 86 | 87 | @property 88 | def batch_size(self): 89 | return self._batch_size 90 | 91 | def get_item(self): 92 | min_length, max_length = self._min_length, self._max_length 93 | min_reps, max_reps = self._min_repeats, self._max_repeats 94 | num_bits = self.num_bits 95 | batch_size = self.batch_size 96 | 97 | # We reserve one dimension for the num-repeats and one for the start-marker. 98 | full_obs_size = num_bits + 2 99 | # We reserve one target dimension for the end-marker. 100 | full_targ_size = num_bits + 1 101 | start_end_flag_idx = full_obs_size - 2 102 | num_repeats_channel_idx = full_obs_size - 1 103 | 104 | # Samples each batch index's sequence length and the number of repeats 105 | sub_seq_length_batch = torch.from_numpy( 106 | np.random.randint(min_length, max_length + 1, batch_size)).long() 107 | num_repeats_batch = torch.from_numpy( 108 | np.random.randint(min_reps, max_reps + 1, batch_size)).long() 109 | 110 | # Pads all the batches to have the same total sequence length. 111 | total_length_batch = sub_seq_length_batch * (num_repeats_batch + 1) + 3 112 | max_length_batch = total_length_batch.max() 113 | residual_length_batch = max_length_batch - total_length_batch 114 | 115 | obs_batch_shape = [max_length_batch, batch_size, full_obs_size] 116 | targ_batch_shape = [max_length_batch, batch_size, full_targ_size] 117 | mask_batch_trans_shape = [batch_size, max_length_batch] 118 | 119 | obs_tensors = [] 120 | targ_tensors = [] 121 | mask_tensors = [] 122 | 123 | # Generates patterns for each batch element independently. 124 | for batch_index in range(batch_size): 125 | sub_seq_len = sub_seq_length_batch[batch_index] 126 | num_reps = num_repeats_batch[batch_index] 127 | 128 | # The observation pattern is a sequence of random binary vectors. 129 | obs_pattern_shape = [sub_seq_len, num_bits] 130 | # obs_pattern = torch.LongTensor(obs_pattern_shape).uniform_(0, 2).float() 131 | obs_pattern = torch.from_numpy(np.random.randint(0, 2, obs_pattern_shape)).float() 132 | print(obs_pattern.size()) 133 | # The target pattern is the observation pattern repeated n times. 134 | # Some reshaping is required to accomplish the tiling. 135 | targ_pattern_shape = [sub_seq_len * num_reps, num_bits] 136 | flat_obs_pattern = obs_pattern.view(-1) 137 | print(flat_obs_pattern, targ_pattern_shape) 138 | flat_targ_pattern = flat_obs_pattern.expand(num_reps) 139 | targ_pattern = torch.reshape(flat_targ_pattern, targ_pattern_shape) 140 | 141 | # Expand the obs_pattern to have two extra channels for flags. 142 | # Concatenate start flag and num_reps flag to the sequence. 143 | obs_flag_channel_pad = torch.zeros([sub_seq_len, 2]) 144 | obs_start_flag = torch.eye([start_end_flag_idx], full_obs_size).float() 145 | num_reps_flag = torch.eye(num_repeats_channel_idx, full_obs_size) 146 | 147 | # note the concatenation dimensions. 148 | obs = torch.cat([obs_pattern, obs_flag_channel_pad], 1) 149 | obs = torch.cat([obs_start_flag, obs], 0) 150 | obs = torch.cat([obs, num_reps_flag], 0) 151 | 152 | # Now do the same for the targ_pattern (it only has one extra channel). 153 | targ_flag_channel_pad = torch.zeros([sub_seq_len * num_reps, 1]) 154 | targ_end_flag = torch.eye(start_end_flag_idx, full_targ_size) 155 | targ = torch.cat([targ_pattern, targ_flag_channel_pad], 1) 156 | targ = torch.cat([targ, targ_end_flag], 0) 157 | 158 | # Concatenate zeros at end of obs and begining of targ. 159 | # This aligns them s.t. the target begins as soon as the obs ends. 160 | obs_end_pad = torch.zeros([sub_seq_len * num_reps + 1, full_obs_size]) 161 | targ_start_pad = torch.zeros([sub_seq_len + 2, full_targ_size]) 162 | 163 | # The mask is zero during the obs and one during the targ. 164 | mask_off = torch.zeros([sub_seq_len + 2]) 165 | mask_on = torch.ones([sub_seq_len * num_reps + 1]) 166 | 167 | obs = torch.cat([obs, obs_end_pad], 0) 168 | targ = torch.cat([targ_start_pad, targ], 0) 169 | mask = torch.cat([mask_off, mask_on], 0) 170 | 171 | obs_tensors.append(obs) 172 | targ_tensors.append(targ) 173 | mask_tensors.append(mask) 174 | 175 | # End the loop over batch index. 176 | # Compute how much zero padding is needed to make tensors sequences 177 | # the same length for all batch elements. 178 | residual_obs_pad = [ 179 | torch.zeros(residual_length_batch[i], full_obs_size) for i in range(batch_size)] 180 | residual_targ_pad = [ 181 | torch.zeros(residual_length_batch[i], full_targ_size) for i in range(batch_size)] 182 | 183 | residual_mask_pad = [torch.zeros(residual_length_batch[i]) for i in range(batch_size)] 184 | 185 | # Concatenate the pad to each batch element. 186 | obs_tensors = [ 187 | torch.cat([o, p], 0) for o, p in zip(obs_tensors, residual_obs_pad) 188 | ] 189 | targ_tensors = [ 190 | torch.cat([t, p], 0) for t, p in zip(targ_tensors, residual_targ_pad) 191 | ] 192 | mask_tensors = [ 193 | torch.cat([m, p], 0) for m, p in zip(mask_tensors, residual_mask_pad) 194 | ] 195 | 196 | # Concatenate each batch element into a single tensor. 197 | obs = torch.cat(obs_tensors, 1).view(obs_batch_shape) 198 | targ = torch.cat(targ_tensors, 1).view(targ_batch_shape) 199 | mask = torch.cat(mask_tensors, 0).reshape(mask_batch_trans_shape).t() 200 | return obs, targ, mask 201 | 202 | 203 | class AirCargoData(): 204 | """ 205 | Flags for 206 | """ 207 | def __init__(self, num_plane=10, num_cargo=6, batch_size=6, 208 | num_airport=1000, plan_phase=1, cuda=False, 209 | one_hot_size=10, encoding=2, mapping=None, 210 | search_function=astar_search, solve=True): 211 | self.n_plane, self.n_cargo, self.n_airport = num_plane, num_cargo, num_airport 212 | self.plan_len = plan_phase 213 | self.batch_size = batch_size 214 | self.mapping = mapping 215 | self.solve = solve 216 | self.encoding = encoding 217 | self.one_hot_size = [one_hot_size] if type(one_hot_size) == int else one_hot_size 218 | self.search_fn = search_function 219 | 220 | self.search_param = 'h_ignore_preconditions' 221 | self.ents_to_ix, self.ix_to_ents = None, None 222 | 223 | self.STATE, self.INIT_STATE = '', '' 224 | self.current_index, self.cuda = 0, cuda 225 | self.current_problem, self.goals, self.state = None, None, None 226 | 227 | self.phase_oh = torch.eye(len(PHASES)) 228 | self.blnk_vec = torch.zeros(len(EXPRESSIONS) * 2).float() # 229 | self.expr_o_h = torch.cat([torch.zeros([1, len(EXPRESSIONS)]), torch.eye(len(EXPRESSIONS))], 0).float() 230 | self.ents_o_h = None 231 | self.masks = [] 232 | 233 | # new indices 234 | self.goals_idx, self.cargo_in_idx = {}, {} 235 | self.cache, self.encodings = {}, {} 236 | self.make_new_problem() 237 | print(self.plan_len) 238 | 239 | @property 240 | def nn_in_size(self): 241 | return self.blnk_vec.size(-1) + self.phase_oh.size(-1) 242 | 243 | @property 244 | def nn_out_size(self): 245 | return self.blnk_vec.size(-1) 246 | 247 | def lookup_expr_to_ix(self, _expr): 248 | return self.ents_to_ix[str(_expr)] 249 | 250 | def lookup_ix_to_expr(self, ix) -> str: 251 | if ix in self.ix_to_ents: 252 | return self.ix_to_ents[ix] 253 | else: 254 | return "NA" 255 | 256 | def masked_input(self): 257 | state_expr = random.choice(self.pull_state()) 258 | state_vec = self.expr_to_vec(state_expr) 259 | mask_idx = 2 260 | mask_chunk = state_vec[mask_idx] 261 | 262 | zeros = 0 if type(mask_chunk) == int else tuple([0] * len(mask_chunk)) 263 | masked_state_vec = state_vec.copy() 264 | masked_state_vec[2] = zeros 265 | inputs = torch.cat([self.phase_oh[3].unsqueeze(0), self.vec_to_ix(masked_state_vec)], 1) 266 | return inputs, mask_chunk, state_vec 267 | 268 | def generate_encodings(self): 269 | """ 270 | 271 | :param ix_to_ents: 272 | :return: 273 | """ 274 | noops = [torch.zeros(1, i) for i in self.one_hot_size] 275 | self.ents_o_h = [] 276 | for noop, ix_size in zip(noops, self.one_hot_size): 277 | ix_enc = torch.cat([noop, torch.eye(ix_size)], 0).float() 278 | self.ents_o_h.append(ix_enc) 279 | self.blnk_vec = torch.cat([torch.zeros([1, len(EXPRESSIONS)]), torch.cat(noops, 1)], 1) 280 | 281 | def print_solution(self, node): 282 | for action in node.solution(): 283 | print("{}{}".format(action.name, action.args)) 284 | 285 | def best_logic(self, action_exprs): 286 | """ 287 | Astar search takes forever for this problem. 288 | best logical thing to do is handcode the ops needed. 289 | :param action_exprs: 290 | :return: 291 | """ 292 | if self.goals_idx == {}: # the problem has no remaining goals 293 | return [] 294 | best_actions, at_goal = [], [] 295 | for action in action_exprs: 296 | op = action.name 297 | if op == 'Unload': 298 | cargo = action.args[0] 299 | if cargo in self.goals_idx and self.goals_idx[cargo] == action.args[2]: 300 | at_goal.append(action) 301 | continue 302 | elif op == 'Load': 303 | cargo = action.args[0] 304 | # load cargo if it is not in its home 305 | if cargo in self.goals_idx: 306 | at_goal.append(action) 307 | continue 308 | elif op == 'Fly': 309 | plane, airpt, dest_n = action.args 310 | cargos_in_plane = [cargo for cargo, _in in self.cargo_in_idx.items() if _in == plane] 311 | # if there is a cargo in the plane 312 | # and desitination is the cargo's goal 313 | if cargos_in_plane != []: 314 | cargo = cargos_in_plane[0] # assumes one cargo 315 | if cargo in self.goals_idx and self.goals_idx[cargo] == dest_n: 316 | at_goal.append(action) 317 | continue 318 | # if there are no cargos at the airport, fly to destination with GOAL cargo 319 | cargos_in_airpt = [cargo for cargo, _in in self.cargo_in_idx.items() if _in == airpt] 320 | if cargos_in_airpt == []: 321 | cargos_in_dest = [cargo for cargo, _in in self.cargo_in_idx.items() if _in == dest_n] 322 | if cargos_in_dest != []: 323 | best_actions.append(action) 324 | if not at_goal: 325 | return best_actions 326 | else: 327 | return at_goal 328 | 329 | def get_raw_actions(self, mode='best'): 330 | self.current_problem.initial = self.STATE 331 | if mode == 'all': 332 | actions_exprs = self.current_problem.actions(self.STATE) 333 | elif mode == 'one': 334 | actions_exprs = [self.current_problem.one_action(self.STATE)] 335 | elif mode == 'both': 336 | all_actions = self.current_problem.actions(self.STATE) 337 | best_actions = self.best_logic(all_actions) 338 | return best_actions, all_actions 339 | else: 340 | if self.search_param is not None: 341 | prm = getattr(self.current_problem, self.search_param) 342 | solution = self.search_fn(self.current_problem, prm) 343 | actions_exprs = solution.solution()[0:len(self.current_problem.planes)] 344 | else: 345 | solution = self.search_fn(self.current_problem) 346 | actions_exprs = solution.solution()[0:len(self.current_problem.planes)] 347 | return actions_exprs 348 | 349 | def get_actions(self, mode='best'): 350 | if mode == 'both': 351 | best, all = self.get_raw_actions(mode=mode) 352 | return [self.expr_to_vec(a) for a in best], [self.expr_to_vec(a) for a in all] 353 | else: 354 | return [self.expr_to_vec(a) for a in self.get_raw_actions(mode=mode)] 355 | 356 | def encode_action(self, action_obj): 357 | return torch.from_numpy(self.current_problem.encode_action(action_obj)).long() 358 | 359 | def send_action(self, action_vec): 360 | """ 361 | Transform vector into an expression compatible with the problem code. 362 | """ 363 | self.current_problem.initial = self.STATE 364 | sym, args = self.vec_to_expr(action_vec) 365 | actions_ = self.current_problem.actions(self.STATE) 366 | final_act = [] 367 | for a in actions_: 368 | if a.name == sym and all((str(ar) == at) for ar, at in zip(a.args, args)): 369 | final_act = a 370 | break 371 | assert final_act != [] 372 | if final_act.name == 'Load': 373 | cargo, plane, _ = final_act.args 374 | self.cargo_in_idx[cargo] = plane 375 | elif final_act.name == 'Unload': 376 | cargo, _, airpt = final_act.args 377 | if cargo in self.goals_idx and self.goals_idx[cargo] == airpt: 378 | # if this is the goal state, remove from running index 379 | del self.goals_idx[cargo] 380 | del self.cargo_in_idx[cargo] 381 | else: # else add to new airport 382 | self.cargo_in_idx[cargo] = airpt 383 | self.STATE = self.current_problem.result(self.STATE, final_act) 384 | return self.STATE, final_act 385 | 386 | def vec_to_expr(self, _vec): 387 | action_str = EXPRESSIONS[_vec[0]] 388 | args_vec = [] 389 | for idx, value in enumerate(_vec[1:]): 390 | args_vec.append(self.lookup_ix_to_expr(value)) 391 | return action_str, args_vec 392 | 393 | def pull_state(self): 394 | res = [] 395 | for idx, char in enumerate(self.STATE): 396 | if char == 'T': 397 | res.append(self.current_problem.state_map[idx]) 398 | return res 399 | 400 | def vec_to_ix(self, _vec): 401 | """ 402 | Input target vec representing cross entropy loss target [1 0 2 0 0 0 0] 403 | Returns a one hot version of it as training input [01 00, 100, 000, 000, 000] 404 | :param _vec: 405 | :return: 406 | """ 407 | merged = flat(_vec) 408 | action = self.expr_o_h[merged[0]].unsqueeze(0) 409 | ix_ent = [] 410 | merged = merged[1:] 411 | if self.mapping is not None: 412 | expr_str = EXPRESSIONS[_vec[0]] 413 | # print(expr_str) 414 | mp = self.mapping[expr_str] 415 | permute = mp['permute'] if 'permute' in mp else None 416 | insert_ = mp['insert'] if 'insert' in mp else None 417 | test = mp['tst'] if 'tst' in mp else None 418 | idx = mp['idx'] if 'idx' in mp else None 419 | if insert_ is not None and test is not None and idx is not None: 420 | text_ix = _vec[1:][test[0]] 421 | for i, insrt_ent in zip(idx, insert_): 422 | if test[1] in self.lookup_ix_to_expr(text_ix): 423 | merged.insert(i, insrt_ent) 424 | else: 425 | merged.append(insrt_ent) 426 | elif insert_ is not None and idx is None: 427 | for insrt_ent in insert_: 428 | merged.append(insrt_ent) 429 | 430 | if permute is not None: 431 | nperm = np.argsort(permute) 432 | merged = np.asarray(merged)[nperm] 433 | else: 434 | # print(len(self.ents_o_h)) 435 | for idx in range(len(self.ents_o_h) - len(merged)): 436 | merged.append(0) 437 | 438 | for idx, value in enumerate(merged): 439 | ix_ent.append(self.ents_o_h[idx][value].unsqueeze(0)) 440 | return torch.cat([action, torch.cat(ix_ent, 1)], 1) 441 | 442 | def ix_to_vec(self, ix_ent): 443 | action_l = self.expr_o_h.size(-1) 444 | ent_vec = [ix_ent[0:action_l].index(1) + 1] 445 | start = action_l 446 | for idx, ix_size in enumerate(self.one_hot_size): 447 | expr_ix = ix_ent[start:start+ix_size] 448 | if 1 in expr_ix: 449 | ent_vec.append(expr_ix.index(1) + 1) 450 | else: 451 | ent_vec.append(0) 452 | start += ix_size 453 | return ent_vec 454 | 455 | def ix_to_ixs(self, ix_ent, grouping=None): 456 | action_l = self.expr_o_h.size(-1) 457 | ixs_vec = [ix_ent[:, 0:action_l]] 458 | start = action_l 459 | if grouping is None: 460 | grouping = self.one_hot_size 461 | for idx, ix_size in enumerate(grouping): 462 | ixs_vec.append(ix_ent[:, start:start + ix_size]) 463 | start += ix_size 464 | return ixs_vec 465 | 466 | def strip_ix_mask(self, ix_input_vec): 467 | phase_size = self.phase_oh.size(-1) 468 | ixs_vec = ix_input_vec[:, phase_size:] 469 | phase_vec = ix_input_vec[:, :phase_size] 470 | return phase_vec, ixs_vec 471 | 472 | def ix_input_to_ixs(self, ix_input_vec, grouping=None): 473 | """ 474 | 475 | :param ix_input_vec: 476 | :param grouping: 477 | :return: 478 | """ 479 | phase_size = self.phase_oh.size(-1) 480 | ixs_vec = ix_input_vec[:, phase_size:] 481 | return self.ix_to_ixs(ixs_vec, grouping) 482 | 483 | def ix_to_expr(self, ix_input_vec): 484 | vec = self.ix_to_vec(ix_input_vec) 485 | return self.vec_to_expr(vec) 486 | 487 | def expr_to_vec(self, expr_obj): 488 | """ 489 | :param expr_obj: Action Expr object Fly(P0, A1, A2) 490 | :return: action vec [0, 0, 1, 2], and argument permutation 491 | """ 492 | if type(expr_obj) == Action: 493 | exp_name = expr_obj.name 494 | else: 495 | exp_name = expr_obj.op 496 | ent_vec = [exprs_to_ix[exp_name]] 497 | for arg in expr_obj.args: 498 | ent_vec.append(self.lookup_expr_to_ix(arg)) 499 | 500 | return ent_vec 501 | 502 | def gen_state_vec(self, index): 503 | state_expr = self.state[index] 504 | return self.expr_to_vec(state_expr) 505 | 506 | def gen_input_ix(self, _exprs, index): 507 | """ 508 | Generate a one_hot vector at a given index 509 | :param vecs: 510 | :param index: 511 | :return: 512 | """ 513 | _expr = _exprs[index] 514 | ent_vec = self.expr_to_vec(_expr) 515 | return self.vec_to_ix(ent_vec) 516 | 517 | def human_readable(self, inputs, mask=None) -> str: 518 | if mask is None: 519 | phase = PHASES[self.phase_oh[self.masks[self.current_index]]] 520 | elif mask is False: 521 | phase = '' 522 | elif mask is True: 523 | phase_size = self.phase_oh.size(-1) 524 | phase_ix = inputs[:, :phase_size].squeeze() 525 | 526 | inputs = inputs[:, phase_size:] 527 | phase = '' 528 | 529 | elif isinstance(mask, int): 530 | phase = PHASES[mask] 531 | else: 532 | phase = PHASES[mask.squeeze()[0]] 533 | args = [] 534 | txt = '' 535 | if phase: 536 | args.append(phase) 537 | txt += 'Phase {}, ' 538 | 539 | if isinstance(inputs, torch.Tensor): 540 | expr = self.ix_to_expr(inputs) 541 | else: 542 | expr = self.vec_to_expr(vec) 543 | args.append(expr) 544 | txt += 'expr {}' 545 | return txt.format(args) 546 | 547 | def make_new_problem(self): 548 | """ 549 | Set up new problem object 550 | :return: 551 | """ 552 | problem, (e_ix, ix_e), (s, g) = \ 553 | mac.arbitrary_ACP2(self.n_airport, self.n_cargo, self.n_plane, encoding=self.encoding) 554 | self.current_problem = problem 555 | 556 | self.STATE = problem.initial 557 | self.INIT_STATE = problem.initial 558 | self.generate_encodings() 559 | self.state, self.goals = s, g 560 | self.goals_idx = {goal.args[0]: goal.args[1] for goal in self.goals} 561 | self.cargo_in_idx = {st.args[0]: st.args[1] for st in s if st.args[0].op.startswith('C')} 562 | 563 | # if self.solve is True: 564 | # todo something about this maybe deepcopy self and fastsolve? 565 | # prm = getattr(problem, self.search_param) 566 | # solution_node = len(self.search_fn(problem, prm).solution()) 567 | # else: 568 | # solution_node = self.solve 569 | # len = num_ents * 2 + num_ents + num_ents * 2 + num_ents * 4 570 | # 9 * 6 571 | 572 | state = torch.zeros(len(self.state)) 573 | goal = torch.ones(len(self.goals)) 574 | plan = torch.ones(self.plan_len) * 2 575 | resp = torch.ones(self.n_cargo * 4) * 3 576 | 577 | self.masks = torch.cat([state, goal, plan, resp], 0).long() 578 | self.ents_to_ix, self.ix_to_ents = e_ix, ix_e 579 | self.current_index = 0 580 | return self.masks 581 | 582 | def len__(self): 583 | if self.current_index >= self.current_problem[0].size(0): 584 | self.make_new_problem() 585 | return len(self.current_problem[1]) 586 | 587 | def getitem(self, batch=1): 588 | """Returns a problem, [initial-state, goals] 589 | and a runnable solution object [problem, solution_node] 590 | 591 | Otherwise take the target one_hot class mask in form of 592 | [ent1-type, ent1 ....entN, channel] 593 | 594 | """ 595 | if self.current_index >= len(self.masks): 596 | self.make_new_problem() 597 | 598 | phase = self.masks[self.current_index] 599 | if phase == 0: 600 | inputs = self.gen_input_ix(self.state, self.current_index) 601 | elif phase == 1: 602 | inputs = self.gen_input_ix(self.goals, self.current_index - len(self.state)) 603 | elif phase == 2: 604 | inputs = self.blnk_vec 605 | else: 606 | inputs = self.blnk_vec 607 | 608 | self.current_index += batch 609 | mask = self.phase_oh[phase].unsqueeze(0) 610 | return inputs, mask 611 | 612 | def getmask(self, batch=1): 613 | if self.current_index >= len(self.masks): 614 | self.make_new_problem() 615 | phase = self.masks[self.current_index] 616 | self.current_index += batch 617 | mask = self.phase_oh[phase].unsqueeze(0) 618 | return mask.cuda() if self.cuda is True else mask 619 | 620 | def getitem_combined(self, batch=1): 621 | inputs, mask = self.getitem(batch) 622 | combined = torch.cat([mask, inputs], 1) 623 | return combined.cuda() if self.cuda is True else combined 624 | 625 | 626 | 627 | -------------------------------------------------------------------------------- /problem/lp_utils.py: -------------------------------------------------------------------------------- 1 | from .logic import associate 2 | from .utils import expr 3 | 4 | 5 | class FluentState(): 6 | """ state object for planning problems as positive and negative fluents 7 | 8 | """ 9 | 10 | def __init__(self, pos_list, neg_list): 11 | self.pos = pos_list 12 | self.neg = neg_list 13 | 14 | def sentence(self): 15 | return expr(conjunctive_sentence(self.pos, self.neg)) 16 | 17 | def pos_sentence(self): 18 | return expr(conjunctive_sentence(self.pos, [])) 19 | 20 | 21 | def conjunctive_sentence(pos_list, neg_list): 22 | """ returns expr conjuntive sentence given positive and negative fluent lists 23 | 24 | :param pos_list: list of fluents 25 | :param neg_list: list of fluents 26 | :return: expr sentence of fluent conjunction 27 | e.g. "At(C1, SFO) ∧ ~At(P1, SFO)" 28 | """ 29 | clauses = [] 30 | for f in pos_list: 31 | clauses.append(expr("{}".format(f))) 32 | for f in neg_list: 33 | clauses.append(expr("~{}".format(f))) 34 | return associate('&', clauses) 35 | 36 | 37 | def encode_state(fs: FluentState, fluent_map: list) -> str: 38 | """ encode fluents to a string of T/F using mapping 39 | 40 | :param fs: FluentState object 41 | :param fluent_map: ordered list of possible fluents for the problem 42 | :return: str eg. "TFFTFT" string of mapped positive and negative fluents 43 | """ 44 | state_tf = [] 45 | for fluent in fluent_map: 46 | if fluent in fs.pos: 47 | state_tf.append('T') 48 | else: 49 | state_tf.append('F') 50 | return "".join(state_tf) 51 | 52 | 53 | def decode_state(state: str, fluent_map: list) -> FluentState: 54 | """ decode string of T/F as fluent per mapping 55 | 56 | :param state: str eg. "TFFTFT" string of mapped positive and negative fluents 57 | :param fluent_map: ordered list of possible fluents for the problem 58 | :return: fs: FluentState object 59 | 60 | lengths of state string and fluent_map list must be the same 61 | """ 62 | fs = FluentState([], []) 63 | for idx, char in enumerate(state): 64 | if char == 'T': 65 | fs.pos.append(fluent_map[idx]) 66 | else: 67 | fs.neg.append(fluent_map[idx]) 68 | return fs 69 | -------------------------------------------------------------------------------- /problem/my_air_cargo_problems.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from .logic import PropKB 4 | from .planning import Action 5 | from .search import ( 6 | Node, Problem 7 | ) 8 | from .utils import expr 9 | from .lp_utils import ( 10 | FluentState, encode_state, decode_state, 11 | ) 12 | from .my_planning_graph import PlanningGraph 13 | from functools import lru_cache 14 | from random import shuffle 15 | 16 | 17 | class AirCargoProblem(Problem): 18 | def __init__(self, cargos, planes, airports, initial: FluentState, goal: list): 19 | """ 20 | 21 | :param cargos: list of str 22 | cargos in the problem 23 | :param planes: list of str 24 | planes in the problem 25 | :param airports: list of str 26 | airports in the problem 27 | :param initial: FluentState object 28 | positive and negative literal fluents (as expr) describing initial state 29 | :param goal: list of expr 30 | literal fluents required for goal test 31 | """ 32 | self.state_map = initial.pos + initial.neg 33 | self.initial_state_TF = encode_state(initial, self.state_map) 34 | Problem.__init__(self, self.initial_state_TF, goal=goal) 35 | self.cargos = cargos 36 | self.planes = planes 37 | self.airports = airports 38 | self.actions_list = self.get_actions() 39 | 40 | def get_actions(self): 41 | """ 42 | This method creates concrete actions (no variables) for all actions in the problem 43 | domain action schema and turns them into complete Action objects as defined in the 44 | aimacode.planning module. It is computationally expensive to call this method directly; 45 | however, it is called in the constructor and the results cached in the `actions_list` property. 46 | 47 | Returns: 48 | ---------- 49 | list 50 | list of Action objects 51 | """ 52 | def load_actions(): 53 | """Create all concrete Load actions and return a list 54 | 55 | :return: list of Action objects 56 | """ 57 | loads = [] 58 | for p in self.planes: 59 | for a in self.airports: 60 | for c in self.cargos: 61 | precond_pos = [expr("At({}, {})".format(c, a)), 62 | expr("At({}, {})".format(p, a))] 63 | # remove this for tests where plane can load more than cargo 64 | precond_neg = [expr("In({}, {})".format(c1, p)) for c1 in self.cargos] 65 | effect_add = [expr("In({}, {})".format(c, p))] 66 | effect_rem = [expr("At({}, {})".format(c, a))] 67 | act = Action(expr("Load({}, {}, {})".format(c, p, a)), 68 | [precond_pos, precond_neg], 69 | [effect_add, effect_rem]) 70 | loads.append(act) 71 | 72 | return loads 73 | 74 | def unload_actions(): 75 | """Create all concrete Unload actions and return a list 76 | 77 | :return: list of Action objects 78 | """ 79 | unloads = [] 80 | # create all Unload ground actions from the domain Unload action 81 | for p in self.planes: 82 | for a in self.airports: 83 | for c in self.cargos: 84 | precond_pos = [expr("In({}, {})".format(c, p)), 85 | expr("At({}, {})".format(p, a))] 86 | 87 | effect_add = [expr("At({}, {})".format(c, a))] 88 | effect_rem = [expr("In({}, {})".format(c, p))] 89 | act = Action(expr("Unload({}, {}, {})".format(c, p, a)), 90 | [precond_pos, []], 91 | [effect_add, effect_rem]) 92 | unloads.append(act) 93 | 94 | return unloads 95 | 96 | def fly_actions(): 97 | """Create all concrete Fly actions and return a list 98 | 99 | :return: list of Action objects 100 | """ 101 | flys = [] 102 | 103 | for fr in self.airports: 104 | for to in self.airports: 105 | if fr != to: 106 | for p in self.planes: 107 | precond_pos = [expr("At({}, {})".format(p, fr))] 108 | # precond_neg = [] 109 | effect_add = [expr("At({}, {})".format(p, to))] 110 | effect_rem = [expr("At({}, {})".format(p, fr))] 111 | fly = Action(expr("Fly({}, {}, {})".format(p, fr, to)), 112 | [precond_pos, []], 113 | [effect_add, effect_rem]) 114 | flys.append(fly) 115 | return flys 116 | data = load_actions() + unload_actions() + fly_actions() 117 | shuffle(data) 118 | return data 119 | 120 | def one_action(self, state: str) -> list: 121 | """ Return the actions that can be executed in the given state. 122 | 123 | :param state: str 124 | state represented as T/F string of mapped fluents (state variables) 125 | e.g. 'FTTTFF' 126 | :return: list of Action objects 127 | """ 128 | kb = PropKB() 129 | kb.tell(decode_state(state, self.state_map).pos_sentence()) 130 | 131 | for action in self.actions_list: 132 | is_possible = True 133 | for clause in action.precond_pos: 134 | if clause not in kb.clauses: 135 | is_possible = False 136 | break 137 | for clause in action.precond_neg: 138 | if clause in kb.clauses: 139 | is_possible = False 140 | break 141 | if is_possible: 142 | return action 143 | return [] 144 | 145 | def actions(self, state: str) -> list: 146 | possible_actions = [] 147 | kb = PropKB() 148 | kb.tell(decode_state(state, self.state_map).pos_sentence()) 149 | for action in self.actions_list: 150 | is_possible = True 151 | for clause in action.precond_pos: 152 | if clause not in kb.clauses: 153 | is_possible = False 154 | break 155 | for clause in action.precond_neg: 156 | if clause in kb.clauses: 157 | is_possible = False 158 | break 159 | if is_possible: 160 | possible_actions.append(action) 161 | return possible_actions 162 | 163 | def result(self, state: str, action: Action): 164 | """ Return the state that results from executing the given 165 | action in the given state. The action must be one of 166 | self.actions(state). 167 | 168 | :param state: state entering node 169 | :param action: Action applied 170 | :return: resulting state after action 171 | """ 172 | new_state = FluentState([], []) 173 | old_state = decode_state(state, self.state_map) 174 | 175 | for fluent in old_state.pos: 176 | if fluent not in action.effect_rem: 177 | new_state.pos.append(fluent) 178 | 179 | for fluent in action.effect_add: 180 | if fluent not in new_state.pos: 181 | new_state.pos.append(fluent) 182 | 183 | for fluent in old_state.neg: 184 | if fluent not in action.effect_add: 185 | new_state.neg.append(fluent) 186 | 187 | for fluent in action.effect_rem: 188 | if fluent not in new_state.neg: 189 | new_state.neg.append(fluent) 190 | return encode_state(new_state, self.state_map) 191 | 192 | def goal_test(self, state: str) -> bool: 193 | """ Test the state to see if goal is reached 194 | 195 | :param state: str representing state 196 | :return: bool 197 | """ 198 | kb = PropKB() 199 | kb.tell(decode_state(state, self.state_map).pos_sentence()) 200 | for clause in self.goal: 201 | if clause not in kb.clauses: 202 | return False 203 | return True 204 | 205 | def h_1(self, node: Node): 206 | # note that this is not a true heuristic 207 | h_const = 1 208 | return h_const 209 | 210 | @lru_cache(maxsize=8192) 211 | def h_pg_levelsum(self, node: Node): 212 | """This heuristic uses a planning graph representation of the problem 213 | state space to estimate the sum of all actions that must be carried 214 | out from the current state in order to satisfy each individual goal 215 | condition. 216 | """ 217 | pg = PlanningGraph(self, node.state) 218 | pg_levelsum = pg.h_levelsum() 219 | return pg_levelsum 220 | 221 | @lru_cache(maxsize=8192) 222 | def h_ignore_preconditions(self, node: Node): 223 | """This heuristic estimates the minimum number of actions that must be 224 | carried out from the current state in order to satisfy all of the goal 225 | conditions by ignoring the preconditions required for an action to be 226 | executed. 227 | """ 228 | kb = PropKB() 229 | kb.tell(decode_state(node.state, self.state_map).pos_sentence()) 230 | return len([c for c in self.goal if c not in kb.clauses]) 231 | 232 | 233 | def posneg_helper(cargos, planes, airports, pos_exprs): 234 | """given the positive expressions and entities, generate the negative expressions""" 235 | pos_set = set(pos_exprs) 236 | neg = [] 237 | for c in cargos: 238 | for p in planes: 239 | neg.append(expr("In({}, {})".format(c, p))) 240 | for a in airports: 241 | neg.append(expr("At({}, {})".format(c, a))) 242 | for a in airports: 243 | for p in planes: 244 | neg.append(expr("At({}, {})".format(p, a))) 245 | return list(set(neg).difference(pos_set)) 246 | 247 | #def ¬(x): 248 | # print(x) well, python does not allow arbitrary symbols is guess. 249 | #pass 250 | 251 | def air_cargo_p1() -> AirCargoProblem: 252 | cargos = ['C1', 'C2'] 253 | planes = ['P1', 'P2'] 254 | airports = ['JFK', 'SFO'] 255 | pos = [expr('At(C1, SFO)'), 256 | expr('At(C2, JFK)'), 257 | expr('At(P1, SFO)'), 258 | expr('At(P2, JFK)')] 259 | neg = posneg_helper(cargos, planes, airports, pos) 260 | init = FluentState(pos, neg) 261 | goal = [expr('At(C1, JFK)'), 262 | expr('At(C2, SFO)')] 263 | return AirCargoProblem(cargos, planes, airports, init, goal) 264 | 265 | 266 | def air_cargo_p2() -> AirCargoProblem: 267 | cargos = ['C1', 'C2', 'C3'] 268 | planes = ['P1', 'P2', 'P3'] 269 | airports = ['JFK', 'SFO', 'ATL'] 270 | pos = [expr('At(C1, SFO)'), 271 | expr('At(C2, JFK)'), 272 | expr('At(C3, ATL)'), 273 | expr('At(P1, SFO)'), 274 | expr('At(P2, JFK)'), 275 | expr('At(P3, ATL)')] 276 | neg = posneg_helper(cargos, planes, airports, pos) 277 | init = FluentState(pos, neg) 278 | goal = [expr('At(C1, JFK)'), 279 | expr('At(C2, SFO)'), 280 | expr('At(C3, SFO)')] 281 | return AirCargoProblem(cargos, planes, airports, init, goal) 282 | 283 | 284 | def air_cargo_p3() -> AirCargoProblem: 285 | cargos = ['C1', 'C2', 'C3', 'C4'] 286 | planes = ['P1', 'P2'] 287 | airports = ['JFK', 'SFO', 'ATL', 'ORD'] 288 | pos = [expr('At(C1, SFO)'), 289 | expr('At(C2, JFK)'), 290 | expr('At(C3, ATL)'), 291 | expr('At(C4, ORD)'), 292 | expr('At(P1, SFO)'), 293 | expr('At(P2, JFK)')] 294 | neg = posneg_helper(cargos, planes, airports, pos) 295 | init = FluentState(pos, neg) 296 | goal = [expr('At(C1, JFK)'), 297 | expr('At(C2, SFO)'), 298 | expr('At(C3, JFK)'), 299 | expr('At(C4, SFO)')] 300 | return AirCargoProblem(cargos, planes, airports, init, goal) 301 | 302 | ##################################################################### 303 | ### GENERATORS 304 | ##################################################################### 305 | # this is all for generating DNC data 306 | 307 | import numpy as np 308 | import random 309 | from random import randint 310 | random.seed() 311 | 312 | 313 | def bin_gen(num, one_hot_size): 314 | """In case I need to do a binary encoding vs one_hot 315 | """ 316 | int(bin(num)[2:].zfill(one_hot_size)) 317 | 318 | 319 | 320 | def air_cargo_generator_v1(num_plane, num_airport, num_cargo, 321 | one_hot_size=None, vec_fn=None): 322 | """Init: 323 | each plane must be at airport 324 | each cargo must be at an airport 325 | each cargo has a destination 326 | 327 | Returned tuples of as one hot vecs of equal size 328 | zeros vector represents nothing at that slot. 329 | so 'At(C1, SFO)' with one_hot_size = 4 -> 330 | S, Act, A, P, C 331 | +---------------+ 332 | | 1-0 2-0 3-0 | 333 | +---------------+ 334 | 335 | 'state' size becomes max A * P * 1h + A * C * 1h (4, 4, 4, 10) 336 | """ 337 | if one_hot_size is None: 338 | one_hot_size = max([num_airport, num_cargo, num_plane]) + 3 339 | if vec_fn is None: 340 | vec_fn = np.eye 341 | 342 | assert num_airport <= one_hot_size 343 | encoding_vec = vec_fn(one_hot_size) 344 | 345 | o_h_planes = encoding_vec[:num_plane].astype(int) 346 | o_h_airprt = encoding_vec[:num_airport].astype(int) 347 | o_h_cargos = encoding_vec[:num_cargo].astype(int) 348 | o_h_types = vec_fn(3).astype(int) 349 | 350 | zeros = np.zeros(one_hot_size, dtype=int) 351 | zero_type = np.zeros(3, dtype=int) 352 | zero_vec = np.concatenate([zero_type, zeros], axis=0) 353 | 354 | o_h_state, o_h_goals = [], [] 355 | idx_state, idx_goals = [], [] 356 | airprt_dict, cargos_dict, planes_dict = {}, {}, {} 357 | 358 | for idx, o_h_airpr in enumerate(o_h_airprt): 359 | o_h_airpr = np.concatenate([o_h_types[0], o_h_airpr], axis=0) 360 | airprt_dict[str(o_h_airpr)] = 'A{}'.format(idx) 361 | 362 | #planes can go wherever 363 | for idx, o_h_plane in enumerate(o_h_planes): 364 | #generate random 365 | airprt_idx = randint(0, num_airport - 1) 366 | #make one_hot vecs 367 | airprt_vec = np.concatenate([o_h_types[0], o_h_airprt[airprt_idx]], axis=0) 368 | plane_vec = np.concatenate([o_h_types[1], o_h_plane], axis=0) 369 | #make final state vecs and add to states. 370 | oh_plane = np.concatenate([airprt_vec, plane_vec, zero_vec], axis=0) 371 | o_h_state.append(oh_plane) 372 | #lookup idxs 373 | pl_ent = 'P{}'.format(idx) 374 | planes_dict[str(plane_vec)] = pl_ent 375 | #make state dicts and add to states. 376 | ar_ent = airprt_dict[str(airprt_vec)] 377 | idx_state.append('At({}, {})'.format(pl_ent, ar_ent)) 378 | 379 | # cargos start end end are mutually exclusive 380 | # for porpises of this problem. Erp erp. 381 | for idx, o_h_cargo in enumerate(o_h_cargos): 382 | # generate random 383 | init_idx = random.randint(0, num_airport-1) 384 | allowed_values = set(range(num_airport)) 385 | allowed_values.discard(init_idx) 386 | goal_idx = random.choice(list(allowed_values)) 387 | # make one_hot vecs 388 | air_in_vec = np.concatenate([o_h_types[0], o_h_airprt[init_idx]], axis=0) 389 | air_gl_vec = np.concatenate([o_h_types[0], o_h_airprt[goal_idx]], axis=0) 390 | cargo_vec = np.concatenate([o_h_types[2], o_h_cargo], axis=0) 391 | # make final state vecs and add to states. 392 | oh_init = np.concatenate([air_in_vec, zero_vec, cargo_vec], axis=0) 393 | oh_goal = np.concatenate([air_gl_vec, zero_vec, cargo_vec], axis=0) 394 | o_h_state.append(oh_init) 395 | o_h_goals.append(oh_goal) 396 | # lookup idxs 397 | cr_ent = 'C{}'.format(idx) 398 | cargos_dict[str(cargo_vec)] = cr_ent 399 | state_ar_ent = airprt_dict[str(air_in_vec)] 400 | goals_ar_ent = airprt_dict[str(air_gl_vec)] 401 | # make state dicts and add to states. 402 | idx_state.append('At({}, {})'.format(cr_ent, state_ar_ent)) 403 | idx_goals.append('At({}, {})'.format(cr_ent, goals_ar_ent)) 404 | 405 | return [[np.asarray(o_h_state), np.asarray(o_h_goals)], 406 | [idx_state, idx_goals], 407 | [airprt_dict, cargos_dict, planes_dict]] 408 | 409 | def reverse_lookup(problem_ents, o_h_idx): 410 | key = next(key for key, value in problem_ents.items() if np.array_equal(value, o_h_idx)) 411 | return key 412 | 413 | noop = np.asarray([0, 0], dtype=int) 414 | 415 | 416 | def air_cargo_generator_v2(num_airport, num_cargo, num_plane, 417 | one_hot_size=None, vec_fn=None): 418 | """Init: 419 | each plane must be at airport 420 | each cargo must be at an airport 421 | each cargo has a destination 422 | 423 | Returned typed state expressions with type and instance of each object 424 | 0-0 indicates 425 | A P C 426 | +---------------+ 427 | | 1-0 2-0 3-0 | 428 | +---------------+ 429 | """ 430 | if one_hot_size is None: 431 | one_hot_size = max([num_airport, num_cargo, num_plane]) + 3 432 | if vec_fn is None: 433 | vec_fn = np.eye 434 | 435 | assert num_airport <= one_hot_size 436 | print(num_airport) 437 | o_h_airprt = np.asarray([[0, i] for i in range(num_airport)], dtype=int) 438 | o_h_planes = np.asarray([[1, i] for i in range(num_plane)], dtype=int) 439 | o_h_cargos = np.asarray([[2, i] for i in range(num_cargo)], dtype=int) 440 | 441 | o_h_state, o_h_goals = [], [] 442 | idx_state, idx_goals = [], [] 443 | airprt_dict, cargos_dict, planes_dict = {}, {}, {} 444 | 445 | for idx, o_h_airpr in enumerate(o_h_airprt): 446 | airprt_dict['A{}'.format(idx)] = o_h_airpr 447 | 448 | # planes can go wherever 449 | for idx, o_h_plane in enumerate(o_h_planes): 450 | # generate random 451 | airprt_idx = randint(0, num_airport - 1) 452 | # make final state vecs and add to states. 453 | airport_vec = o_h_airprt[airprt_idx] 454 | oh_plane_state = np.concatenate([airport_vec, o_h_plane, noop], axis=0) 455 | o_h_state.append(oh_plane_state) 456 | # lookup idxs 457 | pl_ent = 'P{}'.format(idx) 458 | planes_dict[pl_ent] = o_h_plane 459 | # make state dicts and add to states. 460 | ar_ent = reverse_lookup(airprt_dict, airport_vec) 461 | idx_state.append('At({}, {})'.format(pl_ent, ar_ent)) 462 | 463 | # cargos start end end are mutually exclusive 464 | # for porpises of this problem. Erp erp. 465 | for idx, o_h_cargo in enumerate(o_h_cargos): 466 | # generate random 467 | init_idx = random.randint(0, num_airport - 1) 468 | allowed_values = set(range(num_airport)) 469 | # 470 | allowed_values.discard(init_idx) 471 | goal_idx = random.choice(list(allowed_values)) 472 | # make one_hot vecs 473 | air_in_vec = o_h_airprt[init_idx] 474 | air_gl_vec = o_h_airprt[goal_idx] 475 | # make final state vecs and add to states. 476 | oh_init = np.concatenate([air_in_vec, noop, o_h_cargo], axis=0) 477 | oh_goal = np.concatenate([air_gl_vec, noop, o_h_cargo], axis=0) 478 | o_h_state.append(oh_init) 479 | o_h_goals.append(oh_goal) 480 | # lookup idxs 481 | cr_ent = 'C{}'.format(idx) 482 | cargos_dict[cr_ent] = o_h_cargo 483 | # make state dicts and add to states. 484 | idx_state.append('At({}, {})'.format(cr_ent, reverse_lookup(airprt_dict, air_in_vec))) 485 | idx_goals.append('At({}, {})'.format(cr_ent, reverse_lookup(airprt_dict, air_gl_vec))) 486 | 487 | return [[np.asarray(o_h_state), np.asarray(o_h_goals)], 488 | [idx_state, idx_goals], 489 | [airprt_dict, cargos_dict, planes_dict]] 490 | 491 | 492 | def ent_(label, num): 493 | return '{}{}'.format(label, num) 494 | 495 | 496 | def entity_ix_generator(label, num_ents, start=0): 497 | start += 1 498 | ents_to_ix = {ent_(label, idx): i for idx, i in enumerate(range(start, start + num_ents))} 499 | ix_to_ents = {i: ent_(label, idx) for idx, i in enumerate(range(start, start + num_ents))} 500 | return ents_to_ix, ix_to_ents 501 | 502 | 503 | def entity_2ix_generator(label, type_num, num_ents, start=0): 504 | start += 1 505 | ents_to_ix = {ent_(label, idx): (type_num, i) for idx, i in enumerate(range(start, start + num_ents))} 506 | ix_to_ents = {(type_num, i): ent_(label, idx) for idx, i in enumerate(range(start, start + num_ents))} 507 | return ents_to_ix, ix_to_ents 508 | 509 | 510 | def air_cargo_generator_v3(num_airport, num_cargo, num_plane, encoding=1): 511 | """Init: 512 | each plane must be at airport 513 | each cargo must be at an airport 514 | each cargo has a destination 515 | 516 | Returned typed state expressions with type and instance of each object 517 | 0-0 indicates 518 | A P C 519 | +---------------+ 520 | | 1-0 2-0 3-0 | 521 | +---------------+ 522 | """ 523 | if encoding == 1: 524 | airpt_to_ix, ix_to_airpt = entity_ix_generator("A", num_airport) 525 | cargo_to_ix, ix_to_cargo = entity_ix_generator("C", num_cargo, start=num_airport) 526 | plane_to_ix, ix_to_plane = entity_ix_generator("P", num_plane, start=num_cargo + num_airport) 527 | else: 528 | airpt_to_ix, ix_to_airpt = entity_2ix_generator("A", 1, num_airport) 529 | cargo_to_ix, ix_to_cargo = entity_2ix_generator("C", 2, num_cargo) 530 | plane_to_ix, ix_to_plane = entity_2ix_generator("P", 3, num_plane) 531 | 532 | state_exprs, goal_exprs = [], [] 533 | ents_to_ix = {**airpt_to_ix, **cargo_to_ix, **plane_to_ix} 534 | ix_to_ents = {**ix_to_airpt, **ix_to_cargo, **ix_to_plane} 535 | 536 | # planes can go wherever 537 | for idx, plane in ix_to_plane.items(): 538 | # find an airport to put the plane at 539 | airprt_idx = random.choice(list(airpt_to_ix.keys())) 540 | state_exprs.append('At({}, {})'.format(plane, airprt_idx)) 541 | 542 | # cargos start end end are mutually exclusive 543 | # for porpises of this problem. Erp erp. 544 | for idx, cargo in ix_to_cargo.items(): 545 | 546 | # generate random 547 | allowed_values = set(list(airpt_to_ix.keys())) 548 | init_idx = random.choice(list(allowed_values)) 549 | 550 | # set a goal airport 551 | allowed_values.discard(init_idx) 552 | goal_idx = random.choice(list(allowed_values)) 553 | 554 | # make state dicts and add to states. 555 | state_exprs.append('At({}, {})'.format(cargo, init_idx)) 556 | goal_exprs.append('At({}, {})'.format(cargo, goal_idx)) 557 | 558 | return [[ents_to_ix, ix_to_ents], 559 | [state_exprs, goal_exprs]] 560 | 561 | 562 | 563 | def arbitrary_ACP(n_airport, n_cargo, n_plane, one_hot_size=None): 564 | """ 565 | Generate ACP of arbitrary specified size in a roundabout way. 566 | Problem Object goes to engine, 567 | one_hot vecs go to dnc for training. 568 | """ 569 | # create one hot and entity and state dictionaries 570 | o_h, k_b, ent_dic = air_cargo_generator_v2(n_airport, n_cargo, n_plane, 571 | one_hot_size=one_hot_size) 572 | 573 | airprt_dict, cargos_dict, planes_dict = ent_dic 574 | mp_merged = {**airprt_dict, **cargos_dict, **planes_dict} 575 | 576 | airports = list(airprt_dict.keys()) 577 | cargos = list(cargos_dict.keys()) 578 | planes = list(planes_dict.keys()) 579 | 580 | init_exprs, goal_exprs = k_b 581 | o_h_state, o_h_goals = o_h 582 | 583 | pos = [expr(x) for x in init_exprs] 584 | neg = posneg_helper(cargos, planes, airports, pos) 585 | init = FluentState(pos, neg) 586 | 587 | goal = [expr(x) for x in goal_exprs] 588 | acp = AirCargoProblem(cargos, planes, airports, init, goal) 589 | return (acp, o_h_state, o_h_goals, mp_merged) 590 | 591 | 592 | def arbitrary_ACP2(n_airport, n_cargo, n_plane, encoding=1): 593 | """ 594 | Generate ACP of arbitrary specified size in a roundabout way. 595 | Problem Object goes to engine, 596 | one_hot vecs go to dnc for training. 597 | """ 598 | # create one hot and entity and state dictionaries 599 | (e_ix, ix_e), (state, goals) = air_cargo_generator_v3(n_airport, n_cargo, n_plane, encoding) 600 | 601 | airports = [e for e in e_ix.keys() if 'A' in e] 602 | cargos = [e for e in e_ix.keys() if 'C' in e] 603 | planes = [e for e in e_ix.keys() if 'P' in e] 604 | 605 | pos = [expr(x) for x in state] 606 | neg = posneg_helper(cargos, planes, airports, pos) 607 | init = FluentState(pos, neg) 608 | 609 | goal = [expr(x) for x in goals] 610 | acp = AirCargoProblem(cargos, planes, airports, init, goal) 611 | return acp, (e_ix, ix_e), (pos, goal) -------------------------------------------------------------------------------- /problem/my_planning_graph.py: -------------------------------------------------------------------------------- 1 | from .planning import Action 2 | from .search import Problem 3 | from .utils import expr 4 | from .lp_utils import decode_state 5 | #import svgwrite 6 | 7 | class PgNode(): 8 | """Base class for planning graph nodes. 9 | 10 | includes instance sets common to both types of nodes used in a planning graph 11 | parents: the set of nodes in the previous level 12 | children: the set of nodes in the subsequent level 13 | mutex: the set of sibling nodes that are mutually exclusive with this node 14 | """ 15 | 16 | def __init__(self): 17 | self.parents = set() 18 | self.children = set() 19 | self.mutex = set() 20 | 21 | def is_mutex(self, other) -> bool: 22 | """Boolean test for mutual exclusion 23 | 24 | :param other: PgNode 25 | the other node to compare with 26 | :return: bool 27 | True if this node and the other are marked mutually exclusive (mutex) 28 | """ 29 | if other in self.mutex: 30 | return True 31 | return False 32 | 33 | def show(self): 34 | """helper print for debugging shows counts of parents, children, siblings 35 | 36 | :return: 37 | print only 38 | """ 39 | print("{} parents".format(len(self.parents))) 40 | print("{} children".format(len(self.children))) 41 | print("{} mutex".format(len(self.mutex))) 42 | 43 | 44 | class PgNode_s(PgNode): 45 | """A planning graph node representing a state (literal fluent) from a 46 | planning problem. 47 | 48 | Args: 49 | ---------- 50 | symbol : str 51 | A string representing a literal expression from a planning problem 52 | domain. 53 | 54 | is_pos : bool 55 | Boolean flag indicating whether the literal expression is positive or 56 | negative. 57 | """ 58 | 59 | def __init__(self, symbol: str, is_pos: bool): 60 | """S-level Planning Graph node constructor 61 | 62 | :param symbol: expr 63 | :param is_pos: bool 64 | Instance variables calculated: 65 | literal: expr 66 | fluent in its literal form including negative operator if applicable 67 | Instance variables inherited from PgNode: 68 | parents: set of nodes connected to this node in previous A level; initially empty 69 | children: set of nodes connected to this node in next A level; initially empty 70 | mutex: set of sibling S-nodes that this node has mutual exclusion with; initially empty 71 | """ 72 | PgNode.__init__(self) 73 | self.symbol = symbol 74 | self.is_pos = is_pos 75 | self.__hash = None 76 | 77 | def show(self): 78 | """helper print for debugging shows literal plus counts of parents, 79 | children, siblings 80 | 81 | :return: 82 | print only 83 | """ 84 | if self.is_pos: 85 | print("\n*** {}".format(self.symbol)) 86 | else: 87 | print("\n*** ~{}".format(self.symbol)) 88 | PgNode.show(self) 89 | 90 | def __eq__(self, other): 91 | """equality test for nodes - compares only the literal for equality 92 | 93 | :param other: PgNode_s 94 | :return: bool 95 | """ 96 | return (isinstance(other, self.__class__) and 97 | self.is_pos == other.is_pos and 98 | self.symbol == other.symbol) 99 | 100 | def __hash__(self): 101 | self.__hash = self.__hash or hash(self.symbol) ^ hash(self.is_pos) 102 | return self.__hash 103 | 104 | 105 | class PgNode_a(PgNode): 106 | """A-type (action) Planning Graph node - inherited from PgNode """ 107 | 108 | 109 | def __init__(self, action: Action): 110 | """A-level Planning Graph node constructor 111 | 112 | :param action: Action 113 | a ground action, i.e. this action cannot contain any variables 114 | Instance variables calculated: 115 | An A-level will always have an S-level as its parent and an S-level as its child. 116 | The preconditions and effects will become the parents and children of the A-level node 117 | However, when this node is created, it is not yet connected to the graph 118 | prenodes: set of *possible* parent S-nodes 119 | effnodes: set of *possible* child S-nodes 120 | is_persistent: bool True if this is a persistence action, i.e. a no-op action 121 | Instance variables inherited from PgNode: 122 | parents: set of nodes connected to this node in previous S level; initially empty 123 | children: set of nodes connected to this node in next S level; initially empty 124 | mutex: set of sibling A-nodes that this node has mutual exclusion with; initially empty 125 | """ 126 | PgNode.__init__(self) 127 | self.action = action 128 | self.prenodes = self.precond_s_nodes() 129 | self.effnodes = self.effect_s_nodes() 130 | self.is_persistent = self.prenodes == self.effnodes 131 | self.__hash = None 132 | 133 | def show(self): 134 | """helper print for debugging shows action plus counts of parents, children, siblings 135 | 136 | :return: 137 | print only 138 | """ 139 | print("\n*** {!s}".format(self.action)) 140 | PgNode.show(self) 141 | 142 | def precond_s_nodes(self): 143 | """precondition literals as S-nodes (represents possible parents for this node). 144 | It is computationally expensive to call this function; it is only called by the 145 | class constructor to populate the `prenodes` attribute. 146 | 147 | :return: set of PgNode_s 148 | """ 149 | nodes = set() 150 | for p in self.action.precond_pos: 151 | nodes.add(PgNode_s(p, True)) 152 | for p in self.action.precond_neg: 153 | nodes.add(PgNode_s(p, False)) 154 | return nodes 155 | 156 | def effect_s_nodes(self): 157 | """effect literals as S-nodes (represents possible children for this node). 158 | It is computationally expensive to call this function; it is only called by the 159 | class constructor to populate the `effnodes` attribute. 160 | 161 | :return: set of PgNode_s 162 | """ 163 | nodes = set() 164 | for e in self.action.effect_add: 165 | nodes.add(PgNode_s(e, True)) 166 | for e in self.action.effect_rem: 167 | nodes.add(PgNode_s(e, False)) 168 | return nodes 169 | 170 | def __eq__(self, other): 171 | """equality test for nodes - compares only the action name for equality 172 | 173 | :param other: PgNode_a 174 | :return: bool 175 | """ 176 | return (isinstance(other, self.__class__) and 177 | self.is_persistent == other.is_persistent and 178 | self.action.name == other.action.name and 179 | self.action.args == other.action.args) 180 | 181 | def __hash__(self): 182 | self.__hash = self.__hash or hash(self.action.name) ^ hash(self.action.args) 183 | return self.__hash 184 | 185 | 186 | 187 | 188 | def mutexify(node1: PgNode, node2: PgNode): 189 | """ adds sibling nodes to each other's mutual exclusion (mutex) set. These should be sibling nodes! 190 | 191 | :param node1: PgNode (or inherited PgNode_a, PgNode_s types) 192 | :param node2: PgNode (or inherited PgNode_a, PgNode_s types) 193 | :return: 194 | node mutex sets modified 195 | """ 196 | if type(node1) != type(node2): 197 | raise TypeError('Attempted to mutex two nodes of different types') 198 | node1.mutex.add(node2) 199 | node2.mutex.add(node1) 200 | 201 | def crude_inspect(obj): 202 | for prop, value in vars(obj).iteritems(): 203 | print(prop) 204 | print(value) 205 | 206 | 207 | class PlanningGraph(): 208 | """ 209 | A planning graph as described in chapter 10 of the AIMA text. The planning 210 | graph can be used to reason about 211 | """ 212 | 213 | def __init__(self, problem: Problem, state: str, serial_planning=True): 214 | """ 215 | :param problem: PlanningProblem (or subclass such as AirCargoProblem or HaveCakeProblem) 216 | :param state: str (will be in form TFTTFF... representing fluent states) 217 | :param serial_planning: bool (whether or not to assume that only one 218 | action can occur at a time) 219 | Instance variable calculated: 220 | fs: FluentState 221 | the state represented as positive and negative fluent literal lists 222 | all_actions: list of the PlanningProblem valid ground actions 223 | combined with calculated no-op actions 224 | s_levels: list of sets of PgNode_s, where each set in the list 225 | represents an S-level in the planning graph 226 | a_levels: list of sets of PgNode_a, where each set in the list 227 | represents an A-level in the planning graph 228 | """ 229 | self.problem = problem 230 | self.fs = decode_state(state, problem.state_map) 231 | self.serial = serial_planning 232 | self.all_actions = self.problem.actions_list + self.noop_actions(self.problem.state_map) 233 | self.s_levels = [] 234 | self.a_levels = [] 235 | self.create_graph() 236 | 237 | def show_all(self): 238 | print("\n------problem-----") 239 | print(self.problem) 240 | print("\n------FluenStates------") 241 | print("------ pos------") 242 | print(self.fs.pos) 243 | print("------ neg------") 244 | print(self.fs.neg) 245 | 246 | print("\n------S_levels------") 247 | for sl in self.s_levels: 248 | for s in sl: 249 | s.show() 250 | 251 | print("\n------A_levels------") 252 | for a in self.a_levels: 253 | print(a) 254 | 255 | print("\n------actions------") 256 | for action in self.all_actions: 257 | print(action) 258 | print("-------------------------------------------------------") 259 | 260 | def noop_actions(self, literal_list): 261 | """create persistent action for each possible fluent 262 | 263 | "No-Op" actions are virtual actions (i.e., actions that only exist in 264 | the planning graph, not in the planning problem domain) that operate 265 | on each fluent (literal expression) from the problem domain. No op 266 | actions "pass through" the literal expressions from one level of the 267 | planning graph to the next. 268 | 269 | The no-op action list requires both a positive and a negative action 270 | for each literal expression. Positive no-op actions require the literal 271 | as a positive precondition and add the literal expression as an effect 272 | in the output, and negative no-op actions require the literal as a 273 | negative precondition and remove the literal expression as an effect in 274 | the output. 275 | 276 | This function should only be called by the class constructor. 277 | 278 | :param literal_list: 279 | :return: list of Action 280 | """ 281 | action_list = [] 282 | for fluent in literal_list: 283 | act1 = Action(expr("Noop_pos({})".format(fluent)), ([fluent], []), ([fluent], [])) 284 | action_list.append(act1) 285 | act2 = Action(expr("Noop_neg({})".format(fluent)), ([], [fluent]), ([], [fluent])) 286 | action_list.append(act2) 287 | return action_list 288 | 289 | def create_graph(self): 290 | """ build a Planning Graph as described in Russell-Norvig 3rd Ed 10.3 or 2nd Ed 11.4 291 | 292 | The S0 initial level has been implemented for you. It has no parents and includes all of 293 | the literal fluents that are part of the initial state passed to the constructor. At the 294 | start of a problem planning search, this will be the same as the initial state of the prob. 295 | However, the planning graph can be built from any state in the Planning Problem 296 | 297 | This function should only be called by the class constructor. 298 | 299 | :return: 300 | builds the graph by filling s_levels[] and a_levels[] 301 | lists with node sets for each level 302 | """ 303 | # the graph should only be built during class construction 304 | if (len(self.s_levels) != 0) or (len(self.a_levels) != 0): 305 | raise Exception( 306 | 'Planning Graph already created; construct a new planning graph for each new state in the planning sequence') 307 | 308 | # initialize S0 to literals in initial state provided. 309 | leveled = False 310 | level = 0 311 | self.s_levels.append(set()) # S0 set of s_nodes - empty to start 312 | # for each fluent in the initial state, add the correct literal PgNode_s 313 | for literal in self.fs.pos: 314 | self.s_levels[level].add(PgNode_s(literal, True)) 315 | for literal in self.fs.neg: 316 | self.s_levels[level].add(PgNode_s(literal, False)) 317 | 318 | # continue to build the graph alternating A, S levels until 319 | # last two S levels contain the same literals, i.e. until it is "leveled" 320 | while not leveled: 321 | self.add_action_level(level) 322 | self.update_a_mutex(self.a_levels[level]) 323 | 324 | level += 1 325 | self.add_literal_level(level) 326 | self.update_s_mutex(self.s_levels[level]) 327 | 328 | if self.s_levels[level] == self.s_levels[level - 1]: 329 | leveled = True 330 | 331 | def get_levels(self): 332 | return [[s.symbol for s in s_level] for s_level in self.s_levels] 333 | 334 | def add_action_level(self, level): 335 | """ add an A (action) level to the Planning Graph 336 | 337 | :param level: int 338 | the level number alternates S0, A0, S1, A1, S2, .... etc 339 | the level number is also used as the 340 | index for the node set lists self.a_levels[] and self.s_levels[] 341 | :return: 342 | adds A nodes to the current level in self.a_levels[level] 343 | """ 344 | self.a_levels.append(set()) 345 | for s_literal in self.s_levels[level]: 346 | for action in self.all_actions: 347 | node_a = PgNode_a(action) 348 | if s_literal in node_a.prenodes: 349 | node_a.parents.add(s_literal) 350 | s_literal.children.add(node_a) 351 | self.a_levels[level].add(node_a) 352 | 353 | def add_literal_level(self, level): 354 | """ add an S (literal) level to the Planning Graph 355 | 356 | :param level: int 357 | the level number alternates S0, A0, S1, A1, S2, .... etc the level number is also used as the 358 | index for the node set lists self.a_levels[] and self.s_levels[] 359 | :return: 360 | adds S nodes to the current level in self.s_levels[level] 361 | """ 362 | self.s_levels.append(set()) 363 | for a_node in self.a_levels[level-1]: 364 | for eff_literal in a_node.effnodes: 365 | a_node.children.add(eff_literal) 366 | eff_literal.parents.add(a_node) 367 | self.s_levels[level].add(eff_literal) 368 | 369 | def update_a_mutex(self, nodeset): 370 | """ Determine and update sibling mutual exclusion for A-level nodes 371 | 372 | Mutex action tests section from 3rd Ed. 10.3 or 2nd Ed. 11.4 373 | A mutex relation holds between two actions a given level 374 | if the planning graph is a serial planning graph and the pair are nonpersistence actions 375 | or if any of the three conditions hold between the pair: 376 | Inconsistent Effects 377 | Interference 378 | Competing needs 379 | 380 | :param nodeset: set of PgNode_a (siblings in the same level) 381 | :return: 382 | mutex set in each PgNode_a in the set is appropriately updated 383 | """ 384 | nodelist = list(nodeset) 385 | for i, n1 in enumerate(nodelist[:-1]): 386 | for n2 in nodelist[i + 1:]: 387 | if (self.serialize_actions(n1, n2) or 388 | self.inconsistent_effects_mutex(n1, n2) or 389 | self.interference_mutex(n1, n2) or 390 | self.competing_needs_mutex(n1, n2)): 391 | mutexify(n1, n2) 392 | 393 | def serialize_actions(self, node_a1: PgNode_a, node_a2: PgNode_a) -> bool: 394 | """ 395 | Test a pair of actions for mutual exclusion, returning True if the 396 | planning graph is serial, and if either action is persistent; otherwise 397 | return False. Two serial actions are mutually exclusive if they are 398 | both non-persistent. 399 | 400 | :param node_a1: PgNode_a 401 | :param node_a2: PgNode_a 402 | :return: bool 403 | """ 404 | if not self.serial: 405 | return False 406 | if node_a1.is_persistent or node_a2.is_persistent: 407 | return False 408 | return True 409 | 410 | def inconsistent_effects_mutex(self, node_a1: PgNode_a, node_a2: PgNode_a) -> bool: 411 | """ 412 | Test a pair of actions for inconsistent effects, returning True if 413 | one action negates an effect of the other, and False otherwise. 414 | 415 | HINT: The Action instance associated with an action node is accessible 416 | through the PgNode_a.action attribute. See the Action class 417 | documentation for details on accessing the effects and preconditions of 418 | an action. 419 | 420 | :param node_a1: PgNode_a 421 | :param node_a2: PgNode_a 422 | :return: bool 423 | """ 424 | a1 = node_a1.action 425 | a2 = node_a2.action 426 | 427 | if (any(set(a1.effect_add).intersection(a2.effect_rem)) or 428 | any(set(a2.effect_add).intersection(a1.effect_rem)) or 429 | any(set(a1.effect_rem).intersection(a2.effect_add)) or 430 | any(set(a2.effect_rem).intersection(a1.effect_add))): 431 | return True 432 | return False 433 | 434 | def interference_mutex(self, node_a1: PgNode_a, node_a2: PgNode_a) -> bool: 435 | """ 436 | Test a pair of actions for mutual exclusion, returning True if the 437 | effect of one action is the negation of a precondition of the other. 438 | 439 | HINT: The Action instance associated with an action node is accessible 440 | through the PgNode_a.action attribute. See the Action class 441 | documentation for details on accessing the effects and preconditions of 442 | an action. 443 | 444 | :param node_a1: PgNode_a 445 | :param node_a2: PgNode_a 446 | :return: bool 447 | """ 448 | a1 = node_a1.action 449 | a2 = node_a2.action 450 | 451 | if (any(set(a1.effect_add).intersection(a2.precond_neg)) or 452 | any(set(a2.effect_add).intersection(a1.precond_neg)) or 453 | any(set(a1.effect_rem).intersection(a2.precond_pos)) or 454 | any(set(a2.effect_rem).intersection(a1.precond_pos))): 455 | return True 456 | return False 457 | 458 | def competing_needs_mutex(self, node_a1: PgNode_a, node_a2: PgNode_a) -> bool: 459 | """ 460 | Test a pair of actions for mutual exclusion, returning True if one of 461 | the precondition of one action is mutex with a precondition of the 462 | other action. 463 | 464 | :param node_a1: PgNode_a 465 | :param node_a2: PgNode_a 466 | :return: bool 467 | """ 468 | for pre_a1 in node_a1.parents: 469 | for pre_a2 in node_a2.parents: 470 | if pre_a1.is_mutex(pre_a2): 471 | return True 472 | return False 473 | 474 | def update_s_mutex(self, nodeset: set): 475 | """ Determine and update sibling mutual exclusion for S-level nodes 476 | 477 | Mutex action tests section from 3rd Ed. 10.3 or 2nd Ed. 11.4 478 | A mutex relation holds between literals at a given level 479 | if either of the two conditions hold between the pair: 480 | Negation 481 | Inconsistent support 482 | 483 | :param nodeset: set of PgNode_a (siblings in the same level) 484 | :return: 485 | mutex set in each PgNode_a in the set is appropriately updated 486 | """ 487 | nodelist = list(nodeset) 488 | for i, n1 in enumerate(nodelist[:-1]): 489 | for n2 in nodelist[i + 1:]: 490 | if self.negation_mutex(n1, n2) or self.inconsistent_support_mutex(n1, n2): 491 | mutexify(n1, n2) 492 | 493 | def negation_mutex(self, node_s1: PgNode_s, node_s2: PgNode_s) -> bool: 494 | """ 495 | Test a pair of state literals for mutual exclusion, returning True if 496 | one node is the negation of the other, and False otherwise. 497 | 498 | HINT: Look at the PgNode_s.__eq__ defines the notion of equivalence for 499 | literal expression nodes, and the class tracks whether the literal is 500 | positive or negative. 501 | 502 | :param node_s1: PgNode_s 503 | :param node_s2: PgNode_s 504 | :return: bool 505 | """ 506 | same_symbol = node_s1.symbol == node_s2.symbol 507 | not_negated = node_s1.is_pos == node_s1.is_pos 508 | return same_symbol and not_negated 509 | 510 | def inconsistent_support_mutex(self, node_s1: PgNode_s, node_s2: PgNode_s): 511 | """ 512 | Test a pair of state literals for mutual exclusion, returning True if 513 | there are no actions that could achieve the two literals at the same 514 | time, and False otherwise. 515 | In other words, the two literal nodes are 516 | mutex if all of the actions that could achieve the first literal node 517 | are pairwise mutually exclusive with all of the actions that could 518 | achieve the second literal node. 519 | 520 | HINT: The PgNode.is_mutex method can be used to test whether two nodes 521 | are mutually exclusive. 522 | 523 | :param node_s1: PgNode_s 524 | :param node_s2: PgNode_s 525 | :return: bool 526 | """ 527 | if not any(node_s1.parents.intersection(node_s2.parents)): 528 | for pre_a1 in node_s1.parents: 529 | if all((pre_a2.is_mutex(pre_a1) for pre_a2 in node_s2.parents)): 530 | return True 531 | 532 | return False 533 | 534 | def h_levelsum(self) -> int: 535 | """The sum of the level costs of the individual goals 536 | (admissible if goals independent) 537 | for each goal in the problem, determine the level cost, 538 | then add them together 539 | :return: int 540 | """ 541 | level_sum = 0 542 | for goal in self.problem.goal: 543 | for idx, level in enumerate(self.s_levels): 544 | if any(((s.symbol == goal and s.is_pos is True) for s in level)): 545 | level_sum += idx 546 | break 547 | return level_sum 548 | 549 | 550 | 551 | def visualize_plan(self): 552 | """ 553 | {symbol: {h_top }} 554 | 555 | :return: 556 | """ 557 | outpath = 'graph.svg' 558 | h = 500 559 | w = 1000 560 | def swap_rows(r1, r2): 561 | pass 562 | 563 | def add_node(s_node): 564 | pass 565 | 566 | num_levels = len(self.s_levels) 567 | #max nodes are always at lest level since noops 568 | #cause any possible state to remain so total monotonically increases 569 | max_nodes = len(self.s_levels[-1]) 570 | 571 | all_levels = zip(self.s_levels, self.a_levels) 572 | 573 | level_h = h // max_nodes 574 | level_w = w // num_levels 575 | 576 | bx_h = level_h // 2 577 | bx_w = level_w // 2 578 | 579 | # dwg = svgwrite.Drawing(h, w) 580 | 581 | sv_struc = {} 582 | 583 | for idx, (s_level, a_level) in enumerate(self.s_levels): 584 | num_elements = len(s_level) 585 | 586 | #start with last level of graph: 587 | for idx, s_level in enumerate(reversed(self.s_levels)): 588 | num_elements = len(s_level) 589 | for s_node in s_level: 590 | if s_node.symbol in sv_struc: 591 | if s_node.is_pos in sv_struc[s_node.symbol]: 592 | sv_node_h = sv_struc[s_node.symbol][s_node.is_pos] 593 | # create rectangle at x = level_idx * 594 | # svgwrite.shapes.Rect(insert=( ,idx* ), size=(bx_w, bx_h)) 595 | 596 | pass 597 | 598 | 599 | 600 | 601 | 602 | 603 | 604 | 605 | 606 | 607 | 608 | 609 | -------------------------------------------------------------------------------- /problem/planning.py: -------------------------------------------------------------------------------- 1 | """Planning (Chapters 10-11) 2 | """ 3 | 4 | from .utils import Expr 5 | 6 | 7 | class Action: 8 | """ 9 | Defines an action schema using preconditions and effects 10 | Use this to describe actions in PDDL 11 | action is an Expr where variables are given as arguments(args) 12 | Precondition and effect are both lists with positive and negated literals 13 | Example: 14 | precond_pos = [expr("Human(person)"), expr("Hungry(Person)")] 15 | precond_neg = [expr("Eaten(food)")] 16 | effect_add = [expr("Eaten(food)")] 17 | effect_rem = [expr("Hungry(person)")] 18 | eat = Action(expr("Eat(person, food)"), [precond_pos, precond_neg], [effect_add, effect_rem]) 19 | """ 20 | 21 | def __init__(self, action, precond, effect): 22 | self.name = action.op 23 | self.args = action.args 24 | self.precond_pos = precond[0] 25 | self.precond_neg = precond[1] 26 | self.effect_add = effect[0] 27 | self.effect_rem = effect[1] 28 | 29 | def __call__(self, kb, args): 30 | return self.act(kb, args) 31 | 32 | def __str__(self): 33 | return "{}{!s}".format(self.name, self.args) 34 | 35 | def substitute(self, e, args): 36 | """Replaces variables in expression with their 37 | respective Propostional symbol""" 38 | new_args = list(e.args) 39 | for num, x in enumerate(e.args): 40 | for i in range(len(self.args)): 41 | if self.args[i] == x: 42 | new_args[num] = args[i] 43 | return Expr(e.op, *new_args) 44 | 45 | def check_precond(self, kb, args): 46 | """Checks if the precondition is satisfied in the current state""" 47 | # check for positive clauses 48 | for clause in self.precond_pos: 49 | if self.substitute(clause, args) not in kb.clauses: 50 | return False 51 | # check for negative clauses 52 | for clause in self.precond_neg: 53 | if self.substitute(clause, args) in kb.clauses: 54 | return False 55 | return True 56 | 57 | def act(self, kb, args): 58 | """Executes the action on the state's kb""" 59 | # check if the preconditions are satisfied 60 | if not self.check_precond(kb, args): 61 | raise Exception("Action pre-conditions not satisfied") 62 | # remove negative literals 63 | for clause in self.effect_rem: 64 | kb.retract(self.substitute(clause, args)) 65 | # add positive literals 66 | for clause in self.effect_add: 67 | kb.tell(self.substitute(clause, args)) 68 | -------------------------------------------------------------------------------- /problem/search.py: -------------------------------------------------------------------------------- 1 | """Search (Chapters 3-4) 2 | 3 | The way to use this code is to subclass Problem to create a class of problems, 4 | then create problem instances and solve them with calls to the various search 5 | functions.""" 6 | 7 | from .utils import ( 8 | is_in, memoize, print_table, Stack, FIFOQueue, PriorityQueue, name 9 | ) 10 | 11 | import sys 12 | 13 | infinity = float('inf') 14 | 15 | # ______________________________________________________________________________ 16 | 17 | 18 | class Problem: 19 | 20 | """The abstract class for a formal problem. You should subclass 21 | this and implement the methods actions and result, and possibly 22 | __init__, goal_test, and path_cost. Then you will create instances 23 | of your subclass and solve them with the various search functions.""" 24 | 25 | def __init__(self, initial, goal=None): 26 | """The constructor specifies the initial state, and possibly a goal 27 | state, if there is a unique goal. Your subclass's constructor can add 28 | other arguments.""" 29 | self.initial = initial 30 | self.goal = goal 31 | 32 | def actions(self, state): 33 | """Return the actions that can be executed in the given 34 | state. The result would typically be a list, but if there are 35 | many actions, consider yielding them one at a time in an 36 | iterator, rather than building them all at once.""" 37 | raise NotImplementedError 38 | 39 | def result(self, state, action): 40 | """Return the state that results from executing the given 41 | action in the given state. The action must be one of 42 | self.actions(state).""" 43 | raise NotImplementedError 44 | 45 | def goal_test(self, state): 46 | """Return True if the state is a goal. The default method compares the 47 | state to self.goal or checks for state in self.goal if it is a 48 | list, as specified in the constructor. Override this method if 49 | checking against a single self.goal is not enough.""" 50 | if isinstance(self.goal, list): 51 | return is_in(state, self.goal) 52 | else: 53 | return state == self.goal 54 | 55 | def path_cost(self, c, state1, action, state2): 56 | """Return the cost of a solution path that arrives at state2 from 57 | state1 via action, assuming cost c to get up to state1. If the problem 58 | is such that the path doesn't matter, this function will only look at 59 | state2. If the path does matter, it will consider c and maybe state1 60 | and action. The default method costs 1 for every step in the path.""" 61 | return c + 1 62 | 63 | def value(self, state): 64 | """For optimization problems, each state has a value. Hill-climbing 65 | and related algorithms try to maximize this value.""" 66 | raise NotImplementedError 67 | # ______________________________________________________________________________ 68 | 69 | 70 | class Node: 71 | 72 | """A node in a search tree. Contains a pointer to the parent (the node 73 | that this is a successor of) and to the actual state for this node. Note 74 | that if a state is arrived at by two paths, then there are two nodes with 75 | the same state. Also includes the action that got us to this state, and 76 | the total path_cost (also known as g) to reach the node. Other functions 77 | may add an f and h value; see best_first_graph_search and astar_search for 78 | an explanation of how the f and h values are handled. You will not need to 79 | subclass this class.""" 80 | 81 | def __init__(self, state, parent=None, action=None, path_cost=0): 82 | "Create a search tree Node, derived from a parent by an action." 83 | self.state = state 84 | self.parent = parent 85 | self.action = action 86 | self.path_cost = path_cost 87 | self.depth = 0 88 | if parent: 89 | self.depth = parent.depth + 1 90 | 91 | def __repr__(self): 92 | return "" % (self.state,) 93 | 94 | def __lt__(self, node): 95 | return self.state < node.state 96 | 97 | def expand(self, problem): 98 | "List the nodes reachable in one step from this node." 99 | return [self.child_node(problem, action) 100 | for action in problem.actions(self.state)] 101 | 102 | def child_node(self, problem, action): 103 | "[Figure 3.10]" 104 | next = problem.result(self.state, action) 105 | return Node(next, self, action, 106 | problem.path_cost(self.path_cost, self.state, 107 | action, next)) 108 | 109 | def solution(self): 110 | "Return the sequence of actions to go from the root to this node." 111 | return [node.action for node in self.path()[1:]] 112 | 113 | def path(self): 114 | "Return a list of nodes forming the path from the root to this node." 115 | node, path_back = self, [] 116 | while node: 117 | path_back.append(node) 118 | node = node.parent 119 | return list(reversed(path_back)) 120 | 121 | # We want for a queue of nodes in breadth_first_search or 122 | # astar_search to have no duplicated states, so we treat nodes 123 | # with the same state as equal. [Problem: this may not be what you 124 | # want in other contexts.] 125 | 126 | def __eq__(self, other): 127 | return isinstance(other, Node) and self.state == other.state 128 | 129 | def __hash__(self): 130 | return hash(self.state) 131 | 132 | # ______________________________________________________________________________ 133 | # Uninformed Search algorithms 134 | 135 | 136 | def tree_search(problem, frontier): 137 | """Search through the successors of a problem to find a goal. 138 | The argument frontier should be an empty queue. 139 | Don't worry about repeated paths to a state. [Figure 3.7]""" 140 | frontier.append(Node(problem.initial)) 141 | while frontier: 142 | node = frontier.pop() 143 | if problem.goal_test(node.state): 144 | return node 145 | frontier.extend(node.expand(problem)) 146 | return None 147 | 148 | 149 | def graph_search(problem, frontier): 150 | """Search through the successors of a problem to find a goal. 151 | The argument frontier should be an empty queue. 152 | If two paths reach a state, only use the first one. [Figure 3.7]""" 153 | frontier.append(Node(problem.initial)) 154 | explored = set() 155 | while frontier: 156 | node = frontier.pop() 157 | if problem.goal_test(node.state): 158 | return node 159 | explored.add(node.state) 160 | frontier.extend(child for child in node.expand(problem) 161 | if child.state not in explored and 162 | child not in frontier) 163 | return None 164 | 165 | 166 | def breadth_first_tree_search(problem): 167 | "Search the shallowest nodes in the search tree first." 168 | return tree_search(problem, FIFOQueue()) 169 | 170 | 171 | def depth_first_tree_search(problem): 172 | "Search the deepest nodes in the search tree first." 173 | return tree_search(problem, Stack()) 174 | 175 | 176 | def depth_first_graph_search(problem): 177 | "Search the deepest nodes in the search tree first." 178 | return graph_search(problem, Stack()) 179 | 180 | 181 | def breadth_first_search(problem): 182 | "[Figure 3.11]" 183 | node = Node(problem.initial) 184 | if problem.goal_test(node.state): 185 | return node 186 | frontier = FIFOQueue() 187 | frontier.append(node) 188 | explored = set() 189 | while frontier: 190 | node = frontier.pop() 191 | explored.add(node.state) 192 | for child in node.expand(problem): 193 | if child.state not in explored and child not in frontier: 194 | if problem.goal_test(child.state): 195 | return child 196 | frontier.append(child) 197 | return None 198 | 199 | 200 | def best_first_graph_search(problem, f): 201 | """Search the nodes with the lowest f scores first. 202 | You specify the function f(node) that you want to minimize; for example, 203 | if f is a heuristic estimate to the goal, then we have greedy best 204 | first search; if f is node.depth then we have breadth-first search. 205 | There is a subtlety: the line "f = memoize(f, 'f')" means that the f 206 | values will be cached on the nodes as they are computed. So after doing 207 | a best first search you can examine the f values of the path returned.""" 208 | f = memoize(f, 'f') 209 | node = Node(problem.initial) 210 | if problem.goal_test(node.state): 211 | return node 212 | frontier = PriorityQueue(min, f) 213 | frontier.append(node) 214 | explored = set() 215 | while frontier: 216 | node = frontier.pop() 217 | if problem.goal_test(node.state): 218 | return node 219 | explored.add(node.state) 220 | for child in node.expand(problem): 221 | if child.state not in explored and child not in frontier: 222 | frontier.append(child) 223 | elif child in frontier: 224 | incumbent = frontier[child] 225 | if f(child) < f(incumbent): 226 | # del frontier[incumbent] 227 | frontier.append(child) 228 | return None 229 | 230 | 231 | def uniform_cost_search(problem): 232 | "[Figure 3.14]" 233 | return best_first_graph_search(problem, lambda node: node.path_cost) 234 | 235 | 236 | def depth_limited_search(problem, limit=50): 237 | "[Figure 3.17]" 238 | def recursive_dls(node, problem, limit): 239 | if problem.goal_test(node.state): 240 | return node 241 | elif limit == 0: 242 | return 'cutoff' 243 | else: 244 | cutoff_occurred = False 245 | for child in node.expand(problem): 246 | result = recursive_dls(child, problem, limit - 1) 247 | if result == 'cutoff': 248 | cutoff_occurred = True 249 | elif result is not None: 250 | return result 251 | return 'cutoff' if cutoff_occurred else None 252 | 253 | # Body of depth_limited_search: 254 | return recursive_dls(Node(problem.initial), problem, limit) 255 | 256 | 257 | def iterative_deepening_search(problem): 258 | "[Figure 3.18]" 259 | for depth in range(sys.maxsize): 260 | result = depth_limited_search(problem, depth) 261 | if result != 'cutoff': 262 | return result 263 | 264 | # ______________________________________________________________________________ 265 | # Informed (Heuristic) Search 266 | 267 | greedy_best_first_graph_search = best_first_graph_search 268 | # Greedy best-first search is accomplished by specifying f(n) = h(n). 269 | 270 | 271 | def astar_search(problem, h=None): 272 | """A* search is best-first graph search with f(n) = g(n)+h(n). 273 | You need to specify the h function when you call astar_search, or 274 | else in your Problem subclass.""" 275 | h = memoize(h or problem.h, 'h') 276 | return best_first_graph_search(problem, lambda n: n.path_cost + h(n)) 277 | 278 | # ______________________________________________________________________________ 279 | # Other search algorithms 280 | 281 | 282 | def recursive_best_first_search(problem, h=None): 283 | "[Figure 3.26]" 284 | h = memoize(h or problem.h, 'h') 285 | 286 | def RBFS(problem, node, flimit): 287 | if problem.goal_test(node.state): 288 | return node, 0 # (The second value is immaterial) 289 | successors = node.expand(problem) 290 | if len(successors) == 0: 291 | return None, infinity 292 | for s in successors: 293 | s.f = max(s.path_cost + h(s), node.f) 294 | while True: 295 | # Order by lowest f value 296 | successors.sort(key=lambda x: x.f) 297 | best = successors[0] 298 | if best.f > flimit: 299 | return None, best.f 300 | if len(successors) > 1: 301 | alternative = successors[1].f 302 | else: 303 | alternative = infinity 304 | result, best.f = RBFS(problem, best, min(flimit, alternative)) 305 | if result is not None: 306 | return result, best.f 307 | 308 | node = Node(problem.initial) 309 | node.f = h(node) 310 | result, bestf = RBFS(problem, node, infinity) 311 | return result 312 | 313 | # ______________________________________________________________________________ 314 | 315 | # Code to compare searchers on various problems. 316 | 317 | 318 | class InstrumentedProblem(Problem): 319 | 320 | """Delegates to a problem, and keeps statistics.""" 321 | 322 | def __init__(self, problem): 323 | self.problem = problem 324 | self.succs = self.goal_tests = self.states = 0 325 | self.found = None 326 | 327 | def actions(self, state): 328 | self.succs += 1 329 | return self.problem.actions(state) 330 | 331 | def result(self, state, action): 332 | self.states += 1 333 | return self.problem.result(state, action) 334 | 335 | def goal_test(self, state): 336 | self.goal_tests += 1 337 | result = self.problem.goal_test(state) 338 | if result: 339 | self.found = state 340 | return result 341 | 342 | def path_cost(self, c, state1, action, state2): 343 | return self.problem.path_cost(c, state1, action, state2) 344 | 345 | def value(self, state): 346 | return self.problem.value(state) 347 | 348 | def __getattr__(self, attr): 349 | return getattr(self.problem, attr) 350 | 351 | def __repr__(self): 352 | return '<%4d/%4d/%4d/%s>' % (self.succs, self.goal_tests, 353 | self.states, str(self.found)[:4]) 354 | 355 | 356 | def compare_searchers(problems, header, 357 | searchers=[breadth_first_tree_search, 358 | breadth_first_search, 359 | depth_first_graph_search, 360 | iterative_deepening_search, 361 | depth_limited_search, 362 | recursive_best_first_search]): 363 | def do(searcher, problem): 364 | p = InstrumentedProblem(problem) 365 | searcher(p) 366 | return p 367 | table = [[name(s)] + [do(s, p) for p in problems] for s in searchers] 368 | print_table(table, header) 369 | -------------------------------------------------------------------------------- /problem/utils.py: -------------------------------------------------------------------------------- 1 | """Provides some utilities widely used by other modules""" 2 | 3 | import bisect 4 | import collections 5 | import collections.abc 6 | import functools 7 | import operator 8 | import os.path 9 | import random 10 | import math 11 | 12 | import heapq 13 | from collections import defaultdict 14 | 15 | # ______________________________________________________________________________ 16 | # Functions on Sequences and Iterables 17 | 18 | 19 | def sequence(iterable): 20 | "Coerce iterable to sequence, if it is not already one." 21 | return (iterable if isinstance(iterable, collections.abc.Sequence) 22 | else tuple(iterable)) 23 | 24 | 25 | def removeall(item, seq): 26 | """Return a copy of seq (or string) with all occurences of item removed.""" 27 | if isinstance(seq, str): 28 | return seq.replace(item, '') 29 | else: 30 | return [x for x in seq if x != item] 31 | 32 | 33 | def unique(seq): # TODO: replace with set 34 | """Remove duplicate elements from seq. Assumes hashable elements.""" 35 | return list(set(seq)) 36 | 37 | 38 | def count(seq): 39 | """Count the number of items in sequence that are interpreted as true.""" 40 | return sum(bool(x) for x in seq) 41 | 42 | 43 | def product(numbers): 44 | """Return the product of the numbers, e.g. product([2, 3, 10]) == 60""" 45 | result = 1 46 | for x in numbers: 47 | result *= x 48 | return result 49 | 50 | 51 | def first(iterable, default=None): 52 | "Return the first element of an iterable or the next element of a generator; or default." 53 | try: 54 | return iterable[0] 55 | except IndexError: 56 | return default 57 | except TypeError: 58 | return next(iterable, default) 59 | 60 | 61 | def is_in(elt, seq): 62 | """Similar to (elt in seq), but compares with 'is', not '=='.""" 63 | return any(x is elt for x in seq) 64 | 65 | # ______________________________________________________________________________ 66 | # argmin and argmax 67 | 68 | identity = lambda x: x 69 | 70 | argmin = min 71 | argmax = max 72 | 73 | 74 | def argmin_random_tie(seq, key=identity): 75 | """Return a minimum element of seq; break ties at random.""" 76 | return argmin(shuffled(seq), key=key) 77 | 78 | 79 | def argmax_random_tie(seq, key=identity): 80 | "Return an element with highest fn(seq[i]) score; break ties at random." 81 | return argmax(shuffled(seq), key=key) 82 | 83 | 84 | def shuffled(iterable): 85 | "Randomly shuffle a copy of iterable." 86 | items = list(iterable) 87 | random.shuffle(items) 88 | return items 89 | 90 | 91 | 92 | # ______________________________________________________________________________ 93 | # Statistical and mathematical functions 94 | 95 | 96 | def histogram(values, mode=0, bin_function=None): 97 | """Return a list of (value, count) pairs, summarizing the input values. 98 | Sorted by increasing value, or if mode=1, by decreasing count. 99 | If bin_function is given, map it over values first.""" 100 | if bin_function: 101 | values = map(bin_function, values) 102 | 103 | bins = {} 104 | for val in values: 105 | bins[val] = bins.get(val, 0) + 1 106 | 107 | if mode: 108 | return sorted(list(bins.items()), key=lambda x: (x[1], x[0]), 109 | reverse=True) 110 | else: 111 | return sorted(bins.items()) 112 | 113 | 114 | def dotproduct(X, Y): 115 | """Return the sum of the element-wise product of vectors X and Y.""" 116 | return sum(x * y for x, y in zip(X, Y)) 117 | 118 | 119 | def element_wise_product(X, Y): 120 | """Return vector as an element-wise product of vectors X and Y""" 121 | assert len(X) == len(Y) 122 | return [x * y for x, y in zip(X, Y)] 123 | 124 | 125 | def matrix_multiplication(X_M, *Y_M): 126 | """Return a matrix as a matrix-multiplication of X_M and arbitary number of matrices *Y_M""" 127 | 128 | def _mat_mult(X_M, Y_M): 129 | """Return a matrix as a matrix-multiplication of two matrices X_M and Y_M 130 | matrix_multiplication([[1, 2, 3], 131 | [2, 3, 4]], 132 | [[3, 4], 133 | [1, 2], 134 | [1, 0]]) 135 | [[8, 8],[13, 14]] 136 | """ 137 | assert len(X_M[0]) == len(Y_M) 138 | 139 | result = [[0 for i in range(len(Y_M[0]))] for j in range(len(X_M))] 140 | for i in range(len(X_M)): 141 | for j in range(len(Y_M[0])): 142 | for k in range(len(Y_M)): 143 | result[i][j] += X_M[i][k] * Y_M[k][j] 144 | return result 145 | 146 | result = X_M 147 | for Y in Y_M: 148 | result = _mat_mult(result, Y) 149 | 150 | return result 151 | 152 | 153 | def vector_to_diagonal(v): 154 | """Converts a vector to a diagonal matrix with vector elements 155 | as the diagonal elements of the matrix""" 156 | diag_matrix = [[0 for i in range(len(v))] for j in range(len(v))] 157 | for i in range(len(v)): 158 | diag_matrix[i][i] = v[i] 159 | 160 | return diag_matrix 161 | 162 | 163 | def vector_add(a, b): 164 | """Component-wise addition of two vectors.""" 165 | return tuple(map(operator.add, a, b)) 166 | 167 | 168 | 169 | def scalar_vector_product(X, Y): 170 | """Return vector as a product of a scalar and a vector""" 171 | return [X * y for y in Y] 172 | 173 | 174 | def scalar_matrix_product(X, Y): 175 | return [scalar_vector_product(X, y) for y in Y] 176 | 177 | 178 | def inverse_matrix(X): 179 | """Inverse a given square matrix of size 2x2""" 180 | assert len(X) == 2 181 | assert len(X[0]) == 2 182 | det = X[0][0] * X[1][1] - X[0][1] * X[1][0] 183 | assert det != 0 184 | inv_mat = scalar_matrix_product(1.0/det, [[X[1][1], -X[0][1]], [-X[1][0], X[0][0]]]) 185 | 186 | return inv_mat 187 | 188 | 189 | def probability(p): 190 | "Return true with probability p." 191 | return p > random.uniform(0.0, 1.0) 192 | 193 | 194 | def weighted_sample_with_replacement(seq, weights, n): 195 | """Pick n samples from seq at random, with replacement, with the 196 | probability of each element in proportion to its corresponding 197 | weight.""" 198 | sample = weighted_sampler(seq, weights) 199 | 200 | return [sample() for _ in range(n)] 201 | 202 | 203 | def weighted_sampler(seq, weights): 204 | "Return a random-sample function that picks from seq weighted by weights." 205 | totals = [] 206 | for w in weights: 207 | totals.append(w + totals[-1] if totals else w) 208 | 209 | return lambda: seq[bisect.bisect(totals, random.uniform(0, totals[-1]))] 210 | 211 | 212 | def rounder(numbers, d=4): 213 | "Round a single number, or sequence of numbers, to d decimal places." 214 | if isinstance(numbers, (int, float)): 215 | return round(numbers, d) 216 | else: 217 | constructor = type(numbers) # Can be list, set, tuple, etc. 218 | return constructor(rounder(n, d) for n in numbers) 219 | 220 | 221 | def num_or_str(x): 222 | """The argument is a string; convert to a number if 223 | possible, or strip it. 224 | """ 225 | try: 226 | return int(x) 227 | except ValueError: 228 | try: 229 | return float(x) 230 | except ValueError: 231 | return str(x).strip() 232 | 233 | 234 | def normalize(dist): 235 | """Multiply each number by a constant such that the sum is 1.0""" 236 | if isinstance(dist, dict): 237 | total = sum(dist.values()) 238 | for key in dist: 239 | dist[key] = dist[key] / total 240 | assert 0 <= dist[key] <= 1, "Probabilities must be between 0 and 1." 241 | return dist 242 | total = sum(dist) 243 | return [(n / total) for n in dist] 244 | 245 | 246 | def clip(x, lowest, highest): 247 | """Return x clipped to the range [lowest..highest].""" 248 | return max(lowest, min(x, highest)) 249 | 250 | 251 | def sigmoid(x): 252 | """Return activation value of x with sigmoid function""" 253 | return 1/(1 + math.exp(-x)) 254 | 255 | 256 | def step(x): 257 | """Return activation value of x with sign function""" 258 | return 1 if x >= 0 else 0 259 | 260 | try: # math.isclose was added in Python 3.5; but we might be in 3.4 261 | from math import isclose 262 | except ImportError: 263 | def isclose(a, b, rel_tol=1e-09, abs_tol=0.0): 264 | "Return true if numbers a and b are close to each other." 265 | return abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol) 266 | 267 | # ______________________________________________________________________________ 268 | # Misc Functions 269 | 270 | 271 | # TODO: Use functools.lru_cache memoization decorator 272 | 273 | 274 | def memoize(fn, slot=None): 275 | """Memoize fn: make it remember the computed value for any argument list. 276 | If slot is specified, store result in that slot of first argument. 277 | If slot is false, store results in a dictionary.""" 278 | if slot: 279 | def memoized_fn(obj, *args): 280 | if hasattr(obj, slot): 281 | return getattr(obj, slot) 282 | else: 283 | val = fn(obj, *args) 284 | setattr(obj, slot, val) 285 | return val 286 | else: 287 | def memoized_fn(*args): 288 | if args not in memoized_fn.cache: 289 | memoized_fn.cache[args] = fn(*args) 290 | return memoized_fn.cache[args] 291 | 292 | memoized_fn.cache = {} 293 | 294 | return memoized_fn 295 | 296 | 297 | def name(obj): 298 | "Try to find some reasonable name for the object." 299 | return (getattr(obj, 'name', 0) or getattr(obj, '__name__', 0) or 300 | getattr(getattr(obj, '__class__', 0), '__name__', 0) or 301 | str(obj)) 302 | 303 | 304 | def isnumber(x): 305 | "Is x a number?" 306 | return hasattr(x, '__int__') 307 | 308 | 309 | def issequence(x): 310 | "Is x a sequence?" 311 | return isinstance(x, collections.abc.Sequence) 312 | 313 | 314 | def print_table(table, header=None, sep=' ', numfmt='%g'): 315 | """Print a list of lists as a table, so that columns line up nicely. 316 | header, if specified, will be printed as the first row. 317 | numfmt is the format for all numbers; you might want e.g. '%6.2f'. 318 | (If you want different formats in different columns, 319 | don't use print_table.) sep is the separator between columns.""" 320 | justs = ['rjust' if isnumber(x) else 'ljust' for x in table[0]] 321 | 322 | if header: 323 | table.insert(0, header) 324 | 325 | table = [[numfmt.format(x) if isnumber(x) else x for x in row] 326 | for row in table] 327 | 328 | sizes = list( 329 | map(lambda seq: max(map(len, seq)), 330 | list(zip(*[map(str, row) for row in table])))) 331 | 332 | for row in table: 333 | print(sep.join(getattr( 334 | str(x), j)(size) for (j, size, x) in zip(justs, sizes, row))) 335 | 336 | 337 | def AIMAFile(components, mode='r'): 338 | "Open a file based at the AIMA root directory." 339 | aima_root = os.path.dirname(__file__) 340 | 341 | aima_file = os.path.join(aima_root, *components) 342 | 343 | return open(aima_file) 344 | 345 | 346 | def DataFile(name, mode='r'): 347 | "Return a file in the AIMA /aimacode-data directory." 348 | return AIMAFile(['aimacode-data', name], mode) 349 | 350 | 351 | # ______________________________________________________________________________ 352 | # Expressions 353 | 354 | # See https://docs.python.org/3/reference/expressions.html#operator-precedence 355 | # See https://docs.python.org/3/reference/datamodel.html#special-method-names 356 | 357 | class Expr(object): 358 | """A mathematical expression with an operator and 0 or more arguments. 359 | op is a str like '+' or 'sin'; args are Expressions. 360 | Expr('x') or Symbol('x') creates a symbol (a nullary Expr). 361 | Expr('-', x) creates a unary; Expr('+', x, 1) creates a binary.""" 362 | 363 | def __init__(self, op, *args): 364 | self.op = str(op) 365 | self.args = args 366 | self.__hash = None 367 | 368 | # Operator overloads 369 | def __neg__(self): return Expr('-', self) 370 | def __pos__(self): return Expr('+', self) 371 | def __invert__(self): return Expr('~', self) 372 | def __add__(self, rhs): return Expr('+', self, rhs) 373 | def __sub__(self, rhs): return Expr('-', self, rhs) 374 | def __mul__(self, rhs): return Expr('*', self, rhs) 375 | def __pow__(self, rhs): return Expr('**',self, rhs) 376 | def __mod__(self, rhs): return Expr('%', self, rhs) 377 | def __and__(self, rhs): return Expr('&', self, rhs) 378 | def __xor__(self, rhs): return Expr('^', self, rhs) 379 | def __rshift__(self, rhs): return Expr('>>', self, rhs) 380 | def __lshift__(self, rhs): return Expr('<<', self, rhs) 381 | def __truediv__(self, rhs): return Expr('/', self, rhs) 382 | def __floordiv__(self, rhs): return Expr('//', self, rhs) 383 | def __matmul__(self, rhs): return Expr('@', self, rhs) 384 | 385 | def __or__(self, rhs): 386 | "Allow both P | Q, and P |'==>'| Q." 387 | if isinstance(rhs, Expression): 388 | return Expr('|', self, rhs) 389 | else: 390 | return PartialExpr(rhs, self) 391 | 392 | # Reverse operator overloads 393 | def __radd__(self, lhs): return Expr('+', lhs, self) 394 | def __rsub__(self, lhs): return Expr('-', lhs, self) 395 | def __rmul__(self, lhs): return Expr('*', lhs, self) 396 | def __rdiv__(self, lhs): return Expr('/', lhs, self) 397 | def __rpow__(self, lhs): return Expr('**', lhs, self) 398 | def __rmod__(self, lhs): return Expr('%', lhs, self) 399 | def __rand__(self, lhs): return Expr('&', lhs, self) 400 | def __rxor__(self, lhs): return Expr('^', lhs, self) 401 | def __ror__(self, lhs): return Expr('|', lhs, self) 402 | def __rrshift__(self, lhs): return Expr('>>', lhs, self) 403 | def __rlshift__(self, lhs): return Expr('<<', lhs, self) 404 | def __rtruediv__(self, lhs): return Expr('/', lhs, self) 405 | def __rfloordiv__(self, lhs): return Expr('//', lhs, self) 406 | def __rmatmul__(self, lhs): return Expr('@', lhs, self) 407 | 408 | def __call__(self, *args): 409 | "Call: if 'f' is a Symbol, then f(0) == Expr('f', 0)." 410 | if self.args: 411 | raise ValueError('can only do a call for a Symbol, not an Expr') 412 | else: 413 | return Expr(self.op, *args) 414 | 415 | # Equality and repr 416 | def __eq__(self, other): 417 | "'x == y' evaluates to True or False; does not build an Expr." 418 | return (isinstance(other, Expr) 419 | and self.op == other.op 420 | and self.args == other.args) 421 | 422 | def __hash__(self): 423 | self.__hash = self.__hash or hash(self.op) ^ hash(self.args) 424 | return self.__hash 425 | 426 | def __repr__(self): 427 | op = self.op 428 | args = [str(arg) for arg in self.args] 429 | if op.isidentifier(): # f(x) or f(x, y) 430 | return '{}({})'.format(op, ', '.join(args)) if args else op 431 | elif len(args) == 1: # -x or -(x + 1) 432 | return op + args[0] 433 | else: # (x - y) 434 | opp = (' ' + op + ' ') 435 | return '(' + opp.join(args) + ')' 436 | 437 | # An 'Expression' is either an Expr or a Number. 438 | # Symbol is not an explicit type; it is any Expr with 0 args. 439 | 440 | Number = (int, float, complex) 441 | Expression = (Expr, Number) 442 | 443 | 444 | def Symbol(name): 445 | "A Symbol is just an Expr with no args." 446 | return Expr(name) 447 | 448 | 449 | def symbols(names): 450 | "Return a tuple of Symbols; names is a comma/whitespace delimited str." 451 | return tuple(Symbol(name) for name in names.replace(',', ' ').split()) 452 | 453 | 454 | def subexpressions(x): 455 | "Yield the subexpressions of an Expression (including x itself)." 456 | yield x 457 | if isinstance(x, Expr): 458 | for arg in x.args: 459 | yield from subexpressions(arg) 460 | 461 | 462 | def arity(expression): 463 | "The number of sub-expressions in this expression." 464 | if isinstance(expression, Expr): 465 | return len(expression.args) 466 | else: # expression is a number 467 | return 0 468 | 469 | # For operators that are not defined in Python, we allow new InfixOps: 470 | 471 | 472 | class PartialExpr: 473 | """Given 'P |'==>'| Q, first form PartialExpr('==>', P), then combine with Q.""" 474 | def __init__(self, op, lhs): self.op, self.lhs = op, lhs 475 | def __or__(self, rhs): return Expr(self.op, self.lhs, rhs) 476 | def __repr__(self): return "PartialExpr('{}', {})".format(self.op, self.lhs) 477 | 478 | 479 | def expr(x): 480 | """Shortcut to create an Expression. x is a str in which: 481 | - identifiers are automatically defined as Symbols. 482 | - ==> is treated as an infix |'==>'|, as are <== and <=>. 483 | If x is already an Expression, it is returned unchanged. Example: 484 | >>> expr('P & Q ==> Q') 485 | ((P & Q) ==> Q) 486 | """ 487 | if isinstance(x, str): 488 | return eval(expr_handle_infix_ops(x), defaultkeydict(Symbol)) 489 | else: 490 | return x 491 | 492 | infix_ops = '==> <== <=>'.split() 493 | 494 | 495 | def expr_handle_infix_ops(x): 496 | """Given a str, return a new str with ==> replaced by |'==>'|, etc. 497 | >>> expr_handle_infix_ops('P ==> Q') 498 | "P |'==>'| Q" 499 | """ 500 | for op in infix_ops: 501 | x = x.replace(op, '|' + repr(op) + '|') 502 | return x 503 | 504 | 505 | class defaultkeydict(collections.defaultdict): 506 | """Like defaultdict, but the default_factory is a function of the key. 507 | >>> d = defaultkeydict(len); d['four'] 508 | 4 509 | """ 510 | def __missing__(self, key): 511 | self[key] = result = self.default_factory(key) 512 | return result 513 | 514 | 515 | # ______________________________________________________________________________ 516 | # Queues: Stack, FIFOQueue, PriorityQueue 517 | 518 | # TODO: Possibly use queue.Queue, queue.PriorityQueue 519 | # TODO: Priority queues may not belong here -- see treatment in search.py 520 | 521 | 522 | class Queue: 523 | 524 | """Queue is an abstract class/interface. There are three types: 525 | Stack(): A Last In First Out Queue. 526 | FIFOQueue(): A First In First Out Queue. 527 | PriorityQueue(order, f): Queue in sorted order (default min-first). 528 | Each type supports the following methods and functions: 529 | q.append(item) -- add an item to the queue 530 | q.extend(items) -- equivalent to: for item in items: q.append(item) 531 | q.pop() -- return the top item from the queue 532 | len(q) -- number of items in q (also q.__len()) 533 | item in q -- does q contain item? 534 | Note that isinstance(Stack(), Queue) is false, because we implement stacks 535 | as lists. If Python ever gets interfaces, Queue will be an interface.""" 536 | 537 | def __init__(self): 538 | raise NotImplementedError 539 | 540 | def extend(self, items): 541 | for item in items: 542 | self.append(item) 543 | 544 | 545 | def Stack(): 546 | """Return an empty list, suitable as a Last-In-First-Out Queue.""" 547 | return [] 548 | 549 | 550 | class FIFOQueue(Queue): 551 | 552 | """A First-In-First-Out Queue.""" 553 | 554 | def __init__(self): 555 | self.A = [] 556 | self.start = 0 557 | 558 | def append(self, item): 559 | self.A.append(item) 560 | 561 | def __len__(self): 562 | return len(self.A) - self.start 563 | 564 | def extend(self, items): 565 | self.A.extend(items) 566 | 567 | def pop(self): 568 | e = self.A[self.start] 569 | self.start += 1 570 | if self.start > 5 and self.start > len(self.A) / 2: 571 | self.A = self.A[self.start:] 572 | self.start = 0 573 | return e 574 | 575 | def __contains__(self, item): 576 | return item in self.A[self.start:] 577 | 578 | 579 | class PriorityQueue(Queue): 580 | """A queue in which the minimum element (as determined by f and 581 | order) is returned first. Also supports dict-like lookup. 582 | 583 | MODIFIED FROM AIMA VERSION 584 | - Use heapq 585 | - Use an additional dict to track membership 586 | - remove __delitem__ (AIMA version contains error) 587 | """ 588 | 589 | def __init__(self, order=None, f=lambda x: x): 590 | self.A = [] 591 | self._A = defaultdict(lambda: 0) 592 | self.f = f 593 | 594 | def append(self, item): 595 | heapq.heappush(self.A, (self.f(item), item)) 596 | self._A[item] += 1 597 | 598 | def __len__(self): 599 | return len(self.A) 600 | 601 | def pop(self): 602 | _, item = heapq.heappop(self.A) 603 | self._A[item] -= 1 604 | return item 605 | 606 | def __contains__(self, item): 607 | return self._A[item] > 0 608 | 609 | def __getitem__(self, key): 610 | if self._A[key] > 0: 611 | return key 612 | 613 | # ______________________________________________________________________________ 614 | # Useful Shorthands 615 | 616 | 617 | class Bool(int): 618 | """Just like `bool`, except values display as 'T' and 'F' instead of 'True' and 'False'""" 619 | __str__ = __repr__ = lambda self: 'T' if self else 'F' 620 | 621 | T = Bool(True) 622 | F = Bool(False) 623 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import dnc_arity_list as dnc 2 | import numpy as np 3 | from utils import running_avg, flat, save, _variable 4 | import utils as u 5 | import torch 6 | import torch.nn as nn 7 | from problem import generators_v2 as gen 8 | import torch.optim as optim 9 | import time, random 10 | from visualize import logger as sl 11 | import os 12 | import losses as L 13 | from arg import args 14 | 15 | random.seed() 16 | batch_size = 1 17 | dnc_args = {'num_layers': 2, 18 | 'num_read_heads': 2, 19 | 'hidden_size': 250, 20 | 'num_write_heads': 1, 21 | 'memory_size': 100, #50 22 | 'batch_size': batch_size} 23 | 24 | 25 | def generate_data_spec(args, num_ents=2, solve=True): 26 | if args.typed is 1: 27 | ix_size = [4, args.max_ents, 4, args.max_ents, 4, args.max_ents] 28 | encoding = 2 29 | else: 30 | ix_state = args.max_ents * 3 31 | ix_size = [ix_state, ix_state, ix_state] 32 | encoding = 1 33 | return {'num_plane': num_ents, 'num_cargo': num_ents, 'num_airport': num_ents, 34 | 'one_hot_size': ix_size, 'plan_phase': num_ents * 3, 'cuda': args.cuda, 35 | 'batch_size': 1, 'encoding': encoding, 'solve': solve, 'mapping': None} 36 | 37 | 38 | def setupLSTM(args): 39 | data = gen.AirCargoData(**generate_data_spec(args)) 40 | dnc_args['output_size'] = data.nn_in_size # output has no phase component 41 | dnc_args['word_len'] = data.nn_out_size 42 | print(dnc_args) 43 | # input_size = self.output_size + word_len * num_read_heads 44 | Dnc = dnc.VanillaLSTM(batch_size=1, num_layers=2, input_size=data.nn_in_size, 45 | output_size=data.nn_out_size, hidden_size=250, num_reads=2) 46 | previous_out, (ho1, hc1), (ho2, hc2) = Dnc.init_state() 47 | if args.opt == 'adam': 48 | optimizer = optim.Adam([{'params': Dnc.parameters()}, {'params': ho1}, 49 | {'params': hc1}, {'params': hc2}], lr=args.lr) 50 | else: 51 | optimizer = optim.SGD([{'params': Dnc.parameters()}, {'params': ho1}, 52 | {'params': hc1}, {'params': hc2}], lr=args.lr) 53 | 54 | lstm_state = (previous_out, (ho1, hc1), (ho2, hc2)) 55 | return data, Dnc, optimizer, lstm_state 56 | 57 | 58 | def setupDNC(args): 59 | """ 60 | Loader for files or setup new DNC and optimizer 61 | :param args: 62 | :return: 63 | """ 64 | if args.algo == 'lstm': 65 | return setupLSTM(args) 66 | data = gen.AirCargoData(**generate_data_spec(args)) 67 | dnc_args['output_size'] = data.nn_in_size # output has no phase component 68 | dnc_args['word_len'] = data.nn_out_size 69 | print('dnc_args:\n', dnc_args, '\n') 70 | if args.load == '': 71 | Dnc = dnc.DNC(**dnc_args) 72 | if args.opt == 'adam': 73 | optimizer = optim.Adam(Dnc.parameters(), lr=args.lr) 74 | elif args.opt == 'sgd': 75 | optimizer = optim.SGD(Dnc.parameters(), lr=args.lr) 76 | else: 77 | optimizer = None 78 | else: 79 | model_path, optim_path = u.get_chkpt(args.load) 80 | print('loading', model_path) 81 | Dnc = dnc.DNC(**dnc_args) 82 | Dnc.load_state_dict(torch.load(model_path)) 83 | 84 | optimizer = optim.Adam(Dnc.parameters(), lr=args.lr) 85 | if os.path.exists(optim_path): 86 | optimizer.load_state_dict(torch.load(optim_path)) 87 | 88 | if args.cuda is True: 89 | Dnc = Dnc.cuda() 90 | lstm_state = Dnc.init_rnn() 91 | return data, Dnc, optimizer, lstm_state 92 | 93 | 94 | def tick(n_total, n_correct, truth, pred): 95 | n_total += 1 96 | n_correct += 1 if truth == pred else 0 97 | sl.global_step += 1 98 | return n_total, n_correct 99 | 100 | 101 | def train_qa2(args, data, DNC, optimizer): 102 | """ 103 | I am jacks liver. This is a sanity test 104 | 105 | 0 - describe state. 106 | 1 - describe goal. 107 | 2 - do actions. 108 | 3 - ask some questions 109 | :param args: 110 | :return: 111 | """ 112 | criterion = nn.CrossEntropyLoss() 113 | cum_correct, cum_total = [], [] 114 | 115 | for trial in range(args.iters): 116 | phase_masks = data.make_new_problem() 117 | n_total, n_correct, loss = 0, 0, 0 118 | dnc_state = DNC.init_state(grad=False) 119 | optimizer.zero_grad() 120 | 121 | for phase_idx in phase_masks: 122 | if phase_idx == 0 or phase_idx == 1: 123 | inputs = _variable(data.getitem_combined()) 124 | logits, dnc_state = DNC(inputs, dnc_state) 125 | else: 126 | final_moves = data.get_actions(mode='one') 127 | if final_moves == []: 128 | break 129 | data.send_action(final_moves[0]) 130 | mask = data.phase_oh[2].unsqueeze(0) 131 | inputs2 = _variable(torch.cat([mask, data.vec_to_ix(final_moves[0])], 1)) 132 | logits, dnc_state = DNC(inputs2, dnc_state) 133 | 134 | for _ in range(args.num_tests): 135 | # ask where is ---? 136 | if args.zero_at == 'step': 137 | optimizer.zero_grad() 138 | masked_input, mask_chunk, ground_truth = data.masked_input() 139 | logits, dnc_state = DNC(_variable(masked_input), dnc_state) 140 | expanded_logits = data.ix_input_to_ixs(logits) 141 | 142 | # losses 143 | lstep = L.action_loss(expanded_logits, ground_truth, criterion, log=True) 144 | if args.opt_at == 'problem': 145 | loss += lstep 146 | else: 147 | lstep.backward(retain_graph=args.ret_graph) 148 | optimizer.step() 149 | loss = lstep 150 | 151 | # update counters 152 | prediction = u.get_prediction(expanded_logits, [3, 4]) 153 | n_total, n_correct = tick(n_total, n_correct, mask_chunk, prediction) 154 | 155 | if args.opt_at == 'problem': 156 | loss.backward(retain_graph=args.ret_graph) 157 | optimizer.step() 158 | sl.writer.add_scalar('losses.end', loss.data[0], sl.global_step) 159 | 160 | cum_total.append(n_total) 161 | cum_correct.append(n_correct) 162 | sl.writer.add_scalar('recall.pct_correct', n_correct / n_total, sl.global_step) 163 | print("trial: {}, step:{}, accy {:0.4f}, cum_score {:0.4f}, loss: {:0.4f}".format( 164 | trial, sl.global_step, n_correct / n_total, running_avg(cum_correct, cum_total), loss.data[0])) 165 | return DNC, optimizer, dnc_state, running_avg(cum_correct, cum_total) 166 | 167 | 168 | def random_seq(args, data, DNC, lstm_state, optimizer): 169 | pass 170 | 171 | 172 | def train_rl(args, data, DNC, lstm_state, optimizer): 173 | """ 174 | 175 | :param args: 176 | :param data: 177 | :param DNC: a tuple of value and action networks 178 | :param lstm_state: 179 | :param optimizer: 180 | :return: 181 | """ 182 | for trial in range(args.iters): 183 | start_prob = time.time() 184 | phase_masks = data.make_new_problem() 185 | 186 | 187 | pass 188 | 189 | 190 | def train_plan(args, data, DNC, lstm_state, optimizer): 191 | """ 192 | Things to test after some iterations: 193 | - on planning phase and on 194 | 195 | with goals - chose a goal and work toward that 196 | :param args: 197 | :return: 198 | """ 199 | criterion = nn.CrossEntropyLoss().cuda() if args.cuda is True else nn.CrossEntropyLoss() 200 | cum_correct, cum_total, prob_times, n_success = [], [], [], 0 201 | penalty = 1.1 202 | 203 | for trial in range(args.iters): 204 | start_prob = time.time() 205 | phase_masks = data.make_new_problem() 206 | n_total, n_correct, prev_action, loss, stats = 0, 0, None, 0, [] 207 | dnc_state = DNC.init_state(grad=False) 208 | lstm_state = DNC.init_rnn(grad=False) # lstm_state, 209 | optimizer.zero_grad() 210 | 211 | for phase_idx in phase_masks: 212 | 213 | if phase_idx == 0 or phase_idx == 1: 214 | inputs = _variable(data.getitem_combined()) 215 | logits, dnc_state, lstm_state = DNC(inputs, lstm_state, dnc_state) 216 | _, prev_action = data.strip_ix_mask(logits) 217 | 218 | elif phase_idx == 2: 219 | mask = _variable(data.getmask()) 220 | inputs = torch.cat([mask, prev_action], 1) 221 | logits, dnc_state, lstm_state = DNC(inputs, lstm_state, dnc_state) 222 | _, prev_action = data.strip_ix_mask(logits) 223 | 224 | else: 225 | # sample from best moves 226 | actions_star, all_actions = data.get_actions(mode='both') 227 | if not actions_star: 228 | break 229 | if args.zero_at == 'step': 230 | optimizer.zero_grad() 231 | 232 | mask = data.getmask() 233 | prev_action = prev_action.cuda() if args.cuda is True else prev_action 234 | pr = u.depackage(prev_action) 235 | 236 | final_inputs = _variable(torch.cat([mask, pr], 1)) 237 | logits, dnc_state, lstm_state = DNC(final_inputs, lstm_state, dnc_state) 238 | exp_logits = data.ix_input_to_ixs(logits) 239 | 240 | guided = random.random() < args.beta 241 | # thing 1 242 | if guided: # guided loss 243 | final_action, lstep = L.naive_loss(exp_logits, actions_star, criterion, log=True) 244 | else: # pick own move 245 | final_action, lstep = L.naive_loss(exp_logits, all_actions, criterion, log=True) 246 | 247 | # penalty for todo tests this !!!! 248 | action_own = u.get_prediction(exp_logits) 249 | if args.penalty and not [tuple(flat(t)) for t in all_actions]: 250 | final_loss = lstep * _variable([args.penalty]) 251 | else: 252 | final_loss = lstep 253 | 254 | if args.opt_at == 'problem': 255 | loss += final_loss 256 | else: 257 | 258 | final_loss.backward(retain_graph=args.ret_graph) 259 | if args.clip: 260 | torch.nn.utils.clip_grad_norm(DNC.parameters(), args.clip) 261 | optimizer.step() 262 | loss = lstep 263 | 264 | data.send_action(final_action) 265 | 266 | if (trial + 1) % args.show_details == 0: 267 | action_accs = u.human_readable_res(data, all_actions, actions_star, 268 | action_own, guided, lstep.data[0]) 269 | stats.append(action_accs) 270 | n_total, _ = tick(n_total, n_correct, action_own, flat(final_action)) 271 | n_correct += 1 if action_own in [tuple(flat(t)) for t in actions_star] else 0 272 | prev_action = data.vec_to_ix(final_action) 273 | 274 | if stats: 275 | arr = np.array(stats) 276 | correct = len([1 for i in list(arr.sum(axis=1)) if i == len(stats[0])]) / len(stats) 277 | sl.log_acc(list(arr.mean(axis=0)), correct) 278 | 279 | if args.opt_at == 'problem': 280 | floss = loss / n_total 281 | floss.backward(retain_graph=args.ret_graph) 282 | if args.clip: 283 | torch.nn.utils.clip_grad_norm(DNC.parameters(), args.clip) 284 | optimizer.step() 285 | sl.writer.add_scalar('losses.end', floss.data[0], sl.global_step) 286 | 287 | n_success += 1 if n_correct / n_total > args.passing else 0 288 | cum_total.append(n_total) 289 | cum_correct.append(n_correct) 290 | sl.add_scalar('recall.pct_correct', n_correct / n_total, sl.global_step) 291 | print("trial {}, step {} trial accy: {}/{}, {:0.2f}, running total {}/{}, running avg {:0.4f}, loss {:0.4f} ".format( 292 | trial, sl.global_step, n_correct, n_total, n_correct / n_total, n_success, trial, 293 | running_avg(cum_correct, cum_total), loss.data[0] 294 | )) 295 | end_prob = time.time() 296 | prob_times.append(start_prob - end_prob) 297 | print("solved {} out of {} -> {}".format(n_success, args.iters, n_success / args.iters)) 298 | return DNC, optimizer, lstm_state, running_avg(cum_correct, cum_total) 299 | 300 | 301 | def train_manager(args, train_fn): 302 | """ 303 | 304 | :param args: args object. see arg.py or run.py -h for details 305 | :param train_fn: the training function - 306 | :return: 307 | """ 308 | datspec = generate_data_spec(args) 309 | print('\nInitial Spec', datspec) 310 | 311 | _, DNC, optimizer, lstm_state = setupDNC(args) 312 | start_ents, score, global_epoch = args.n_init_start, 0, args.start_epoch 313 | print('\nDnc structure', DNC) 314 | 315 | for problem_size in range(args.max_ents): 316 | test_size = problem_size + start_ents 317 | passing = False 318 | data_spec = generate_data_spec(args, num_ents=test_size, solve=test_size * 3) 319 | data = gen.AirCargoData(**data_spec) 320 | 321 | print("beginning new training Size: {}".format(test_size)) 322 | for train_epoch in range(args.n_phases): 323 | ep_start = time.time() 324 | global_epoch += 1 325 | print("\nStarting Epoch {}".format(train_epoch)) 326 | 327 | DNC, optimizer, lstm_state, score = train_fn(args, data, DNC, lstm_state, optimizer) 328 | if (train_epoch + 1) % args.checkpoint_every and args.save != '': 329 | save(DNC, optimizer, lstm_state, args, global_epoch) 330 | 331 | ep_end = time.time() 332 | ttl_s = ep_end - ep_start 333 | print('finished epoch: {}, score: {}, ttl-time: {:0.4f}, time/prob: {:0.4f}'.format( 334 | train_epoch, score, ttl_s, ttl_s / args.iters 335 | )) 336 | if score > args.passing: 337 | print('model_successful: {}, {} '.format(score, train_epoch)) 338 | print('----------------------WOO!!--------------------------') 339 | passing = True 340 | break 341 | 342 | if passing is False: 343 | print("Training has FAILED for problem of size: {}, after {} epochs of {} phases".format( 344 | test_size, args.max_ents, args.n_phases 345 | )) 346 | print("final score was {}".format(score)) 347 | break 348 | 349 | 350 | if __name__== "__main__": 351 | print(args) 352 | if args.act == 'plan': 353 | train_manager(args, train_plan) 354 | elif args.act == 'qa': 355 | train_manager(args, train_qa2) 356 | elif args.act == 'clean': 357 | pass 358 | else: 359 | print("wrong action") -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | 2 | # apt-get install -y supervisor 3 | 4 | # mkdir /etc/supervisor/conf.d/ 5 | # cp tf.conf /etc/supervisor/conf.d/tensorboard.conf 6 | pip install tensorboardX 7 | pip install tensorflow 8 | pip install git+https://github.com/lanpa/tensorboard-pytorch 9 | 10 | # tensordboard logdir /output/runs 11 | -------------------------------------------------------------------------------- /tests.py: -------------------------------------------------------------------------------- 1 | import unittest, copy, time 2 | import torch 3 | from problem import generators_v2 as gen 4 | from dnc_arity_list import DNC 5 | import problem.my_air_cargo_problems as mac 6 | from problem.lp_utils import ( 7 | decode_state, 8 | ) 9 | import run 10 | from visualize import wuddido as viz 11 | 12 | 13 | def test_solve_with_logic(data): 14 | print('\nWITH LOGICAL HEURISTICS') 15 | print('GOAL', data.current_problem.goal, '\n') 16 | start = time.time() 17 | f = 0 18 | for n in range(20): 19 | log_actions = data.best_logic(data.get_raw_actions(mode='all')) 20 | if log_actions != []: 21 | print('chosen', log_actions[0]) 22 | data.send_action(data.expr_to_vec(log_actions[0])) 23 | else: 24 | f = n 25 | break 26 | end = time.time() 27 | print('DONE in {} steps, {:0.4f} s'.format(f, end - start)) 28 | 29 | 30 | def test_solve_with_algo(data): 31 | print('\nWITH ASTAR') 32 | print('GOAL', data.current_problem.goal, '\n') 33 | f = 0 34 | start2 = time.time() 35 | for n in range(20): 36 | actions = data.get_raw_actions(mode='best') 37 | # print(actions) 38 | if actions != []: 39 | print('chosen', actions[0]) 40 | data.send_action(data.expr_to_vec(actions[0])) 41 | else: 42 | f = n 43 | break 44 | end2 = time.time() 45 | print('DONE in {} steps, {:0.4f} s'.format(f, end2 - start2 )) 46 | 47 | 48 | sample_args = {'num_plane': 2, 'num_cargo': 2, 49 | 'num_airport': 2, 50 | 'one_hot_size': [4,6, 4, 6, 4, 6], 51 | 'plan_phase': 2 * 3, 52 | 'cuda': False, 'batch_size': 1, 53 | 'encoding': 2, 'solve': True, 'mapping': None} 54 | 55 | 56 | class Misc(unittest.TestCase): 57 | def setUp(self): 58 | self.dataspec = {'solve': True, 'mapping': None, 59 | 'num_plane': 2, 'one_hot_size': [4, 6, 4, 6, 4, 6], 60 | 'num_airport': 2, 'plan_phase': 6, 61 | 'encoding': 2, 'batch_size': 1, 'num_cargo': 2} 62 | self.dataspec2 = {'solve': True, 'mapping': None, 63 | 'num_plane': 3, 'one_hot_size': [4, 6, 4, 6, 4, 6], 64 | 'num_airport': 3, 'plan_phase': 6, 65 | 'encoding': 2, 'batch_size': 1, 'num_cargo': 3} 66 | 67 | def est_cache(self): 68 | data = gen.AirCargoData(**self.dataspec) 69 | data.make_new_problem() 70 | test_solve_with_logic(copy.deepcopy(data)) 71 | test_solve_with_algo(copy.deepcopy(data)) 72 | print('\n\n ROUND 2') 73 | data.make_new_problem() 74 | test_solve_with_logic(copy.deepcopy(data)) 75 | test_solve_with_algo(copy.deepcopy(data)) 76 | 77 | def est_searches(self): 78 | problem = mac.air_cargo_p1() 79 | ds = decode_state(problem.initial, problem.state_map) 80 | 81 | 82 | class TestVis(unittest.TestCase): 83 | def setUp(self): 84 | self.folder = '1512692566_clip_cont2_40' 85 | iter = 109 # 86 | self.base = './models/{}/checkpts/{}/dnc_model.pkl' 87 | dict1 = torch.load(self.base.format(self.folder, iter)) 88 | 89 | self.data = gen.AirCargoData(**sample_args) 90 | 91 | args = run.dnc_args.copy() 92 | args['output_size'] = self.data.nn_in_size 93 | args['word_len'] = self.data.nn_out_size 94 | 95 | self.Dnc = DNC(**args) 96 | self.Dnc.load_state_dict(dict1) 97 | pass 98 | 99 | def test_show_state(self): 100 | rand_vec = torch.randn(39, 1) 101 | viz.ix_to_color(rand_vec) 102 | 103 | def test_run(self): 104 | record = viz.recorded_step(self.data, self.Dnc) 105 | viz.make_usage_viz(record) 106 | 107 | 108 | if __name__ == '__main__': 109 | unittest.main() 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /tf.conf: -------------------------------------------------------------------------------- 1 | logdir /output/runs -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | from visualize import logger as sl 6 | import utils as u 7 | import losses as l 8 | 9 | def tick(n_total, n_correct, truth, pred): 10 | n_total += 1 11 | n_correct += 1 if truth == pred else 0 12 | return n_total, n_correct 13 | 14 | class TrainPlanning: 15 | def __init__(self, problem, optim, args): 16 | self.ProblemGen = problem 17 | self.args = args 18 | self.optimizer = optim 19 | self.current_prob = [] 20 | self.reset() 21 | 22 | def reset(self): 23 | self.num_correct, self.total_moves, self.prev_action = 0, 0, None 24 | self.n_success = 0 25 | 26 | def phase_fn(self, idx): 27 | if idx == 0: return self.state_phase 28 | if idx == 1: return self.goal_phase 29 | if idx == 2: return self.plan_phase 30 | if idx == 3: return self.response_phase 31 | 32 | def state_phase(self, Dnc, dnc_state, inputs, mask): 33 | pass 34 | 35 | def goal_phase(self, Dnc, dnc_state, inputs, mask): 36 | pass 37 | 38 | def plan_phase(self, Dnc, dnc_state, inputs, mask): 39 | inputs = Variable(torch.cat([mask_, data.ix_input_to_ixs(prev_action)], 1)) 40 | prev_action, dnc_state = Dnc(inputs, dnc_state) 41 | pass 42 | 43 | def response_phase(self, Dnc, dnc_state, inputs, mask): 44 | sl.global_step += 1 45 | self.total_moves += 1 46 | final_inputs = Variable(torch.cat([mask, self.ProblemGen.ix_input_to_ixs(prev_action)], 1)) 47 | 48 | logits, dnc_state = Dnc(final_inputs, dnc_state) 49 | expanded_logits = self.ProblemGen.ix_input_to_ixs(logits) 50 | # 51 | chosen_act_own, loss_own = L.naive_loss(expanded_logits, all_actions, loss_fn) 52 | chosen_act_star, loss_star = L.naive_loss(expanded_logits, targets_star, loss_fn, log_itr=sl.global_step) 53 | 54 | # set next input to be the networks current action ... 55 | if random.random() < self.args.beta: 56 | loss = loss_star 57 | final_action = chosen_act_star 58 | else: 59 | loss = loss_own 60 | final_action = chosen_act_own 61 | self.num_correct += 1 if chosen_act_own == final_action else 0 62 | return Dnc, dnc_state, final_action 63 | 64 | def train(self, Dnc): 65 | for n in range(self.args.iters): 66 | self.current_prob = self.ProblemGen.make_new_problem() 67 | dnc_state = Dnc.init_state(grad=False) 68 | self.optimizer.zero_grad() 69 | 70 | for idx, _ in enumerate(self.current_prob): 71 | inputs, mask = self.ProblemGen.getitem() 72 | phase_fn = self.phase_fn(idx) 73 | phase_fn(Dnc, dnc_state, inputs, mask) 74 | pass 75 | 76 | def step(self): 77 | pass 78 | 79 | def end_problem(self): 80 | print("solved {} out of {} -> {}".format(self.n_success, self.args.iters, self.n_success / self.args.iters)) 81 | pass 82 | 83 | 84 | 85 | def play_qa_readable(args, data, DNC): 86 | criterion = nn.CrossEntropyLoss() 87 | cum_correct, cum_total = [], [] 88 | 89 | for trial in range(args.iters): 90 | phase_masks = data.make_new_problem() 91 | n_total, n_correct, loss = 0, 0, 0 92 | dnc_state = DNC.init_state(grad=False) 93 | 94 | 95 | for phase_idx in phase_masks: 96 | if phase_idx == 0 or phase_idx == 1: 97 | 98 | inputs, msk = data.getitem() 99 | print(data.human_readable(inputs, msk)) 100 | 101 | inputs = Variable(torch.cat([msk, inputs], 1)) 102 | logits, dnc_state = DNC(inputs, dnc_state) 103 | else: 104 | final_moves = data.get_actions(mode='one') 105 | if final_moves == []: 106 | break 107 | data.send_action(final_moves[0]) 108 | mask = data.phase_oh[2].unsqueeze(0) 109 | vec = data.vec_to_ix(final_moves[0]) 110 | print('\n') 111 | print(data.human_readable(vec, mask)) 112 | 113 | inputs2 = Variable(torch.cat([mask, vec], 1)) 114 | logits, dnc_state = DNC(inputs2, dnc_state) 115 | 116 | for _ in range(args.num_tests): 117 | # ask where is ---? 118 | 119 | masked_input, mask_chunk, ground_truth = data.masked_input() 120 | print("Context:", data.human_readable(ground_truth)) 121 | print("Q:") 122 | 123 | logits, dnc_state = DNC(Variable(masked_input), dnc_state) 124 | expanded_logits = data.ix_input_to_ixs(logits) 125 | 126 | #losses 127 | lstep = l.action_loss(expanded_logits, ground_truth, criterion, log=True) 128 | 129 | #update counters 130 | prediction = u.get_prediction(expanded_logits, [3, 4]) 131 | print("A:") 132 | n_total, n_correct = tick(n_total, n_correct, mask_chunk, prediction) 133 | print("correct:", mask_chunk == prediction) 134 | 135 | 136 | cum_total.append(n_total) 137 | cum_correct.append(n_correct) 138 | sl.writer.add_scalar('recall.pct_correct', n_correct / n_total, sl.global_step) 139 | print("trial: {}, step:{}, accy {:0.4f}, cum_score {:0.4f}, loss: {:0.4f}".format( 140 | trial, sl.global_step, n_correct / n_total, u.running_avg(cum_correct, cum_total), loss.data[0])) 141 | return DNC, dnc_state, u.running_avg(cum_correct, cum_total) 142 | 143 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.cuda 3 | from torch.autograd import Variable 4 | import os, glob 5 | from arg import args 6 | 7 | eps = 10e-6 8 | MODEL_NAME = 'dnc_model.pkl' 9 | OPTIM_NAME = 'optimizer.pkl' 10 | 11 | 12 | def _variable(xs, **kwargs): 13 | if args.cuda is True: 14 | return Variable(xs, **kwargs).cuda() 15 | else: 16 | return Variable(xs, **kwargs) 17 | 18 | 19 | def repackage(xs): 20 | """Wraps hidden states in new Variables, to detach them from their history.""" 21 | if type(xs) == Variable: 22 | return Variable(xs.data) 23 | else: 24 | return tuple(repackage(v) for v in xs) 25 | 26 | 27 | def depackage(xs): 28 | """Wraps hidden states in new Variables, to detach them from their history.""" 29 | if type(xs) == Variable: 30 | return xs.data 31 | elif type(xs) == torch.Tensor: 32 | return xs 33 | elif type(xs) == torch.cuda.FloatTensor: 34 | return xs 35 | else: 36 | return tuple(depackage(v) for v in xs) 37 | 38 | 39 | def running_avg(cum_correct, cum_total_move, last=-100): 40 | return sum(cum_correct[last:]) / sum(cum_total_move[last:]) 41 | 42 | 43 | def flat(container): 44 | acc = [] 45 | for i in container: 46 | if isinstance(i, (list, tuple)): 47 | for j in flat(i): 48 | acc.append(j) 49 | else: 50 | acc.append(i) 51 | return acc 52 | 53 | 54 | def show(tensors, m=""): 55 | print("--") 56 | if type(tensors) == torch.Tensor: 57 | print(m, tensors.size()) 58 | else: 59 | print(m) 60 | [print(t.size()) for t in tensors] 61 | 62 | 63 | def clean_runs(dirs, lim=10): 64 | def convert_bytes(num): 65 | for x in ['bytes', 'KB', 'MB']: 66 | if num < 1024.0: 67 | return "%3.1f %s" % (num, x) 68 | num /= 1024.0 69 | for d in glob.glob(dirs + '*'): 70 | if convert_bytes(os.stat(d).st_size) < lim: 71 | os.remove(d) 72 | 73 | 74 | def get_chkpt(name, idx=0): 75 | search = './models/*{}/checkpts/*'.format(name) 76 | chkpt = sorted(glob.glob(search))[-1] + '/' 77 | return chkpt + MODEL_NAME, chkpt + OPTIM_NAME 78 | 79 | 80 | def dnc_checksum(state): 81 | return [state[i].data.sum() for i in range(6)] 82 | 83 | 84 | def save(model, optimizer, lstm_state, args, global_epoch): 85 | chkpt = '{}{}/{}/'.format(args.base_dir, 'checkpts', global_epoch) 86 | os.mkdir(chkpt) 87 | 88 | torch.save(model.state_dict(), chkpt + MODEL_NAME) 89 | torch.save(model, chkpt + 'dnc_model_full.pkl') 90 | torch.save(lstm_state, chkpt + 'lstm_state.pkl') 91 | torch.save(optimizer.state_dict(), chkpt + OPTIM_NAME) 92 | torch.save(optimizer, chkpt + 'optimizer_full.pkl') 93 | print("Saving ... file...{}, chkpt_num:{}".format(args.base_dir, global_epoch)) 94 | 95 | 96 | def get_prediction(expanded_logits, idxs='all'): 97 | max_idxs = [] 98 | if idxs == 'all': 99 | idxs = range(len(expanded_logits)) 100 | for idx in idxs: 101 | _, pidx = expanded_logits[idx].data.topk(1) 102 | max_idxs.append(pidx.squeeze()[0]) 103 | return tuple(max_idxs) 104 | 105 | 106 | def closest_action(pred, actions): 107 | best, chosen_action = 0, None 108 | for action in [flat(a) for a in actions]: 109 | scores = [1 if pred[i] == action[i] else 0 for i in range(len(pred))] 110 | if sum(scores) >= best: 111 | best, chosen_action = sum(scores), scores 112 | return chosen_action 113 | 114 | 115 | def human_readable_res(Data, all_actions, best_actions, pred, guided, loss_data): 116 | 117 | base_prob = 1 / len(all_actions) 118 | action_own = [pred[0], (pred[1], pred[2]), (pred[3], pred[4]), (pred[5], pred[6])] 119 | correct = action_own in best_actions 120 | 121 | action_der = closest_action(pred, best_actions) 122 | best_move_exprs = [Data.vec_to_expr(t) for t in best_actions] 123 | # all_move_exprs = [Data.vec_to_expr(t) for t in all_actions] 124 | chos_move, crest = Data.vec_to_expr(action_own) 125 | 126 | # print("all {}".format(', '.join(["{} {}".format(m[0], m[1]) for m in all_move_exprs]))) 127 | print("best {}".format(', '.join(["{} {}".format(m[0], m[1]) for m in best_move_exprs]))) 128 | print("chosen: {} {}, guided {}, prob {:0.2f}, T? {}---loss {:0.4f}".format( 129 | chos_move, crest, guided, base_prob, correct, loss_data 130 | )) 131 | return action_der 132 | 133 | 134 | def interface_part(num_reads, W): 135 | partition = [num_reads* W, num_reads, W, 1, W, W, num_reads, 1, 1, num_reads * 3] 136 | ds = [] 137 | cntr = 0 138 | for idx in partition: 139 | tn = [cntr, cntr + idx] 140 | ds.append(tn) 141 | cntr += idx 142 | return ds 143 | 144 | -------------------------------------------------------------------------------- /visualize/logger.py: -------------------------------------------------------------------------------- 1 | # from tensorboardX import SummaryWriter 2 | import torchvision.utils as vutils 3 | from torch.autograd import Variable 4 | from arg import writer 5 | 6 | global_step = 0 7 | log_step = 1 8 | # state = ["access_ouptut", "memory", "read_weights", "write_wghts", "link_matrix", "link_weights", "usage"] 9 | interface = ['read_keys', 'read_str', 'write_key', 'write_str', 10 | 'erase_vec', 'write_vec', 'free_gates', 'alloc_gate', 'write_gate', 'read_modes'] 11 | losses_desc = ['action', 'ent1-type', 'ent1', 'ent2-type', 'ent2', 'ent3-type', 'ent3'] 12 | losses_desc2 = ['action', 'ent1', 'ent2', 'ent3'] 13 | 14 | 15 | def to_log(tnsr): 16 | if writer: 17 | clone = tnsr.clone() 18 | if type(clone) == Variable: 19 | clone = clone.data 20 | sizes = list(clone.size()) 21 | if len(sizes) > 1 and not all((s == 1 for s in sizes)): 22 | return clone.cpu().squeeze().numpy() 23 | else: 24 | return clone.cpu().numpy() 25 | 26 | 27 | def log_if(name, data, f=None): 28 | if writer: 29 | if global_step % log_step == 0: 30 | if f is None: 31 | writer.add_histogram(name, data, global_step, bins='sturges') 32 | else: 33 | f(name, data, global_step) 34 | 35 | 36 | def log_interface(interface_vec, step): 37 | if writer: 38 | for name, data in zip(interface, interface_vec): 39 | writer.add_histogram("interface." + name, to_log(data), step, bins='sturges') 40 | 41 | 42 | def log_model(model): 43 | if writer: 44 | for name, param in model.named_parameters(): 45 | writer.add_histogram(name, to_log(param), global_step, bins='sturges') 46 | 47 | 48 | def log_loss(losses, loss): 49 | if global_step % log_step == 0 and writer: 50 | type_descs = losses_desc2 if len(losses) == 4 else losses_desc 51 | for name, param in zip(type_descs, losses): 52 | writer.add_scalar("lossess." + name, to_log(param), global_step) 53 | writer.add_scalar('losses.total', loss.clone().cpu().data[0], global_step) 54 | 55 | 56 | def log_acc(accs, total): 57 | if global_step % log_step == 0 and writer: 58 | type_descs = losses_desc2 if len(accs) == 4 else losses_desc 59 | for name, param in zip(type_descs, accs): 60 | writer.add_scalar("acc." + name, param, global_step) 61 | writer.add_scalar('acc.total', total, global_step) 62 | 63 | def add_scalar(*args, **kwdargs): 64 | if writer: 65 | writer.add_scalar(*args, **kwdargs) 66 | 67 | 68 | def log_loss_qa(ent, inst, loss): 69 | if global_step % log_step == 0 and writer: 70 | writer.add_scalar('losses.ent1-type', ent.clone().cpu().data[0], global_step) 71 | writer.add_scalar('losses.ent1', inst.clone().cpu().data[0], global_step) 72 | writer.add_scalar('losses.total', loss.clone().cpu().data[0], global_step) 73 | 74 | 75 | def log_state(state): 76 | if global_step % log_step == 0 and len(state) == 8 and writer: 77 | out, mem, r_wghts, w_wghts, links, l_wghts, usage, hidden = state 78 | writer.add_histogram("state.access_ouptut", to_log(out), global_step, bins='sturges') 79 | writer.add_histogram("state.memory", to_log(mem), global_step, bins='sturges') 80 | writer.add_histogram("state.read_weights", to_log(r_wghts), global_step, bins='sturges') 81 | writer.add_histogram("state.write_wghts", to_log(w_wghts), global_step, bins='sturges') 82 | writer.add_histogram("state.link_matrix", to_log(links), global_step, bins='sturges') 83 | writer.add_histogram("state.link_weights", to_log(l_wghts), global_step, bins='sturges') 84 | writer.add_histogram("state.usage", to_log(usage), global_step, bins='sturges') 85 | -------------------------------------------------------------------------------- /visualize/wuddido.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotlib.figure as fig 3 | import matplotlib 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import losses as L 8 | from utils import depackage, _variable 9 | import copy 10 | from itertools import accumulate 11 | from torch.autograd import Variable 12 | import plotly.plotly as py 13 | import plotly.graph_objs as go 14 | from matplotlib.ticker import FuncFormatter, MaxNLocator 15 | 16 | """ 17 | During the state phase, the network must learn to write each tuple to seperate mem location 18 | During the goal phase, the goal phase, the goal tuples are recorded. 19 | --todo-- the triple stored at each location can be recovered by logistic reg. decoder 20 | --todo-- 21 | During state phase read modes - 22 | 23 | Extended Data Figure 2 24 | """ 25 | 26 | c_map = plt.get_cmap('jet') 27 | 28 | 29 | def ix_to_color(ix_vec): 30 | """ 31 | 32 | :param ix_vec: 33 | :return: Image 34 | """ 35 | ix_np = ix_vec.numpy() 36 | 37 | 38 | def recorded_step(data, DNC): 39 | """ 40 | state phase answer phase 41 | m w r 42 | e w r 43 | m w r 44 | w r 45 | l 46 | o 47 | c 48 | ------------time-------------> 49 | :param state: 50 | :return: 51 | """ 52 | criterion = nn.CrossEntropyLoss() 53 | 54 | phase_masks = data.make_new_problem() 55 | state = DNC.init_state(grad=False) 56 | lstm_state = DNC.init_rnn(grad=False) 57 | 58 | prev_action, time_steps = None, [] 59 | # time_steps.append(copy.copy(state) + [-1]) 60 | 61 | for phase_idx in phase_masks: 62 | if phase_idx == 0 or phase_idx == 1: 63 | inputs = Variable(data.getitem_combined()) 64 | logits, state, lstm_state = DNC(inputs, lstm_state, state) 65 | # 66 | _, prev_action = data.strip_ix_mask(logits) 67 | time_steps.append({'state': copy.copy(state), 'input': inputs.data, 68 | 'outputs': logits.data, 'phase': phase_idx}) 69 | 70 | elif phase_idx == 2: 71 | inputs = torch.cat([Variable(data.getmask()), prev_action], 1) 72 | logits, dnc_state, lstm_state = DNC(inputs, lstm_state, state) 73 | # 74 | _, prev_action = data.strip_ix_mask(logits) 75 | time_steps.append({'state': copy.copy(state), 'input': inputs.data, 76 | 'outputs': logits.data, 'phase': phase_idx}) 77 | else: 78 | _, all_actions = data.get_actions(mode='both') 79 | if data.goals_idx == {}: 80 | break 81 | mask, pr = data.getmask(), depackage(prev_action) 82 | final_inputs = Variable(torch.cat([mask, pr], 1)) 83 | # 84 | logits, state, lstm_state = DNC(final_inputs, lstm_state, state) 85 | exp_logits = data.ix_input_to_ixs(logits) 86 | # time_steps.append(copy.copy(state) + [phase_idx]) 87 | final_action, _ = L.naive_loss(exp_logits, all_actions, criterion) 88 | 89 | print(final_action) 90 | # send action to Data Generator, and set locally 91 | data.send_action(final_action) 92 | prev_action = data.vec_to_ix(final_action) 93 | time_steps.append({'state': copy.copy(state), 'input': final_inputs.data, 94 | 'outputs': logits.data, 'phase': phase_idx}) 95 | return time_steps 96 | 97 | 98 | def make_usage_viz(states): 99 | """ 100 | 101 | :param states: 102 | :return: 103 | """ 104 | # access, memory, read_wghts, write_wghts, \ 105 | # link, link_wghts, usage, phase_idx = state 106 | 107 | mem_size = list(states[0]['state'][1].size())[1] 108 | map, writes = [], [] 109 | phases = [c['phase'] for c in states] 110 | 111 | for idx, state in enumerate(states): 112 | 113 | # usage Vectors to calculate write positions 114 | current_usage = state['state'][6][0].data 115 | prev_usage = states[idx-1]['state'][6][0].data if idx > 0 else current_usage 116 | 117 | # max diff in usages => 118 | diffs = current_usage - prev_usage 119 | diff_idx, w = max(enumerate(diffs)) 120 | 121 | # decode max position 122 | map.append(np.abs(diffs.numpy())) 123 | write_vec = state['state'][1][0][diff_idx].data.numpy() 124 | print(write_vec) 125 | 126 | counts = dict((x, phases.count(x)) for x in set(phases)) 127 | pos = [0] + list(accumulate(counts.values()))[:-1] 128 | ax = plt.gca() 129 | 130 | # fig = matplotlib.figure(figsize=(3, 3)) 131 | rotated = np.rot90(np.asarray(map)) 132 | 133 | ax.xaxis.set_ticks(pos) 134 | ax.xaxis.set_ticklabels(list(range(4))) 135 | 136 | plot = plt.imshow(rotated, cmap='hot', interpolation='nearest', 137 | extent=[0, len(states), 0, mem_size]) 138 | 139 | plt.colorbar() 140 | plt.show() 141 | 142 | pass 143 | 144 | 145 | 146 | 147 | --------------------------------------------------------------------------------